import base64 import io import os import subprocess from functools import partial import gradio as gr import httpx from const import BASE_URL, CLI_COMMAND, CSS, FOOTER, HEADER, MODELS, PLACEHOLDER from openai import OpenAI from PIL import Image def get_token() -> str: return ( subprocess.run( CLI_COMMAND, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, env=os.environ.copy(), ) .stdout.decode("utf-8") .strip() ) def get_headers(host: str) -> dict: return { "Authorization": f"Bearer {get_token()}", "Host": host, "Accept": "application/json", "Content-Type": "application/json", } def proxy(request: httpx.Request, model_info: dict) -> httpx.Request: request.url = request.url.copy_with(path=model_info["endpoint"]) request.headers.update(get_headers(host=model_info["host"])) return request def encode_image_with_pillow(image_path: str) -> str: with Image.open(image_path) as img: img.thumbnail((384, 384)) buffered = io.BytesIO() img.convert("RGB").save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def call_chat_api(message, history, model_name): if message["files"]: if isinstance(message["files"], dict): image = message["files"]["path"] else: image = message["files"][-1] else: for hist in history: if isinstance(hist[0], tuple): image = hist[0][0] img_base64 = encode_image_with_pillow(image) history_openai_format = [ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{img_base64}", }, }, ], } ] if len(history) == 0: history_openai_format[0]["content"].append( {"type": "text", "text": message["text"]} ) else: for human, assistant in history[1:]: if len(history_openai_format) == 1: history_openai_format[0]["content"].append( {"type": "text", "text": human} ) else: history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": message["text"]}) client = OpenAI( api_key="", base_url=BASE_URL, http_client=httpx.Client( event_hooks={ "request": [partial(proxy, model_info=MODELS[model_name])], }, verify=False, ), ) stream = client.chat.completions.create( model=f"/data/cyberagent/{model_name}", messages=history_openai_format, temperature=0.2, top_p=1.0, max_tokens=1024, stream=True, extra_body={"repetition_penalty": 1.1}, ) message = "" for chunk in stream: content = chunk.choices[0].delta.content or "" message = message + content yield message def run(): chatbot = gr.Chatbot( elem_id="chatbot", placeholder=PLACEHOLDER, scale=1, height=700 ) chat_input = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False, ) with gr.Blocks(css=CSS) as demo: gr.Markdown(HEADER) with gr.Row(): model_selector = gr.Dropdown( choices=MODELS.keys(), value=list(MODELS.keys())[0], label="Model", ) gr.ChatInterface( fn=call_chat_api, stop_btn="Stop Generation", examples=[ [ { "text": "この画像を詳しく説明してください。", "files": ["./examples/cat.jpg"], }, ], [ { "text": "この料理はどんな味がするか詳しく教えてください。", "files": ["./examples/takoyaki.jpg"], }, ], ], multimodal=True, textbox=chat_input, chatbot=chatbot, additional_inputs=[model_selector], ) gr.Markdown(FOOTER) demo.queue().launch(share=False) if __name__ == "__main__": run()