MohamedRashad commited on
Commit
2d4f4be
·
1 Parent(s): 9914b63

Enhance Infinity model loading by specifying device type in autocast context for improved performance

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -198,7 +198,7 @@ def load_infinity(
198
  print(f'[Loading Infinity]')
199
  text_maxlen = 512
200
  torch.cuda.empty_cache()
201
- with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
202
  infinity_test: Infinity = Infinity(
203
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
204
  shared_aln=True, raw_scale_schedule=scale_schedule,
 
198
  print(f'[Loading Infinity]')
199
  text_maxlen = 512
200
  torch.cuda.empty_cache()
201
+ with torch.amp.autocast(device_type=device, enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
202
  infinity_test: Infinity = Infinity(
203
  vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
204
  shared_aln=True, raw_scale_schedule=scale_schedule,