anvilarth commited on
Commit
293785c
·
verified ·
1 Parent(s): be8c1ae

update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -3
app.py CHANGED
@@ -1,9 +1,12 @@
 
1
  import cv2
2
- import numpy as np
3
- import torch
4
  import time
5
- from PIL import Image, ImageDraw
 
 
6
  import gradio as gr
 
 
7
  import matplotlib.pyplot as plt
8
 
9
  from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, sam_model_registry
@@ -43,6 +46,7 @@ class GradioWindow():
43
  self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH
44
 
45
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
46
  # for debug
47
  # self.augmenter = None
48
  self.augmenter = Augmenter(device=self.device)
@@ -159,6 +163,52 @@ class GradioWindow():
159
  outputs=[augmented_img, generated_prompt],
160
  )
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def setup_model(self) -> SamPredictor:
163
  self.sam = sam_model_registry[self.model_type](checkpoint=self.SAM_CHECKPOINT_PATH)
164
  self.sam.to(device=self.device)
 
1
+ import os
2
  import cv2
 
 
3
  import time
4
+ import torch
5
+ import subprocess
6
+ import numpy as np
7
  import gradio as gr
8
+ import urllib.request
9
+ from PIL import Image, ImageDraw
10
  import matplotlib.pyplot as plt
11
 
12
  from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, sam_model_registry
 
46
  self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH
47
 
48
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ self.download_weights()
50
  # for debug
51
  # self.augmenter = None
52
  self.augmenter = Augmenter(device=self.device)
 
163
  outputs=[augmented_img, generated_prompt],
164
  )
165
 
166
+ def download_weights(self):
167
+ models = [
168
+ "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/",
169
+ "https://huggingface.co/llava-hf/llava-1.5-7b-hf",
170
+ "https://huggingface.co/danulkin/llama",
171
+ ]
172
+
173
+ destinations = [
174
+ "Garage/models/checkpoints/ppt-v2-1",
175
+ "Garage/models/checkpoints/llava-1.5-7b-hf",
176
+ "Garage/models/checkpoints/llama-3-8b-Instruct",
177
+ ]
178
+
179
+ if not os.path.exists("Garage/models/checkpoints"):
180
+ os.makedirs("Garage/models/checkpoints")
181
+
182
+ for model, destination in zip(models, destinations):
183
+ # Git LFS clone command
184
+ command = ["git", "lfs", "clone", model, destination]
185
+ try:
186
+ result = subprocess.run(command, check=True, text=True, capture_output=True)
187
+ print("Command Output:", result.stdout)
188
+
189
+ except subprocess.CalledProcessError as e:
190
+ print(f"Error: {e}")
191
+ print("Command Output:", e.output)
192
+
193
+ models = [
194
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
195
+ "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
196
+ ]
197
+
198
+ destinations = [
199
+ "Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth",
200
+ "Garage/models/checkpoints/GroundedSegmentAnything/groundingdino_swint_ogc.pth",
201
+ ]
202
+ if not os.path.exists("Garage/models/checkpoints/GroundedSegmentAnything"):
203
+ os.makedirs("Garage/models/checkpoints/GroundedSegmentAnything")
204
+
205
+ for model, destination in zip(models, destinations):
206
+ if not os.path.exists(destination):
207
+ urllib.request.urlretrieve(model, destination)
208
+ print(f"Downloaded {model} to {destination}")
209
+ else:
210
+ print(f"Model {model} already exists")
211
+
212
  def setup_model(self) -> SamPredictor:
213
  self.sam = sam_model_registry[self.model_type](checkpoint=self.SAM_CHECKPOINT_PATH)
214
  self.sam.to(device=self.device)