Commit
·
23fa49c
1
Parent(s):
488df98
First commit
Browse files- README.md +0 -2
- app.py +138 -0
- data/wd-v1-4-convnext-tagger-v2/clip.msgpack +3 -0
- index/cosine_ids.npy +3 -0
- index/cosine_infos.json +1 -0
- index/cosine_knn.index +3 -0
- requirements.txt +3 -0
README.md
CHANGED
@@ -9,5 +9,3 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
|
|
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import faiss
|
4 |
+
import flax
|
5 |
+
import gradio as gr
|
6 |
+
import jax
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import requests
|
10 |
+
|
11 |
+
from Models.CLIP import CLIP
|
12 |
+
|
13 |
+
|
14 |
+
def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""):
|
15 |
+
headers = {"User-Agent": "image_similarity_tool"}
|
16 |
+
ratings_to_letters = {
|
17 |
+
"General": "g",
|
18 |
+
"Sensitive": "s",
|
19 |
+
"Questionable": "q",
|
20 |
+
"Explicit": "e",
|
21 |
+
}
|
22 |
+
|
23 |
+
acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings]
|
24 |
+
|
25 |
+
image_url = f"https://danbooru.donmai.us/posts/{image_id}.json"
|
26 |
+
if api_username != "" and api_key != "":
|
27 |
+
image_url = f"{image_url}?api_key={api_key}&login={api_username}"
|
28 |
+
|
29 |
+
r = requests.get(image_url, headers=headers)
|
30 |
+
if r.status_code != 200:
|
31 |
+
return None
|
32 |
+
|
33 |
+
content = json.loads(r.text)
|
34 |
+
image_url = content["large_file_url"] if "large_file_url" in content else None
|
35 |
+
image_url = image_url if content["rating"] in acceptable_ratings else None
|
36 |
+
return image_url
|
37 |
+
|
38 |
+
|
39 |
+
class Predictor:
|
40 |
+
def __init__(self):
|
41 |
+
self.base_model = "wd-v1-4-convnext-tagger-v2"
|
42 |
+
|
43 |
+
with open(f"data/{self.base_model}/clip.msgpack", "rb") as f:
|
44 |
+
data = f.read()
|
45 |
+
|
46 |
+
self.params = flax.serialization.msgpack_restore(data)["model"]
|
47 |
+
self.model = CLIP()
|
48 |
+
|
49 |
+
self.tags_df = pd.read_csv("data/selected_tags.csv")
|
50 |
+
|
51 |
+
self.images_ids = np.load("index/cosine_ids.npy")
|
52 |
+
|
53 |
+
self.knn_index = faiss.read_index("index/cosine_knn.index")
|
54 |
+
|
55 |
+
config = json.loads(open("index/cosine_infos.json").read())["index_param"]
|
56 |
+
faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
|
57 |
+
|
58 |
+
def predict(self, positive_tags, negative_tags, n_neighbours=5):
|
59 |
+
tags_df = self.tags_df
|
60 |
+
model = self.model
|
61 |
+
|
62 |
+
num_classes = len(tags_df)
|
63 |
+
|
64 |
+
positive_tags = positive_tags.split(",")
|
65 |
+
negative_tags = negative_tags.split(",")
|
66 |
+
|
67 |
+
positive_tags_idxs = tags_df.index[tags_df["name"].isin(positive_tags)].tolist()
|
68 |
+
negative_tags_idxs = tags_df.index[tags_df["name"].isin(negative_tags)].tolist()
|
69 |
+
|
70 |
+
tags = np.zeros((1, num_classes), dtype=np.float32)
|
71 |
+
tags[0][positive_tags_idxs] = 1
|
72 |
+
emb_from_logits = model.apply(
|
73 |
+
{"params": self.params},
|
74 |
+
tags,
|
75 |
+
method=model.encode_text,
|
76 |
+
)
|
77 |
+
emb_from_logits = jax.device_get(emb_from_logits)
|
78 |
+
|
79 |
+
if len(negative_tags_idxs) > 0:
|
80 |
+
tags = np.zeros((1, num_classes), dtype=np.float32)
|
81 |
+
tags[0][negative_tags_idxs] = 1
|
82 |
+
|
83 |
+
neg_emb_from_logits = model.apply(
|
84 |
+
{"params": self.params},
|
85 |
+
tags,
|
86 |
+
method=model.encode_text,
|
87 |
+
)
|
88 |
+
neg_emb_from_logits = jax.device_get(neg_emb_from_logits)
|
89 |
+
emb_from_logits = emb_from_logits - neg_emb_from_logits
|
90 |
+
|
91 |
+
faiss.normalize_L2(emb_from_logits)
|
92 |
+
|
93 |
+
dists, indexes = self.knn_index.search(emb_from_logits, k=n_neighbours)
|
94 |
+
neighbours_ids = self.images_ids[indexes][0]
|
95 |
+
neighbours_ids = [int(x) for x in neighbours_ids]
|
96 |
+
|
97 |
+
captions = []
|
98 |
+
image_urls = []
|
99 |
+
for image_id, dist in zip(neighbours_ids, dists[0]):
|
100 |
+
current_url = danbooru_id_to_url(
|
101 |
+
image_id,
|
102 |
+
[
|
103 |
+
"General",
|
104 |
+
"Sensitive",
|
105 |
+
"Questionable",
|
106 |
+
"Explicit",
|
107 |
+
],
|
108 |
+
)
|
109 |
+
if current_url is not None:
|
110 |
+
image_urls.append(current_url)
|
111 |
+
captions.append(f"{image_id}/{dist:.2f}")
|
112 |
+
return list(zip(image_urls, captions))
|
113 |
+
|
114 |
+
|
115 |
+
def main():
|
116 |
+
predictor = Predictor()
|
117 |
+
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
with gr.Row():
|
120 |
+
positive_tags = gr.Textbox(label="Positive tags")
|
121 |
+
negative_tags = gr.Textbox(label="Negative tags")
|
122 |
+
|
123 |
+
find_btn = gr.Button("Find similar images")
|
124 |
+
|
125 |
+
similar_images = gr.Gallery(label="Similar images", columns=[5])
|
126 |
+
|
127 |
+
find_btn.click(
|
128 |
+
fn=predictor.predict,
|
129 |
+
inputs=[positive_tags, negative_tags],
|
130 |
+
outputs=[similar_images],
|
131 |
+
)
|
132 |
+
|
133 |
+
demo.queue()
|
134 |
+
demo.launch()
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
main()
|
data/wd-v1-4-convnext-tagger-v2/clip.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3be3b97824313f01d9f1d74c43e441199b7ea485f5698d2008739f34c3e41200
|
3 |
+
size 48689306
|
index/cosine_ids.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df724519c8c1981e49d80e2430261deb4fb6edf6d9c04e134427879710747394
|
3 |
+
size 21830676
|
index/cosine_infos.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"index_key": "OPQ256_1280,IVF16384_HNSW32,PQ256x8", "index_param": "nprobe=16,efSearch=32,ht=2048", "index_path": "/home/SmilingWolf/eval/index/ConvNextBV1_01_14_2023_08h37m46s_cosine_knn.index", "size in bytes": 1535843672, "avg_search_speed_ms": 10.164478485783887, "99p_search_speed_ms": 12.419190758373587, "reconstruction error %": 22.007358074188232, "nb vectors": 5457637, "vectors dimension": 1024, "compression ratio": 14.555180035276402}
|
index/cosine_knn.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a718ab8370df8b9d84002c55f945ef241e4cc3450d306c2ecd97661f51022ad
|
3 |
+
size 1535843672
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
faiss
|
2 |
+
jax[cpu]
|
3 |
+
flax
|