Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,462 Bytes
f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df 2e62a51 f44a0df cfe04ee 2e62a51 f44a0df 2e62a51 f44a0df bfe63bd f44a0df bfe63bd f44a0df cfe04ee 0b1848e cfe04ee f44a0df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import spaces
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from typing import Generator
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./lora_model")
model = AutoPeftModelForCausalLM.from_pretrained("./lora_model", device_map=0, torch_dtype="auto")
@spaces.GPU()
@torch.no_grad()
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
) -> Generator[str, None, None]:
torch.cuda.empty_cache()
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
convo_string = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert isinstance(convo_string, str)
# Tokenize the conversation
convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
input_ids = torch.tensor(convo_tokens, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
# Move to GPU
input_ids = input_ids.unsqueeze(0).to("cuda")
attention_mask = attention_mask.unsqueeze(0).to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=True,
suppress_tokens=None,
use_cache=True,
temperature=temperature,
top_k=None,
top_p=top_p,
streamer=streamer,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = ["score_7_up,"]
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful image generation prompt writing AI. You write image generation prompts based on user requests. The prompt you write should be 150 words or longer.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["Please write a random prompt."],
["I'd like an image based on the tags: black and white, two women, gym, minimalist design, exposed beams, kneeling, holding head, casual wear."],
["Can you create an image of a woman hiking and resting on a rock in a beautiful forest with mountains?"],
["can u make a creepy hallway pic, like something out of a weird dream, with shadows and a mysterious figure at the end? maybe some reds and blacks, make it look kinda eerie and otherworldly pls"],
["Beach sunset with silhouettes on rocks and birds flying"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|