prithivMLmods commited on
Commit
43f0687
·
verified ·
1 Parent(s): 6c3e861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -113,7 +113,7 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
113
  ).to(device)
114
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
 
116
- # **Fix for dtype mismatch in the text encoder:**
117
  if torch.cuda.is_available():
118
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
119
 
@@ -171,12 +171,19 @@ def generate_image_fn(
171
  options["use_resolution_binning"] = True
172
 
173
  images = []
 
174
  for i in range(0, num_images, BATCH_SIZE):
175
  batch_options = options.copy()
176
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
177
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
178
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
179
- images.extend(sd_pipe(**batch_options).images)
 
 
 
 
 
 
180
  image_paths = [save_image(img) for img in images]
181
  return image_paths, seed
182
 
 
113
  ).to(device)
114
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
 
116
+ # Ensure that the text encoder is in half-precision if using CUDA.
117
  if torch.cuda.is_available():
118
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
119
 
 
171
  options["use_resolution_binning"] = True
172
 
173
  images = []
174
+ # Process in batches
175
  for i in range(0, num_images, BATCH_SIZE):
176
  batch_options = options.copy()
177
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
178
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
179
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
180
+ # Wrap the pipeline call in autocast if using CUDA
181
+ if device.type == "cuda":
182
+ with torch.autocast("cuda", dtype=torch.float16):
183
+ outputs = sd_pipe(**batch_options)
184
+ else:
185
+ outputs = sd_pipe(**batch_options)
186
+ images.extend(outputs.images)
187
  image_paths = [save_image(img) for img in images]
188
  return image_paths, seed
189