yeq6x commited on
Commit
54507dc
·
1 Parent(s): 27b8d0c
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ output.glb
2
+ venv/
3
+ __pycache__/
4
+ temp_output.ply
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt /app/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . /app
10
+
11
+ EXPOSE 7860
12
+
13
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import gradio as gr
4
+ import open3d as o3d
5
+ import trimesh
6
+ from tqdm import tqdm
7
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, EulerAncestralDiscreteScheduler
8
+ import torch
9
+ from collections import Counter
10
+ import random
11
+
12
+ # import spaces
13
+
14
+ pipe = None
15
+
16
+ def load_model():
17
+ global pipe
18
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
19
+ "yeq6x/animagine_position_map",
20
+ controlnet=ControlNetModel.from_pretrained("yeq6x/Image2PositionColor_v3"),
21
+ # torch_dtype=torch.float16,
22
+ use_safetensors=True,
23
+ # variant="fp16"
24
+ ).to("cuda")
25
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
26
+
27
+ def convert_pil_to_opencv(pil_image):
28
+ return np.array(pil_image)
29
+
30
+ def inv_func(y,
31
+ c = -712.380100,
32
+ a = 137.375240,
33
+ b = 192.435866):
34
+ return (np.exp((y - c) / a) - np.exp(-c/a)) / 964.8468371292845
35
+
36
+ def create_point_cloud(img1, img2):
37
+ if img1.shape != img2.shape:
38
+ raise ValueError("Both images must have the same dimensions.")
39
+
40
+ h, w, _ = img1.shape
41
+ points = []
42
+ colors = []
43
+ for y in tqdm(range(h)):
44
+ for x in range(w):
45
+ # ピクセル位置 (x, y) のRGBをXYZとして取得
46
+ r, g, b = img1[y, x]
47
+ r = inv_func(r) * 0.9
48
+ g = inv_func(g) / 1.7 * 0.6
49
+ b = inv_func(b)
50
+ r *= 150
51
+ g *= 150
52
+ b *= 150
53
+ points.append([g, b, r]) # X, Y, Z
54
+ # 対応するピクセル位置の画像2の色を取得
55
+ colors.append(img2[y, x] / 255.0) # 色は0〜1にスケール
56
+
57
+ return np.array(points), np.array(colors)
58
+
59
+ def point_cloud_to_glb(points, colors):
60
+ # Open3Dでポイントクラウドを作成
61
+ pc = o3d.geometry.PointCloud()
62
+ pc.points = o3d.utility.Vector3dVector(points)
63
+ pc.colors = o3d.utility.Vector3dVector(colors)
64
+
65
+ # 一時的にPLY形式で保存
66
+ temp_ply_file = "temp_output.ply"
67
+ o3d.io.write_point_cloud(temp_ply_file, pc)
68
+
69
+ # PLYをGLBに変換
70
+ mesh = trimesh.load(temp_ply_file)
71
+ glb_file = "output.glb"
72
+ mesh.export(glb_file)
73
+
74
+ return glb_file
75
+
76
+ def visualize_3d(image1, image2):
77
+ print("Processing...")
78
+ # PIL画像をOpenCV形式に変換
79
+ img1 = convert_pil_to_opencv(image1)
80
+ img2 = convert_pil_to_opencv(image2)
81
+
82
+ # ポイントクラウド生成
83
+ points, colors = create_point_cloud(img1, img2)
84
+
85
+ # GLB形式に変換
86
+ glb_file = point_cloud_to_glb(points, colors)
87
+
88
+ return glb_file
89
+
90
+ def scale_image(original_image):
91
+ aspect_ratio = original_image.width / original_image.height
92
+
93
+ if original_image.width > original_image.height:
94
+ new_width = 1024
95
+ new_height = round(new_width / aspect_ratio)
96
+ else:
97
+ new_height = 1024
98
+ new_width = round(new_height * aspect_ratio)
99
+
100
+ resized_original = original_image.resize((new_width, new_height), Image.LANCZOS)
101
+
102
+ return resized_original
103
+
104
+ def get_edge_mode_color(img, edge_width=10):
105
+ # 外周の10ピクセル領域を取得
106
+ left = img.crop((0, 0, edge_width, img.height)) # 左端
107
+ right = img.crop((img.width - edge_width, 0, img.width, img.height)) # 右端
108
+ top = img.crop((0, 0, img.width, edge_width)) # 上端
109
+ bottom = img.crop((0, img.height - edge_width, img.width, img.height)) # 下端
110
+
111
+ # 各領域のピクセルデータを取得して結合
112
+ colors = list(left.getdata()) + list(right.getdata()) + list(top.getdata()) + list(bottom.getdata())
113
+
114
+ # 最頻値(mode)を計算
115
+ mode_color = Counter(colors).most_common(1)[0][0] # 最も頻繁に出現する色を取得
116
+
117
+ return mode_color
118
+
119
+ def paste_image(resized_img):
120
+ # 外周10pxの最頻値を背景色に設定
121
+ mode_color = get_edge_mode_color(resized_img, edge_width=10)
122
+ mode_background = Image.new("RGBA", (1024, 1024), mode_color)
123
+ mode_background = mode_background.convert('RGB')
124
+
125
+ x = (1024 - resized_img.width) // 2
126
+ y = (1024 - resized_img.height) // 2
127
+ mode_background.paste(resized_img, (x, y))
128
+
129
+ return mode_background
130
+
131
+ def outpaint_image(image):
132
+ if type(image) == type(None):
133
+ return None
134
+ resized_img = scale_image(image)
135
+ image = paste_image(resized_img)
136
+
137
+ return image
138
+
139
+ # @spaces.GPU
140
+ def predict_image(cond_image, prompt, negative_prompt):
141
+ generator = torch.Generator()
142
+ generator.manual_seed(random.randint(0, 2147483647))
143
+
144
+ prompt = 'position map, 1girl, white background'
145
+ negative_prompt = "lowres, bad anatomy, bad hands, bad feet, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
146
+
147
+ image = pipe(
148
+ prompt,
149
+ prompt,
150
+ cond_image,
151
+ negative_prompt=negative_prompt,
152
+ width=1024,
153
+ height=1024,
154
+ guidance_scale=8,
155
+ num_inference_steps=20,
156
+ generator=generator,
157
+ guess_mode = True,
158
+ controlnet_conditioning_scale = 0.6,
159
+ ).images[0]
160
+
161
+ return image
162
+
163
+ # load_model()
164
+
165
+ # Gradioアプリケーション
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown("## Position Map Visualizer")
168
+
169
+ with gr.Row():
170
+ with gr.Column():
171
+ with gr.Row():
172
+ img1 = gr.Image(type="pil", label="color Image", height=300)
173
+ img2 = gr.Image(type="pil", label="map Image", height=300)
174
+ prompt = gr.Textbox("position map, 1girl, white background", label="Prompt")
175
+ negative_prompt = gr.Textbox("lowres, bad anatomy, bad hands, bad feet, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry", label="Negative Prompt")
176
+ predict_map_btn = gr.Button("Predict Position Map")
177
+ visualize_3d_btn = gr.Button("Generate 3D Point Cloud")
178
+ with gr.Column():
179
+ reconstruction_output = gr.Model3D(label="3D Viewer", height=600)
180
+ gr.Examples(
181
+ examples=[
182
+ ["resources/source/000006.png", "resources/target/000006.png"],
183
+ ["resources/source/006420.png", "resources/target/006420.png"],
184
+ ],
185
+ inputs=[img1, img2]
186
+ )
187
+
188
+ img1.input(outpaint_image, inputs=img1, outputs=img1)
189
+ predict_map_btn.click(predict_image, inputs=[img1, prompt, negative_prompt], outputs=img2)
190
+ visualize_3d_btn.click(visualize_3d, inputs=[img2, img1], outputs=reconstruction_output)
191
+
192
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.2.0
3
+ diffusers
4
+ gradio
5
+ open3d
6
+ numpy
7
+ opencv-python
8
+ trimesh
resources/source/000006.png ADDED
resources/source/006420.png ADDED
resources/target/000006.png ADDED
resources/target/006420.png ADDED