Spaces:
Running
Running
import os | |
import logging | |
import gradio as gr | |
import torch | |
import uuid | |
import time | |
import ldclient | |
from ldclient.config import Config | |
from ldclient import Context | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def capture_logs(log_body, log_file, uuid_label): | |
logger = logging.getLogger('MyApp') | |
logger.setLevel(logging.INFO) | |
# Check if handlers are already added to avoid duplication | |
if not logger.handlers: | |
fh = logging.FileHandler(log_file) | |
fh.setLevel(logging.INFO) | |
ch = logging.StreamHandler() | |
ch.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
fh.setFormatter(formatter) | |
ch.setFormatter(formatter) | |
logger.addHandler(fh) | |
logger.addHandler(ch) | |
logger.info('uuid: %s - %s', log_body, uuid_label) | |
return | |
print("CUDA available: ", torch.cuda.is_available()) | |
print("MPS available: ", torch.backends.mps.is_available()) | |
sdkKey = os.getenv('sdkKEY') | |
ldclient.set_config(Config(sdkKey)) | |
client = ldclient.get() | |
context = Context.builder("huggie-face") \ | |
.set("application", "deepSeekChatbot") \ | |
.build() | |
flag_value = client.variation("themeColors", context, False) | |
if flag_value: | |
print("Feature flag on") | |
theme = gr.themes.Soft( | |
primary_hue="fuchsia", | |
neutral_hue="blue", | |
) | |
else: | |
print("Feature flag off") | |
theme = gr.themes.Soft( | |
primary_hue="sky", | |
neutral_hue="slate", | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True, torch_dtype=torch.bfloat16) | |
# Disable tokenizers parallelism warning | |
os.environ["TOKENIZERS_PARALLELISM"] = "True" | |
# Check if MPS (Metal Performance Shaders) is available | |
device = torch.device( | |
"mps") if torch.backends.mps.is_available() else torch.device("cpu") | |
model = model.to(device) | |
# Function to handle user input and generate a response | |
def chatbot_response(query, tokens, top_k, top_p): | |
uuid_label = str(uuid.uuid4()) | |
start_time = time.time() # Start timer | |
# Generate response using the model | |
messages = [{'role': 'user', 'content': query}] | |
inputs = tokenizer.apply_chat_template( | |
messages, add_generation_prompt=True, return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=tokens, | |
do_sample=True, | |
top_k=top_k, | |
top_p=top_p, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
model_response = tokenizer.decode( | |
outputs[0][len(inputs[0]):], skip_special_tokens=True) | |
end_time = time.time() # End timer | |
performance_time = round(end_time - start_time, 2) | |
log_body = 'query: %s, pocessTime: %s, tokens: %s, top_k: %s, top_p: %s' % ( | |
query, performance_time, tokens, top_k, top_p) | |
capture_logs(uuid_label, 'query_logs.csv', log_body) | |
return model_response | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
fn=chatbot_response, | |
inputs=[ | |
gr.Textbox(label="Ask our DSChatbot Expert"), | |
gr.Slider(label="Max New Tokens", minimum=128, | |
maximum=2048, step=128, value=512), | |
gr.Slider(label="Top K", minimum=0, maximum=100, step=10, value=50), | |
gr.Slider(label="Top P", minimum=0.0, | |
maximum=1.0, step=0.1, value=0.95), | |
], | |
outputs=gr.Textbox(label="Hope it helps!"), | |
theme=theme, | |
title="DSChatbot" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |