update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
1 |
import cv2
|
2 |
-
import numpy as np
|
3 |
-
import torch
|
4 |
import time
|
5 |
-
|
|
|
|
|
6 |
import gradio as gr
|
|
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
|
9 |
from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, sam_model_registry
|
@@ -43,6 +46,7 @@ class GradioWindow():
|
|
43 |
self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH
|
44 |
|
45 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
46 |
# for debug
|
47 |
# self.augmenter = None
|
48 |
self.augmenter = Augmenter(device=self.device)
|
@@ -159,6 +163,52 @@ class GradioWindow():
|
|
159 |
outputs=[augmented_img, generated_prompt],
|
160 |
)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
def setup_model(self) -> SamPredictor:
|
163 |
self.sam = sam_model_registry[self.model_type](checkpoint=self.SAM_CHECKPOINT_PATH)
|
164 |
self.sam.to(device=self.device)
|
|
|
1 |
+
import os
|
2 |
import cv2
|
|
|
|
|
3 |
import time
|
4 |
+
import torch
|
5 |
+
import subprocess
|
6 |
+
import numpy as np
|
7 |
import gradio as gr
|
8 |
+
import urllib.request
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
import matplotlib.pyplot as plt
|
11 |
|
12 |
from Garage.models.GroundedSegmentAnything.segment_anything.segment_anything import SamPredictor, sam_model_registry
|
|
|
46 |
self.SAM_CHECKPOINT_PATH = SAM_CHECKPOINT_PATH
|
47 |
|
48 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
self.download_weights()
|
50 |
# for debug
|
51 |
# self.augmenter = None
|
52 |
self.augmenter = Augmenter(device=self.device)
|
|
|
163 |
outputs=[augmented_img, generated_prompt],
|
164 |
)
|
165 |
|
166 |
+
def download_weights(self):
|
167 |
+
models = [
|
168 |
+
"https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/",
|
169 |
+
"https://huggingface.co/llava-hf/llava-1.5-7b-hf",
|
170 |
+
"https://huggingface.co/danulkin/llama",
|
171 |
+
]
|
172 |
+
|
173 |
+
destinations = [
|
174 |
+
"Garage/models/checkpoints/ppt-v2-1",
|
175 |
+
"Garage/models/checkpoints/llava-1.5-7b-hf",
|
176 |
+
"Garage/models/checkpoints/llama-3-8b-Instruct",
|
177 |
+
]
|
178 |
+
|
179 |
+
if not os.path.exists("Garage/models/checkpoints"):
|
180 |
+
os.makedirs("Garage/models/checkpoints")
|
181 |
+
|
182 |
+
for model, destination in zip(models, destinations):
|
183 |
+
# Git LFS clone command
|
184 |
+
command = ["git", "lfs", "clone", model, destination]
|
185 |
+
try:
|
186 |
+
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
187 |
+
print("Command Output:", result.stdout)
|
188 |
+
|
189 |
+
except subprocess.CalledProcessError as e:
|
190 |
+
print(f"Error: {e}")
|
191 |
+
print("Command Output:", e.output)
|
192 |
+
|
193 |
+
models = [
|
194 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
195 |
+
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
196 |
+
]
|
197 |
+
|
198 |
+
destinations = [
|
199 |
+
"Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth",
|
200 |
+
"Garage/models/checkpoints/GroundedSegmentAnything/groundingdino_swint_ogc.pth",
|
201 |
+
]
|
202 |
+
if not os.path.exists("Garage/models/checkpoints/GroundedSegmentAnything"):
|
203 |
+
os.makedirs("Garage/models/checkpoints/GroundedSegmentAnything")
|
204 |
+
|
205 |
+
for model, destination in zip(models, destinations):
|
206 |
+
if not os.path.exists(destination):
|
207 |
+
urllib.request.urlretrieve(model, destination)
|
208 |
+
print(f"Downloaded {model} to {destination}")
|
209 |
+
else:
|
210 |
+
print(f"Model {model} already exists")
|
211 |
+
|
212 |
def setup_model(self) -> SamPredictor:
|
213 |
self.sam = sam_model_registry[self.model_type](checkpoint=self.SAM_CHECKPOINT_PATH)
|
214 |
self.sam.to(device=self.device)
|