File size: 2,294 Bytes
70766ea
 
 
 
cdc36a5
 
 
 
 
 
70766ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc36a5
 
 
 
 
70766ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from threading import Thread
import gradio as gr
import transformers
import torch
from transformers import (
    pipeline,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig
)


def chat_history(history) -> str:
    messages = [
        {
            "role": ("user" if i % 2 == 0 else "assistant"),
            "content": dialog[i % 2]
        }
        for i, dialog in enumerate(history) for _ in (0, 1) if dialog[i % 2]
    ]

    return pipeline.tokenizer.apply_chat_template(
        messages, toknizer=False, add_generation_prompt=True
    )


def model_loading_pipeline():
    model_id = "vilm/vinallama-7b"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5)
    pipeline = transformers.pipeline(
        "text-generation",
        model=model_id,
        model_kwargs={
            "torch_dtype": torch.float16,
            "load_in_4bits": True,
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.float16
            ),
        },
        streamer=streamer,
    )
    return pipeline, streamer


def launch_app(pipeline, streamer):
    with gr.Blocks() as demo:
        chat = gr.Chatbot()
        msg = gr.Textbox()
        clear = gr.Button("Clear")

        def user(user_message, history):
            return "", history + [[user_message, None]]

        def bot(history):
            prompt = chat_history(history)
            history[-1][1] = ""
            kwargs = {
                "text_inputs": prompt,
                "max_new_tokens": 2048,
                "do_sample": True,
                "temperature": 0.7,
                "top_k": 50,
                "top_p": 0.95,
            }
            thread = Thread(target=pipeline, kwargs=kwargs)
            thread.start()

            for token in streamer:
                history[-1][1] += token
                yield history

        msg.submit(user, [msg, chat], [msg, chat], queue=False).then(bot, chat, chat)
        clear.click(lambda: None, None, chat, queue=False)

    demo.queue()
    demo.launch(share=True, debug=True)


if __name__ == "__main__":
    pipe, streamer = model_loading_pipeline()
    launch_app(pipe, streamer)