Spaces:
Running
Running
Create helper.py
Browse files- src/helper.py +87 -0
src/helper.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.hub import download_url_to_file, get_dir
|
9 |
+
|
10 |
+
LAMA_MODEL_URL = os.environ.get(
|
11 |
+
"LAMA_MODEL_URL",
|
12 |
+
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def download_model(url=LAMA_MODEL_URL):
|
17 |
+
parts = urlparse(url)
|
18 |
+
hub_dir = get_dir()
|
19 |
+
model_dir = os.path.join(hub_dir, "checkpoints")
|
20 |
+
if not os.path.isdir(model_dir):
|
21 |
+
os.makedirs(os.path.join(model_dir, "hub", "checkpoints"))
|
22 |
+
filename = os.path.basename(parts.path)
|
23 |
+
cached_file = os.path.join(model_dir, filename)
|
24 |
+
if not os.path.exists(cached_file):
|
25 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
26 |
+
hash_prefix = None
|
27 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=True)
|
28 |
+
return cached_file
|
29 |
+
|
30 |
+
|
31 |
+
def ceil_modulo(x, mod):
|
32 |
+
if x % mod == 0:
|
33 |
+
return x
|
34 |
+
return (x // mod + 1) * mod
|
35 |
+
|
36 |
+
|
37 |
+
def numpy_to_bytes(image_numpy: np.ndarray) -> bytes:
|
38 |
+
data = cv2.imencode(".jpg", image_numpy)[1]
|
39 |
+
image_bytes = data.tobytes()
|
40 |
+
return image_bytes
|
41 |
+
|
42 |
+
|
43 |
+
def load_img(img_bytes, gray: bool = False):
|
44 |
+
nparr = np.frombuffer(img_bytes, np.uint8)
|
45 |
+
if gray:
|
46 |
+
np_img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
|
47 |
+
else:
|
48 |
+
np_img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
|
49 |
+
if len(np_img.shape) == 3 and np_img.shape[2] == 4:
|
50 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGRA2RGB)
|
51 |
+
else:
|
52 |
+
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
53 |
+
|
54 |
+
return np_img
|
55 |
+
|
56 |
+
|
57 |
+
def norm_img(np_img):
|
58 |
+
if len(np_img.shape) == 2:
|
59 |
+
np_img = np_img[:, :, np.newaxis]
|
60 |
+
np_img = np.transpose(np_img, (2, 0, 1))
|
61 |
+
np_img = np_img.astype("float32") / 255
|
62 |
+
return np_img
|
63 |
+
|
64 |
+
|
65 |
+
def resize_max_size(
|
66 |
+
np_img, size_limit: int, interpolation=cv2.INTER_CUBIC
|
67 |
+
) -> np.ndarray:
|
68 |
+
# Resize image's longer size to size_limit if longer size larger than size_limit
|
69 |
+
h, w = np_img.shape[:2]
|
70 |
+
if max(h, w) > size_limit:
|
71 |
+
ratio = size_limit / max(h, w)
|
72 |
+
new_w = int(w * ratio + 0.5)
|
73 |
+
new_h = int(h * ratio + 0.5)
|
74 |
+
return cv2.resize(np_img, dsize=(new_w, new_h), interpolation=interpolation)
|
75 |
+
else:
|
76 |
+
return np_img
|
77 |
+
|
78 |
+
|
79 |
+
def pad_img_to_modulo(img, mod):
|
80 |
+
channels, height, width = img.shape
|
81 |
+
out_height = ceil_modulo(height, mod)
|
82 |
+
out_width = ceil_modulo(width, mod)
|
83 |
+
return np.pad(
|
84 |
+
img,
|
85 |
+
((0, 0), (0, out_height - height), (0, out_width - width)),
|
86 |
+
mode="symmetric",
|
87 |
+
)
|