File size: 2,614 Bytes
22bd0c2
fcabc39
 
59b000d
7e9b0bd
fcabc39
 
59b000d
 
fcabc39
 
 
 
 
 
7e9b0bd
 
be52360
7e9b0bd
 
fcabc39
 
 
 
 
22bd0c2
fcabc39
 
59b000d
fcabc39
 
 
7e9b0bd
 
fcabc39
 
f7d8d0c
7e9b0bd
59b000d
7e9b0bd
f7d8d0c
7e9b0bd
 
 
 
 
 
 
 
 
 
 
 
 
f7d8d0c
 
7e9b0bd
f7d8d0c
7e9b0bd
22bd0c2
 
7e9b0bd
22bd0c2
fcabc39
 
59b000d
22bd0c2
 
 
 
59b000d
7e9b0bd
22bd0c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import requests
import os
import json
import sseclient

# Set up the API endpoint and key
API_URL = os.getenv("RUNPOD_API_URL")
API_KEY = os.getenv("RUNPOD_API_KEY")

headers = {
    "Authorization": f"Bearer {API_KEY}",
    "Content-Type": "application/json"
}

# Fixed system prompt
SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> <reflection> and you output a brief and small to the point <output>."

def stream_response(message, history, max_tokens, temperature, top_p):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    
    for human, assistant in history:
        messages.append({"role": "user", "content": human})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    data = {
        "model": "forcemultiplier/fmx-reflective-2b",
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": True
    }
    
    try:
        response = requests.post(API_URL, headers=headers, json=data, stream=True)
        response.raise_for_status()
        client = sseclient.SSEClient(response)
        
        full_response = ""
        for event in client.events():
            if event.data != "[DONE]":
                try:
                    chunk = json.loads(event.data)
                    if 'choices' in chunk and len(chunk['choices']) > 0:
                        content = chunk['choices'][0]['delta'].get('content', '')
                        full_response += content
                        # Replace < and > with their HTML entities
                        display_content = content.replace('<', '&lt;').replace('>', '&gt;')
                        yield display_content
                except json.JSONDecodeError:
                    print(f"Failed to decode JSON: {event.data}")
        
    except requests.exceptions.RequestException as e:
        yield f"Error: {str(e)}"
    except Exception as e:
        yield f"Unexpected error: {str(e)}"

demo = gr.ChatInterface(
    stream_response,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    print(f"Starting application with API URL: {API_URL}")
    print(f"Using system prompt: {SYSTEM_PROMPT}")
    demo.launch()