SHREYSH commited on
Commit
a6723ff
·
verified ·
1 Parent(s): 092997a

Updating prepare_models (#1)

Browse files

- Updating prepare_models (8e98ce788746309169c34d162fa96fa2d4d7de56)

Files changed (1) hide show
  1. util/prepare_utils.py +13 -12
util/prepare_utils.py CHANGED
@@ -184,17 +184,14 @@ def extract_features(imgs, feature_extractor_ensemble, dim):
184
  return features
185
 
186
 
187
-
188
-
189
-
190
  def prepare_models(model_backbones,
191
- input_size,
192
- model_roots,
193
- kernel_size_attack,
194
- sigma_attack,
195
- combination,
196
- using_subspace,
197
- V_reduction_root):
198
 
199
  backbone_dict = {'IR_50': IR_50(input_size), 'IR_152': IR_152(input_size), 'ResNet_50': ResNet_50(input_size),
200
  'ResNet_152': ResNet_152(input_size)}
@@ -205,11 +202,14 @@ def prepare_models(model_backbones,
205
  models_attack = []
206
  for i in range(len(model_backbones)):
207
  model = backbone_dict[model_backbones[i]]
208
- model.load_state_dict(torch.load(model_roots[i], map_location=device))
 
 
 
 
209
  models_attack.append(model)
210
 
211
  if using_subspace:
212
-
213
  V_reduction = []
214
  for i in range(len(model_backbones)):
215
  V_reduction.append(torch.tensor(np.load(V_reduction_root[i])))
@@ -221,6 +221,7 @@ def prepare_models(model_backbones,
221
 
222
  return models_attack, V_reduction, dim
223
 
 
224
  def prepare_data(query_data_root, target_data_root, freq, batch_size, warp = False, theta_warp = None):
225
 
226
  data = datasets.ImageFolder(query_data_root, tensor_transform)
 
184
  return features
185
 
186
 
 
 
 
187
  def prepare_models(model_backbones,
188
+ input_size,
189
+ model_roots,
190
+ kernel_size_attack,
191
+ sigma_attack,
192
+ combination,
193
+ using_subspace,
194
+ V_reduction_root):
195
 
196
  backbone_dict = {'IR_50': IR_50(input_size), 'IR_152': IR_152(input_size), 'ResNet_50': ResNet_50(input_size),
197
  'ResNet_152': ResNet_152(input_size)}
 
202
  models_attack = []
203
  for i in range(len(model_backbones)):
204
  model = backbone_dict[model_backbones[i]]
205
+ try:
206
+ model.load_state_dict(torch.load(model_roots[i], map_location=device))
207
+ except Exception as e:
208
+ print(f"Error loading model {model_roots[i]}: {e}")
209
+ continue
210
  models_attack.append(model)
211
 
212
  if using_subspace:
 
213
  V_reduction = []
214
  for i in range(len(model_backbones)):
215
  V_reduction.append(torch.tensor(np.load(V_reduction_root[i])))
 
221
 
222
  return models_attack, V_reduction, dim
223
 
224
+
225
  def prepare_data(query_data_root, target_data_root, freq, batch_size, warp = False, theta_warp = None):
226
 
227
  data = datasets.ImageFolder(query_data_root, tensor_transform)