MohamedRashad commited on
Commit
54f9225
·
1 Parent(s): c493a61

Add device and dtype logging in app.py for better debugging

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -188,7 +188,7 @@ def load_infinity(
188
  model_path='',
189
  scale_schedule=None,
190
  vae=None,
191
- device='cuda',
192
  model_kwargs=None,
193
  text_channels=2048,
194
  apply_spatial_patchify=0,
@@ -295,6 +295,7 @@ def load_visual_tokenizer(args):
295
 
296
  def load_transformer(vae, args):
297
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
298
  model_path = args.model_path
299
  if args.checkpoint_type == 'torch':
300
  # copy large model to local; save slim to local; and copy slim to nas; load local slim model
@@ -420,8 +421,8 @@ weights_path.mkdir(exist_ok=True)
420
  download_infinity_weights(weights_path)
421
 
422
  # Device setup
423
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
424
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
 
425
 
426
  # Define args
427
  args = argparse.Namespace(
 
188
  model_path='',
189
  scale_schedule=None,
190
  vae=None,
191
+ device='cuda',
192
  model_kwargs=None,
193
  text_channels=2048,
194
  apply_spatial_patchify=0,
 
295
 
296
  def load_transformer(vae, args):
297
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
298
+ print(f"Device: {device}")
299
  model_path = args.model_path
300
  if args.checkpoint_type == 'torch':
301
  # copy large model to local; save slim to local; and copy slim to nas; load local slim model
 
421
  download_infinity_weights(weights_path)
422
 
423
  # Device setup
 
424
  dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
425
+ print(f"Using dtype: {dtype}")
426
 
427
  # Define args
428
  args = argparse.Namespace(