|
import torch |
|
from transformers import ( |
|
AutoProcessor, |
|
BitsAndBytesConfig, |
|
LlavaForConditionalGeneration, |
|
) |
|
from PIL import Image |
|
import gradio as gr |
|
from threading import Thread |
|
from transformers import TextIteratorStreamer, AutoModelForCausalLM, CodeGenTokenizerFast as Tokenizer |
|
from dotenv import load_dotenv |
|
import os |
|
|
|
from db_client import get_user_history, update_user_history, delete_user_history |
|
|
|
load_dotenv() |
|
|
|
|
|
TESTING = False |
|
|
|
IS_LOGGED_IN = True |
|
USER_ID = "[email protected]" |
|
|
|
|
|
model_id = "blanchon/pixtral-nutrition-2" |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
if TESTING: |
|
model_id = "vikhyatk/moondream1" |
|
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True) |
|
processor = Tokenizer.from_pretrained(model_id) |
|
else: |
|
model = LlavaForConditionalGeneration.from_pretrained( |
|
model_id, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
quantization_config=bnb_config, |
|
) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
processor.chat_template = """ |
|
{%- for message in messages %} |
|
{%- if message.role == "user" %} |
|
<s>[INST] |
|
{%- for item in message.content %} |
|
{%- if item.type == "text" %} |
|
{{ item.text }} |
|
{%- elif item.type == "image" %} |
|
\n[IMG] |
|
{%- endif %} |
|
{%- endfor %} |
|
[/INST] |
|
{%- elif message.role == "assistant" %} |
|
{%- for item in message.content %} |
|
{%- if item.type == "text" %} |
|
{{ item.text }} |
|
{%- endif %} |
|
{%- endfor %} |
|
</s> |
|
{%- endif %} |
|
{%- endfor %} |
|
""".replace(' ', "") |
|
|
|
processor.tokenizer.pad_token = processor.tokenizer.eos_token |
|
|
|
|
|
def bot_streaming(chatbot, image_input, max_new_tokens=250): |
|
|
|
messages = get_user_history(USER_ID) |
|
images = [] |
|
text_input = chatbot[-1][0] |
|
|
|
if text_input != "": |
|
text_input = "you are a nutrition expert. Identify the food/ingredients in this image. Is this a healthy meal? Can you think of how to improve it?" |
|
|
|
|
|
if image_input is not None: |
|
|
|
if isinstance(image_input, Image.Image): |
|
image = image_input.convert("RGB") |
|
else: |
|
image = Image.fromarray(image_input).convert("RGB") |
|
images.append(image) |
|
messages.append({ |
|
"role": "user", |
|
"content": [{"type": "text", "text": text_input}, {"type": "image"}] |
|
}) |
|
else: |
|
messages.append({ |
|
"role": "user", |
|
"content": [{"type": "text", "text": text_input}] |
|
}) |
|
|
|
|
|
texts = processor.apply_chat_template(messages) |
|
|
|
|
|
if not images: |
|
inputs = processor(text=texts, return_tensors="pt").to("cuda") |
|
else: |
|
inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda") |
|
|
|
streamer = TextIteratorStreamer( |
|
processor.tokenizer, skip_special_tokens=True, skip_prompt=True |
|
) |
|
|
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
response = "" |
|
for new_text in streamer: |
|
response += new_text |
|
chatbot[-1][1] = response |
|
yield chatbot |
|
|
|
thread.join() |
|
|
|
|
|
print('*'*60) |
|
print('*'*60) |
|
print('BOT_STREAMING_CONV_START') |
|
for i, (request, answer) in enumerate(chatbot[:-1], 1): |
|
print(f'Q{i}:\n {request}') |
|
print(f'A{i}:\n {answer}') |
|
print('New_Q:\n', text_input) |
|
print('New_A:\n', response) |
|
print('BOT_STREAMING_CONV_END') |
|
|
|
|
|
if IS_LOGGED_IN: |
|
new_history = messages + [{"role": "assistant", "content": [{"type": "text", "text": response}]}] |
|
update_user_history(USER_ID, new_history) |
|
|
|
|
|
html = f""" |
|
<p align="center" style="font-size: 2.5em; line-height: 1;"> |
|
<span style="display: inline-block; vertical-align: middle;">🍽️</span> |
|
<span style="display: inline-block; vertical-align: middle;">PixDiet</span> |
|
</p> |
|
<center><font size=3><b>PixDiet</b> is your AI nutrition expert. Upload an image of your meal and chat with our AI to get personalized advice on your diet, meal composition, and ways to improve your nutrition.</font></center> |
|
<div style="display: flex; justify-content: center; align-items: center; margin-top: 20px;"> |
|
<img src="https://static.alan.com/murray/e5830dc1cc164ef2b054023851e32206_Alan_Lockup_Horizontal_Black_RGB_Large.png" alt="Alan AI Logo" style="height: 50px; margin-right: 20px;"> |
|
<img src="https://seeklogo.com/images/M/mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" alt="Mistral AI Logo" style="height: 50px;"> |
|
</div> |
|
""" |
|
|
|
|
|
latex_delimiters_set = [ |
|
{"left": "\\(", "right": "\\)", "display": False}, |
|
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}, |
|
{"left": "\\begin{align}", "right": "\\end{align}", "display": True}, |
|
{"left": "\\begin{alignat}", "right": "\\end{alignat}", "display": True}, |
|
{"left": "\\begin{gather}", "right": "\\end{gather}", "display": True}, |
|
{"left": "\\begin{CD}", "right": "\\end{CD}", "display": True}, |
|
{"left": "\\[", "right": "\\]", "display": True} |
|
] |
|
|
|
|
|
with gr.Blocks(title="PixDiet", theme=gr.themes.Ocean()) as demo: |
|
gr.HTML(html) |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
image_input = gr.Image(label="Upload your meal image", height=350, type="pil") |
|
gr.Examples( |
|
examples=[ |
|
["./examples/mistral_breakfast.jpeg", ""], |
|
["./examples/pate_carbo.jpg", ""], |
|
["./examples/mc_do.jpeg", ""], |
|
], |
|
inputs=[image_input, gr.Textbox(visible=False)] |
|
) |
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot(label="Chat with PixDiet", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set) |
|
text_input = gr.Textbox(label="Ask about your meal", placeholder="Enter your question here...", lines=1, container=False) |
|
with gr.Row(): |
|
send_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Delete my historic", variant="secondary") |
|
|
|
def submit_chat(chatbot, text_input): |
|
response = '' |
|
chatbot.append((text_input, response)) |
|
return chatbot, '' |
|
|
|
def clear_chat(): |
|
delete_user_history(USER_ID) |
|
return [], None, "" |
|
|
|
send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then( |
|
bot_streaming, [chatbot, image_input], chatbot |
|
) |
|
|
|
|
|
|
|
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=False, share=False, show_api=False) |
|
|