File size: 3,462 Bytes
f44a0df
2e62a51
 
 
 
 
 
f44a0df
 
2e62a51
 
 
f44a0df
2e62a51
 
 
f44a0df
 
 
 
 
 
 
2e62a51
 
 
f44a0df
 
 
 
 
 
 
 
 
 
2e62a51
 
 
 
 
 
 
 
f44a0df
2e62a51
 
 
 
 
 
 
 
 
 
 
 
 
f44a0df
2e62a51
f44a0df
2e62a51
 
 
 
 
 
 
 
f44a0df
cfe04ee
2e62a51
 
 
f44a0df
 
 
 
 
2e62a51
f44a0df
bfe63bd
f44a0df
 
 
bfe63bd
f44a0df
 
 
 
cfe04ee
0b1848e
 
 
 
 
cfe04ee
 
f44a0df
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import spaces
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from typing import Generator


# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./lora_model")
model = AutoPeftModelForCausalLM.from_pretrained("./lora_model", device_map=0, torch_dtype="auto")


@spaces.GPU()
@torch.no_grad()
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
) -> Generator[str, None, None]:
    torch.cuda.empty_cache()

    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    convo_string = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
    assert isinstance(convo_string, str)

    # Tokenize the conversation
    convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)

    input_ids = torch.tensor(convo_tokens, dtype=torch.long)
    attention_mask = torch.ones_like(input_ids)

    # Move to GPU
    input_ids = input_ids.unsqueeze(0).to("cuda")
    attention_mask = attention_mask.unsqueeze(0).to("cuda")

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_tokens,
        do_sample=True,
        suppress_tokens=None,
        use_cache=True,
        temperature=temperature,
        top_k=None,
        top_p=top_p,
        streamer=streamer,
    )

    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = ["score_7_up,"]
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful image generation prompt writing AI. You write image generation prompts based on user requests. The prompt you write should be 150 words or longer.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["Please write a random prompt."],
        ["I'd like an image based on the tags: black and white, two women, gym, minimalist design, exposed beams, kneeling, holding head, casual wear."],
        ["Can you create an image of a woman hiking and resting on a rock in a beautiful forest with mountains?"],
        ["can u make a creepy hallway pic, like something out of a weird dream, with shadows and a mysterious figure at the end? maybe some reds and blacks, make it look kinda eerie and otherworldly pls"],
        ["Beach sunset with silhouettes on rocks and birds flying"],
    ],
    cache_examples=False,
)


if __name__ == "__main__":
    demo.launch()