Markgazol commited on
Commit
a55cbd2
·
verified ·
1 Parent(s): a19ef03

Upload folder using huggingface_hub

Browse files
adapter_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": null,
4
- "base_model_name_or_path": "./models/qwenstella-base",
5
  "bias": "none",
6
  "fan_in_fan_out": false,
7
  "inference_mode": true,
 
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": null,
4
+ "base_model_name_or_path": "Metric-AI/ColQwenStella-base-2b",
5
  "bias": "none",
6
  "fan_in_fan_out": false,
7
  "inference_mode": true,
preprocessor_config.json CHANGED
@@ -25,5 +25,8 @@
25
  "max_pixels": 12845056,
26
  "min_pixels": 3136
27
  },
28
- "temporal_patch_size": 2
 
 
 
29
  }
 
25
  "max_pixels": 12845056,
26
  "min_pixels": 3136
27
  },
28
+ "temporal_patch_size": 2,
29
+ "auto_map": {
30
+ "AutoProcessor": "processing_colqwenstella.ColQwenStellaProcessor"
31
+ }
32
  }
processing_colqwenstella.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import ClassVar, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from transformers import BatchFeature
7
+ from transformers.models.qwen2_vl import Qwen2VLProcessor
8
+
9
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
10
+
11
+
12
+ def round_by_factor(number: float, factor: int) -> int:
13
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
14
+ return round(number / factor) * factor
15
+
16
+
17
+ def ceil_by_factor(number: float, factor: int) -> int:
18
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
19
+ return math.ceil(number / factor) * factor
20
+
21
+
22
+ def floor_by_factor(number: float, factor: int) -> int:
23
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
24
+ return math.floor(number / factor) * factor
25
+
26
+
27
+ class ColQwenStellaProcessor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
28
+ """
29
+ Processor for ColQwen2.
30
+ """
31
+
32
+ visual_prompt_prefix: ClassVar[str] = (
33
+ "<|im_start|><|image_pad|><|im_end|><|endoftext|>"
34
+ )
35
+ query_prefix: ClassVar[str] = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
36
+ query_augmentation_token: ClassVar[str] = "<|endoftext|>"
37
+ image_token: ClassVar[str] = "<|image_pad|>"
38
+
39
+ @property
40
+ def image_token_id(self) -> int:
41
+ return self.tokenizer.convert_tokens_to_ids(self.image_token)
42
+
43
+ def __init__(self, *args, **kwargs):
44
+ num_image_tokens = kwargs.pop("num_image_tokens", 768)
45
+ super().__init__(*args, **kwargs)
46
+ self.tokenizer.padding_side = "left"
47
+ self.min_pixels = 4 * 28 * 28
48
+ self.max_pixels = num_image_tokens * 28 * 28
49
+ self.factor = 28
50
+ self.max_ratio = 200
51
+
52
+ @staticmethod
53
+ def smart_resize_helper(
54
+ width: int,
55
+ height: int,
56
+ factor: int,
57
+ max_ratio: int,
58
+ min_pixels: int,
59
+ max_pixels: int,
60
+ ) -> Tuple[int, int]:
61
+ """
62
+ Returns the image size so that the following conditions are met:
63
+ 1. Both dimensions (height and width) are divisible by 'factor'.
64
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
65
+ 3. The aspect ratio of the image is maintained as closely as possible.
66
+ """
67
+
68
+ if max(height, width) / min(height, width) > max_ratio:
69
+ raise ValueError(
70
+ f"absolute aspect ratio must be smaller than {max_ratio}, "
71
+ f"got {max(height, width) / min(height, width)}"
72
+ )
73
+
74
+ h_bar = max(factor, round_by_factor(height, factor))
75
+ w_bar = max(factor, round_by_factor(width, factor))
76
+
77
+ if h_bar * w_bar > max_pixels:
78
+ beta = math.sqrt((height * width) / max_pixels)
79
+ h_bar = floor_by_factor(height / beta, factor)
80
+ w_bar = floor_by_factor(width / beta, factor)
81
+ elif h_bar * w_bar < min_pixels:
82
+ beta = math.sqrt(min_pixels / (height * width))
83
+ h_bar = ceil_by_factor(height * beta, factor)
84
+ w_bar = ceil_by_factor(width * beta, factor)
85
+
86
+ return h_bar, w_bar
87
+
88
+ def smart_resize(self, image: Image.Image) -> Image.Image:
89
+ """
90
+ Resize and convert the image to the required format.
91
+ """
92
+ image_size = image.size
93
+ resized_height, resized_width = self.smart_resize_helper(
94
+ width=image_size[0],
95
+ height=image_size[1],
96
+ factor=self.factor,
97
+ max_ratio=self.max_ratio,
98
+ min_pixels=self.min_pixels,
99
+ max_pixels=self.max_pixels,
100
+ )
101
+ return image.convert("RGB").resize((resized_width, resized_height))
102
+
103
+ def process_images(
104
+ self,
105
+ images: List[Image.Image],
106
+ ) -> BatchFeature:
107
+ """
108
+ Process images for ColQwen2.
109
+ """
110
+ texts_doc = [self.visual_prompt_prefix] * len(images)
111
+
112
+ resized_images: List[Image.Image] = [self.smart_resize(image) for image in images]
113
+ # # batch_doc["input_ids"][0][batch_doc["input_ids"][0]==151655] = 151646
114
+ batch_doc = self(
115
+ text=texts_doc,
116
+ images=resized_images,
117
+ padding="longest",
118
+ return_tensors="pt",
119
+ )
120
+ for i in range(batch_doc["input_ids"].shape[0]):
121
+ batch_doc["input_ids"][i][batch_doc["input_ids"][i]==151655] = 151646
122
+
123
+ # NOTE: The following code is a hack to make sure the scatter in DDP is done correctly when training
124
+ # on multiple GPUs.
125
+ offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
126
+
127
+ # separate pixel_values for each image
128
+ pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist())
129
+
130
+ # pad pixel_values to the same length to be able to make it into a tensor
131
+ max_length = max([len(pv) for pv in pixel_values])
132
+
133
+ pixel_values = [
134
+ torch.cat([pv, torch.zeros((max_length - len(pv), pv.shape[1]), dtype=pv.dtype, device=pv.device)])
135
+ for pv in pixel_values
136
+ ]
137
+ batch_doc["pixel_values"] = torch.stack(pixel_values)
138
+
139
+ return batch_doc
140
+
141
+ def process_queries(
142
+ self,
143
+ queries: List[str],
144
+ max_length: int = 50,
145
+ suffix: Optional[str] = None,
146
+ ) -> BatchFeature:
147
+ """
148
+ Process queries for ColQwen2.
149
+ """
150
+ if suffix is None:
151
+ suffix = self.query_augmentation_token * 10
152
+ texts_query: List[str] = []
153
+
154
+ for query in queries:
155
+ query = self.query_prefix + query + suffix
156
+ texts_query.append(query)
157
+
158
+ batch_query = self(
159
+ text=texts_query,
160
+ return_tensors="pt",
161
+ padding="longest",
162
+ )
163
+
164
+ return batch_query
165
+
166
+ def score(
167
+ self,
168
+ qs: List[torch.Tensor],
169
+ ps: List[torch.Tensor],
170
+ device: Optional[Union[str, torch.device]] = None,
171
+ **kwargs,
172
+ ) -> torch.Tensor:
173
+ """
174
+ Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
175
+ """
176
+ return self.score_multi_vector(qs, ps, device=device, **kwargs)
177
+
178
+ def get_n_patches(
179
+ self,
180
+ image_size: Tuple[int, int],
181
+ patch_size: int,
182
+ spatial_merge_size: int,
183
+ ) -> Tuple[int, int]:
184
+ """
185
+ Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
186
+ size (height, width) with the given patch size.
187
+
188
+ The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
189
+ as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
190
+ """
191
+ height_new, width_new = self.smart_resize_helper(
192
+ width=image_size[0],
193
+ height=image_size[1],
194
+ factor=self.factor,
195
+ max_ratio=self.max_ratio,
196
+ min_pixels=self.min_pixels,
197
+ max_pixels=self.max_pixels,
198
+ )
199
+
200
+ n_patches_x = width_new // patch_size // spatial_merge_size
201
+ n_patches_y = height_new // patch_size // spatial_merge_size
202
+
203
+ return n_patches_x, n_patches_y
204
+
205
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
206
+ return batch_images.input_ids == self.image_token_id