File size: 3,494 Bytes
b19bf93
 
 
 
 
 
 
 
2f28407
8af24d4
cf9593c
f309e2f
 
8af24d4
 
 
f309e2f
2f28407
b19bf93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af24d4
b19bf93
 
 
 
 
 
 
 
 
 
8af24d4
b19bf93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8af24d4
b19bf93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import random

import requests

# Template
title = "A conversation with Gandalf (GPTJ-6B) 🧙"
description = ""
article = """
<p> To reset you <b>need to reload the page.</b> </p>
<p> If you liked don't forget to 💖 the project 🥰 </p>
<h2> Parameters: </h2>
<ul>
    <li><i>top_p</i>:  control how deterministic the model is in generating a response.</li>
    <li><i>temperature</i>: (sampling temperature) higher values means the model will take more risks.</li>
    <li><i>max_new_tokens</i>: Max number of tokens in generation.</li>
</ul>
<img src='http://www.simoninithomas.com/test/gandalf.jpg', alt="Gandalf"/>"""
theme="huggingface"
examples = [[0.9, 1.1, 50, "Hey Gandalf! How are you?"], [0.9, 1.1, 50, "Hey Gandalf, why you didn't use the great eagles to fly Frodo to Mordor?"]]

# GPT-J-6B API
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
def query(payload):
  response = requests.post(API_URL, json=payload)
  return response.json()
context_setup = "The following is a conversation with Gandalf, the mage of 'the Lord of the Rings'"
context=context_setup
interlocutor_names = ["Human", "Gandalf"]

# Builds the prompt from what previously happened 
def build_prompt(conversation, context):
  prompt = context + "\n"
  for user_msg, resp_msg in conversation:
      line = "\n- " + interlocutor_names[0] + ":" + user_msg
      prompt += line
      line = "\n- " + interlocutor_names[1] + ":" + resp_msg
      prompt += line
  prompt += ""
  return prompt

# Attempt to recognize what the model said, if it used the correct format
def clean_chat_output(txt, prompt):
  delimiter = "\n- "+interlocutor_names[0]
  output = txt.replace(prompt, '')
  output = output[:output.find(delimiter)]
  return output


def chat(top_p, temperature, max_new_tokens, message):
    history = gr.get_state() or []
    history.append((message, ""))
    gr.set_state(history)
    conversation = history
    prompt = build_prompt(conversation, context)
    
    # Build JSON
    json_ = {"inputs": prompt,
         "parameters":
         {
         "top_p": top_p,
        "temperature": temperature,
        "max_new_tokens": max_new_tokens,
        "return_full_text": False
        }}
       
    output = query(json_)
    output = output[0]['generated_text']
    answer = clean_chat_output(output, prompt)
    response = answer
    history[-1] = (message, response)
    gr.set_state(history)
    html = "<div class='chatbot'>"
    for user_msg, resp_msg in history:
        html += f"<div class='user_msg'>{user_msg}</div>"
        html += f"<div class='resp_msg'>{resp_msg}</div>"
    html += "</div>"
    return html

iface = gr.Interface(
        chat, 
        [ 
            gr.inputs.Slider(minimum=0.5, maximum=1, step=0.05, default=0.9, label="top_p"),
            gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.1, label="temperature"),
            gr.inputs.Slider(minimum=20, maximum=250, step=10, default=50, label="max_new_tokens"),
            "text",
        ],
     "html", css="""
    .chatbox {display:flex;flex-direction:column}
    .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
    .user_msg {background-color:cornflowerblue;color:white;align-self:start}
    .resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=True, 
allow_flagging=True,
title=title,
article=article,
theme=theme,
examples=examples)

if __name__ == "__main__":
  iface.launch()