|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
|
|
MODELS = { |
|
"TinyLlama-1.1B": { |
|
"base": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
"adapter": "morangold1/vacation-rental-assistant" |
|
}, |
|
"DeepSeek-7B": { |
|
"base": "deepseek-ai/deepseek-llm-7b-chat", |
|
"adapter": "morangold1/vacation-rental-assistant-deepseek" |
|
} |
|
} |
|
|
|
class VacationRentalAssistant: |
|
def __init__(self): |
|
self.current_model = None |
|
self.current_tokenizer = None |
|
self.model_name = "TinyLlama-1.1B" |
|
|
|
def load_model(self, model_name): |
|
if self.model_name != model_name: |
|
print(f"Loading {model_name}...") |
|
config = MODELS[model_name] |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
config["base"], |
|
torch_dtype=torch.float32, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(config["base"]) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = PeftModel.from_pretrained(base_model, config["adapter"]) |
|
model.eval() |
|
|
|
self.current_model = model |
|
self.current_tokenizer = tokenizer |
|
self.model_name = model_name |
|
print(f"Model {model_name} loaded successfully!") |
|
|
|
def get_response(self, message, history, model_name): |
|
self.load_model(model_name) |
|
|
|
system_prompt = """You are a vacation rental property assistant. Help guests with inquiries, maintenance requests, and local information.""" |
|
prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" |
|
|
|
inputs = self.current_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
|
inputs = {k: v.to(self.current_model.device) for k, v in inputs.items()} |
|
|
|
outputs = self.current_model.generate( |
|
**inputs, |
|
max_new_tokens=1024, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.current_tokenizer.pad_token_id, |
|
eos_token_id=self.current_tokenizer.eos_token_id |
|
) |
|
|
|
response = self.current_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = response.replace(prompt, "").strip() |
|
|
|
return response |
|
|
|
assistant = VacationRentalAssistant() |
|
|
|
with gr.Blocks() as demo: |
|
model_name = gr.Dropdown( |
|
choices=list(MODELS.keys()), |
|
value="TinyLlama-1.1B", |
|
label="Select Model" |
|
) |
|
chatbot = gr.ChatInterface( |
|
fn=lambda msg, history: assistant.get_response(msg, history, model_name.value), |
|
title="Vacation Rental Assistant", |
|
description="Ask questions about your vacation rental property, make requests, or get local information.", |
|
examples=[ |
|
"What time is check-in?", |
|
"Is early check-in available?", |
|
"The AC isn't working properly, can you help?", |
|
"What amenities are available?", |
|
"Is there a grocery store nearby?", |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |