rishiraj commited on
Commit
bbbe561
·
verified ·
1 Parent(s): 1ad4bf1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from google import genai
10
+ from gradio_webrtc import (
11
+ AsyncAudioVideoStreamHandler,
12
+ WebRTC,
13
+ async_aggregate_bytes_to_16bit,
14
+ VideoEmitType,
15
+ AudioEmitType,
16
+ get_twilio_turn_credentials,
17
+ )
18
+ from PIL import Image
19
+
20
+
21
+ def encode_audio(data: np.ndarray) -> dict:
22
+ """Encode Audio data to send to the server"""
23
+ return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
24
+
25
+
26
+ def encode_image(data: np.ndarray) -> dict:
27
+ with BytesIO() as output_bytes:
28
+ pil_image = Image.fromarray(data)
29
+ pil_image.save(output_bytes, "JPEG")
30
+ bytes_data = output_bytes.getvalue()
31
+ base64_str = str(base64.b64encode(bytes_data), "utf-8")
32
+ return {"mime_type": "image/jpeg", "data": base64_str}
33
+
34
+
35
+ class GeminiHandler(AsyncAudioVideoStreamHandler):
36
+ def __init__(
37
+ self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
38
+ ) -> None:
39
+ super().__init__(
40
+ expected_layout,
41
+ output_sample_rate,
42
+ output_frame_size,
43
+ input_sample_rate=16000,
44
+ )
45
+ self.audio_queue = asyncio.Queue()
46
+ self.video_queue = asyncio.Queue()
47
+ self.quit = asyncio.Event()
48
+ self.session = None
49
+ self.last_frame_time = 0
50
+
51
+ def copy(self) -> "GeminiHandler":
52
+ return GeminiHandler(
53
+ expected_layout=self.expected_layout,
54
+ output_sample_rate=self.output_sample_rate,
55
+ output_frame_size=self.output_frame_size,
56
+ )
57
+
58
+ async def video_receive(self, frame: np.ndarray):
59
+ if self.session:
60
+ # send image every 1 second
61
+ if time.time() - self.last_frame_time > 1:
62
+ self.last_frame_time = time.time()
63
+ await self.session.send(encode_image(frame))
64
+ self.video_queue.put_nowait(frame)
65
+
66
+ async def video_emit(self) -> VideoEmitType:
67
+ return await self.video_queue.get()
68
+
69
+ async def connect(self, api_key: str, filled_prompt: str):
70
+ if self.session is None:
71
+ client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"})
72
+ config = {"response_modalities": ["AUDIO"], "system_instruction": filled_prompt}
73
+ async with client.aio.live.connect(
74
+ model="gemini-2.0-flash-exp", config=config
75
+ ) as session:
76
+ self.session = session
77
+ asyncio.create_task(self.receive_audio())
78
+ await self.quit.wait()
79
+
80
+ async def generator(self):
81
+ while not self.quit.is_set():
82
+ turn = self.session.receive()
83
+ async for response in turn:
84
+ if data := response.data:
85
+ yield data
86
+
87
+ async def receive_audio(self):
88
+ async for audio_response in async_aggregate_bytes_to_16bit(
89
+ self.generator()
90
+ ):
91
+ self.audio_queue.put_nowait(audio_response)
92
+
93
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
94
+ _, array = frame
95
+ array = array.squeeze()
96
+ audio_message = encode_audio(array)
97
+ if self.session:
98
+ await self.session.send(audio_message)
99
+
100
+ async def emit(self) -> AudioEmitType:
101
+ if not self.args_set.is_set():
102
+ await self.wait_for_args()
103
+ if self.session is None:
104
+ asyncio.create_task(self.connect(self.latest_args[1], self.latest_args[2]))
105
+ array = await self.audio_queue.get()
106
+ return (self.output_sample_rate, array)
107
+
108
+ def shutdown(self) -> None:
109
+ self.quit.set()
110
+ self.connection = None
111
+ self.args_set.clear()
112
+ self.quit.clear()
113
+
114
+
115
+
116
+ css = """
117
+ #video-source {max-width: 600px !important; max-height: 600 !important;}
118
+ """
119
+
120
+ with gr.Blocks(css=css) as demo:
121
+ gr.HTML(
122
+ """
123
+ <div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
124
+ <div style="background-color: var(--block-background-fill); border-radius: 8px">
125
+ <img src="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png" style="width: 100px; height: 100px;">
126
+ </div>
127
+ <div>
128
+ <h1>Gen AI SDK Voice Chat</h1>
129
+ <p>Speak with Gemini using real-time audio + video streaming</p>
130
+ <p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href=https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
131
+ <p>Get an API Key <a href="https://support.google.com/googleapi/answer/6158862?hl=en">here</a></p>
132
+ </div>
133
+ </div>
134
+ """
135
+ )
136
+ with gr.Row() as api_key_row:
137
+ api_key = gr.Textbox(label="API Key", type="password", placeholder="Enter your API Key", value=os.getenv("GOOGLE_API_KEY"))
138
+ with gr.Row(visible=False) as row:
139
+ with gr.Column():
140
+ webrtc = WebRTC(
141
+ label="Video Chat",
142
+ modality="audio-video",
143
+ mode="send-receive",
144
+ elem_id="video-source",
145
+ rtc_configuration=get_twilio_turn_credentials(),
146
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
147
+ pulse_color="rgb(35, 157, 225)",
148
+ icon_button_color="rgb(35, 157, 225)",
149
+ )
150
+ with gr.Column():
151
+ about_company = gr.Textbox(label="About company", lines=3, placeholder="Enter details about the company")
152
+ about_role = gr.Textbox(label="About this role", lines=3, placeholder="Enter details about the role")
153
+ responsibilities = gr.Textbox(label="This is what you’ll be doing", lines=4, placeholder="Describe the responsibilities")
154
+ requirements = gr.Textbox(label="What we are looking for", lines=4, placeholder="Describe the requirements")
155
+ benefits = gr.Textbox(label="What we offer", lines=3, placeholder="Describe the benefits offered")
156
+ interview_questions = gr.Textbox(label="Interview questions", lines=4, placeholder="Provide interview questions")
157
+
158
+ with open("prompt.md", "r") as file:
159
+ template = file.read()
160
+
161
+ filled_prompt = template.format(
162
+ about_company=about_company,
163
+ about_role=about_role,
164
+ responsibilities=responsibilities,
165
+ requirements=requirements,
166
+ benefits=benefits,
167
+ interview_questions=interview_questions
168
+ )
169
+
170
+ webrtc.stream(
171
+ GeminiHandler(),
172
+ inputs=[webrtc, api_key, filled_prompt],
173
+ outputs=[webrtc],
174
+ time_limit=90,
175
+ concurrency_limit=2,
176
+ )
177
+ api_key.submit(
178
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
179
+ None,
180
+ [api_key_row, row],
181
+ )
182
+
183
+
184
+ if __name__ == "__main__":
185
+ demo.launch()