saumitras commited on
Commit
a149cce
·
unverified ·
1 Parent(s): e3f298a

added gemini and openai rag

Browse files
Files changed (3) hide show
  1. rag.py +101 -0
  2. requirements.txt +1 -0
  3. utils.py +5 -0
rag.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import google.generativeai as genai
4
+
5
+ from typing import List
6
+ from utils import encode_image
7
+ from PIL import Image
8
+
9
+ class Rag:
10
+
11
+ def get_answer_from_gemini(self, query, imagePaths):
12
+ try:
13
+ genai.configure(api_key=os.environ['GEMINI_API_KEY'])
14
+ model = genai.GenerativeModel('gemini-1.5-flash')
15
+
16
+ images = [Image.open(path) for path in imagePaths]
17
+
18
+
19
+ chat = model.start_chat()
20
+ response = chat.send_message([*images, query])
21
+
22
+ answer = response.text
23
+
24
+ print(answer)
25
+
26
+ return answer
27
+
28
+ except Exception as e:
29
+ print(f"An error occurred while querying Gemini: {e}")
30
+ return f"Error: {str(e)}"
31
+
32
+
33
+ def get_answer_from_openai(self, query, imagesPaths):
34
+ try:
35
+ print(f"Querying LLM for query={query}, imagesPaths={imagesPaths}")
36
+
37
+ payload = self.__get_openai_api_payload(query, imagesPaths)
38
+
39
+ headers = {
40
+ "Content-Type": "application/json",
41
+ "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
42
+ }
43
+
44
+ response = requests.post(
45
+ url="https://api.openai.com/v1/chat/completions",
46
+ headers=headers,
47
+ json=payload
48
+ )
49
+ response.raise_for_status() # Raise an HTTPError for bad responses
50
+
51
+ answer = response.json()["choices"][0]["message"]["content"]
52
+
53
+ print(answer)
54
+
55
+ return answer
56
+
57
+ except Exception as e:
58
+ print(f"An error occurred while querying OpenAI: {e}")
59
+ return None
60
+
61
+
62
+ def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
63
+ image_payload = []
64
+
65
+ for imagePath in imagesPaths:
66
+ base64_image = encode_image(imagePath)
67
+ image_payload.append({
68
+ "type": "image_url",
69
+ "image_url": {
70
+ "url": f"data:image/jpeg;base64,{base64_image}"
71
+ }
72
+ })
73
+
74
+ payload = {
75
+ "model": "gpt-4o",
76
+ "messages": [
77
+ {
78
+ "role": "user",
79
+ "content": [
80
+ {
81
+ "type": "text",
82
+ "text": query
83
+ },
84
+ *image_payload
85
+ ]
86
+ }
87
+ ],
88
+ "max_tokens": 1024
89
+ }
90
+
91
+ return payload
92
+
93
+
94
+
95
+ # if __name__ == "__main__":
96
+ # rag = Rag()
97
+
98
+ # query = "Based on attached images, how many new cases were reported during second wave peak"
99
+ # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
100
+
101
+ # rag.get_answer_from_gemini(query, imagesPaths)
requirements.txt CHANGED
@@ -6,3 +6,4 @@ colpali_engine==0.3.4
6
  tqdm==4.66.5
7
  pillow==10.4.0
8
  spaces==0.30.4
 
 
6
  tqdm==4.66.5
7
  pillow==10.4.0
8
  spaces==0.30.4
9
+ google-generativeai==0.8.3
utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import base64
2
+
3
+ def encode_image(image_path):
4
+ with open(image_path, "rb") as image_file:
5
+ return base64.b64encode(image_file.read()).decode('utf-8')