davanstrien HF staff commited on
Commit
30ae6d8
·
verified ·
1 Parent(s): d334b52

regex parsing

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -7,6 +7,7 @@ subprocess.run(
7
  )
8
  import spaces
9
  import gradio as gr
 
10
 
11
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
12
  from qwen_vl_utils import process_vision_info
@@ -18,15 +19,15 @@ from typing import Tuple
18
 
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
20
 
21
-
22
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  "Qwen/Qwen2.5-VL-7B-Instruct",
24
  torch_dtype=torch.bfloat16,
25
  attn_implementation="flash_attention_2",
26
  device_map="auto",
27
  )
28
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
29
-
 
30
 
31
  class GeneralRetrievalQuery(BaseModel):
32
  broad_topical_query: str
@@ -36,6 +37,17 @@ class GeneralRetrievalQuery(BaseModel):
36
  visual_element_query: str
37
  visual_element_explanation: str
38
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
41
  if prompt_name != "general":
@@ -76,11 +88,9 @@ Generate the queries based on this image and provide the response in the specifi
76
 
77
  return prompt, GeneralRetrievalQuery
78
 
79
-
80
  # defined like this so we can later add more prompting options
81
  prompt, pydantic_model = get_retrieval_prompt("general")
82
 
83
-
84
  def _prep_data_for_input(image):
85
  messages = [
86
  {
@@ -109,7 +119,6 @@ def _prep_data_for_input(image):
109
  return_tensors="pt",
110
  )
111
 
112
-
113
  @spaces.GPU
114
  def generate_response(image):
115
  inputs = _prep_data_for_input(image)
@@ -125,13 +134,20 @@ def generate_response(image):
125
  generated_ids_trimmed,
126
  skip_special_tokens=True,
127
  clean_up_tokenization_spaces=False,
128
- )
 
129
  try:
130
- return json.loads(output_text[0])
 
 
 
 
 
 
 
131
  except Exception:
132
  gr.Warning("Failed to parse JSON from output")
133
- return output_text[0]
134
-
135
 
136
  title = "ColPali Query Generator using Qwen2.5-VL"
137
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
 
7
  )
8
  import spaces
9
  import gradio as gr
10
+ import re
11
 
12
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
13
  from qwen_vl_utils import process_vision_info
 
19
 
20
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
21
 
 
22
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  "Qwen/Qwen2.5-VL-7B-Instruct",
24
  torch_dtype=torch.bfloat16,
25
  attn_implementation="flash_attention_2",
26
  device_map="auto",
27
  )
28
+ processor = AutoProcessor.from_pretrained(
29
+ "Qwen/Qwen2.5-VL-7B-Instruct",
30
+ )
31
 
32
  class GeneralRetrievalQuery(BaseModel):
33
  broad_topical_query: str
 
37
  visual_element_query: str
38
  visual_element_explanation: str
39
 
40
+ def extract_json_with_regex(text):
41
+ # Pattern to match content between code backticks
42
+ pattern = r'```(?:json)?\s*(.+?)\s*```'
43
+
44
+ # Find all matches (should typically be one)
45
+ matches = re.findall(pattern, text, re.DOTALL)
46
+
47
+ if matches:
48
+ # Return the first match
49
+ return matches[0]
50
+ return None
51
 
52
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
53
  if prompt_name != "general":
 
88
 
89
  return prompt, GeneralRetrievalQuery
90
 
 
91
  # defined like this so we can later add more prompting options
92
  prompt, pydantic_model = get_retrieval_prompt("general")
93
 
 
94
  def _prep_data_for_input(image):
95
  messages = [
96
  {
 
119
  return_tensors="pt",
120
  )
121
 
 
122
  @spaces.GPU
123
  def generate_response(image):
124
  inputs = _prep_data_for_input(image)
 
134
  generated_ids_trimmed,
135
  skip_special_tokens=True,
136
  clean_up_tokenization_spaces=False,
137
+ )[0]
138
+
139
  try:
140
+ # Try to extract JSON from code block first
141
+ json_str = extract_json_with_regex(output_text)
142
+ if json_str:
143
+ parsed = json.loads(json_str)
144
+ return json.dumps(parsed, indent=2)
145
+ # If no code block found, try direct JSON parsing
146
+ parsed = json.loads(output_text)
147
+ return json.dumps(parsed, indent=2)
148
  except Exception:
149
  gr.Warning("Failed to parse JSON from output")
150
+ return output_text
 
151
 
152
  title = "ColPali Query Generator using Qwen2.5-VL"
153
  description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.