MohamedRashad commited on
Commit
8e9da2c
·
1 Parent(s): 2d4f4be

Make device optional in load_infinity function; set default to 'cuda' or 'cpu' based on availability and adjust autocast dtype handling

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -188,7 +188,7 @@ def load_infinity(
188
  model_path='',
189
  scale_schedule=None,
190
  vae=None,
191
- device='cuda',
192
  model_kwargs=None,
193
  text_channels=2048,
194
  apply_spatial_patchify=0,
@@ -196,9 +196,23 @@ def load_infinity(
196
  bf16=False,
197
  ):
198
  print(f'[Loading Infinity]')
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  text_maxlen = 512
200
  torch.cuda.empty_cache()
201
- with torch.amp.autocast(device_type=device, enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
 
202
  infinity_test: Infinity = Infinity(
203
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
204
  shared_aln=True, raw_scale_schedule=scale_schedule,
@@ -217,6 +231,7 @@ def load_infinity(
217
  train_h_div_w_list=[1.0],
218
  **model_kwargs,
219
  ).to(device)
 
220
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
221
 
222
  if bf16:
@@ -229,7 +244,10 @@ def load_infinity(
229
  print(f'[Load Infinity weights]')
230
  state_dict = torch.load(model_path, map_location=device)
231
  print(infinity_test.load_state_dict(state_dict))
 
 
232
  infinity_test.rng = torch.Generator(device=device)
 
233
  return infinity_test
234
 
235
  def transform(pil_img, tgt_h, tgt_w):
 
188
  model_path='',
189
  scale_schedule=None,
190
  vae=None,
191
+ device=None, # Make device optional
192
  model_kwargs=None,
193
  text_channels=2048,
194
  apply_spatial_patchify=0,
 
196
  bf16=False,
197
  ):
198
  print(f'[Loading Infinity]')
199
+
200
+ # Set device if not provided
201
+ if device is None:
202
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
203
+ print(f'Using device: {device}')
204
+
205
+ # Set autocast dtype based on bf16 and device support
206
+ if bf16 and device == 'cuda' and torch.cuda.is_bf16_supported():
207
+ autocast_dtype = torch.bfloat16
208
+ else:
209
+ autocast_dtype = torch.float32
210
+ bf16 = False # Disable bf16 if not supported
211
+
212
  text_maxlen = 512
213
  torch.cuda.empty_cache()
214
+
215
+ with torch.amp.autocast(device_type=device, enabled=bf16, dtype=autocast_dtype, cache_enabled=True), torch.no_grad():
216
  infinity_test: Infinity = Infinity(
217
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
218
  shared_aln=True, raw_scale_schedule=scale_schedule,
 
231
  train_h_div_w_list=[1.0],
232
  **model_kwargs,
233
  ).to(device)
234
+
235
  print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')
236
 
237
  if bf16:
 
244
  print(f'[Load Infinity weights]')
245
  state_dict = torch.load(model_path, map_location=device)
246
  print(infinity_test.load_state_dict(state_dict))
247
+
248
+ # Initialize random number generator on the correct device
249
  infinity_test.rng = torch.Generator(device=device)
250
+
251
  return infinity_test
252
 
253
  def transform(pil_img, tgt_h, tgt_w):