Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import re | |
import imageio | |
import matplotlib.pyplot as plt | |
import moviepy.editor as mvp | |
import numpy as np | |
import pydiffvg | |
import torch | |
from IPython.display import Image as Image_colab | |
from IPython.display import display, SVG | |
from PIL import Image | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--target_file", type=str, | |
help="target image file, located in <target_images>") | |
parser.add_argument("--num_strokes", type=int) | |
args = parser.parse_args() | |
def read_svg(path_svg, multiply=False): | |
device = torch.device("cuda" if ( | |
torch.cuda.is_available() and torch.cuda.device_count() > 0) else "cpu") | |
canvas_width, canvas_height, shapes, shape_groups = pydiffvg.svg_to_scene( | |
path_svg) | |
if multiply: | |
canvas_width *= 2 | |
canvas_height *= 2 | |
for path in shapes: | |
path.points *= 2 | |
path.stroke_width *= 2 | |
_render = pydiffvg.RenderFunction.apply | |
scene_args = pydiffvg.RenderFunction.serialize_scene( | |
canvas_width, canvas_height, shapes, shape_groups) | |
img = _render(canvas_width, # width | |
canvas_height, # height | |
2, # num_samples_x | |
2, # num_samples_y | |
0, # seed | |
None, | |
*scene_args) | |
img = img[:, :, 3:4] * img[:, :, :3] + \ | |
torch.ones(img.shape[0], img.shape[1], 3, | |
device=device) * (1 - img[:, :, 3:4]) | |
img = img[:, :, :3] | |
return img | |
abs_path = os.path.abspath(os.getcwd()) | |
result_path = f"{abs_path}/output_sketches/{os.path.splitext(args.target_file)[0]}" | |
svg_files = os.listdir(result_path) | |
svg_files = [f for f in svg_files if "best.svg" in f and f"{args.num_strokes}strokes" in f] | |
svg_output_path = f"{result_path}/{svg_files[0]}" | |
target_path = f"{svg_output_path[:-9]}/input.png" | |
sketch_res = read_svg(svg_output_path, multiply=True).cpu().numpy() | |
sketch_res = Image.fromarray((sketch_res * 255).astype('uint8'), 'RGB') | |
input_im = Image.open(target_path).resize((224,224)) | |
display(input_im) | |
display(SVG(svg_output_path)) | |
p = re.compile("_best") | |
best_sketch_dir = "" | |
for m in p.finditer(svg_files[0]): | |
best_sketch_dir += svg_files[0][0: m.start()] | |
sketches = [] | |
cur_path = f"{result_path}/{best_sketch_dir}" | |
sketch_res.save(f"{cur_path}/final_sketch.png") | |
print(f"You can download the result sketch from {cur_path}/final_sketch.png") | |
if not os.path.exists(f"{cur_path}/svg_to_png"): | |
os.mkdir(f"{cur_path}/svg_to_png") | |
if os.path.exists(f"{cur_path}/config.npy"): | |
config = np.load(f"{cur_path}/config.npy", allow_pickle=True)[()] | |
inter = config["save_interval"] | |
loss_eval = np.array(config['loss_eval']) | |
inds = np.argsort(loss_eval) | |
intervals = list(range(0, (inds[0] + 1) * inter, inter)) | |
for i_ in intervals: | |
path_svg = f"{cur_path}/svg_logs/svg_iter{i_}.svg" | |
sketch = read_svg(path_svg, multiply=True).cpu().numpy() | |
sketch = Image.fromarray((sketch * 255).astype('uint8'), 'RGB') | |
# print("{0}/iter_{1:04}.png".format(cur_path, int(i_))) | |
sketch.save("{0}/{1}/iter_{2:04}.png".format(cur_path, "svg_to_png", int(i_))) | |
sketches.append(sketch) | |
imageio.mimsave(f"{cur_path}/sketch.gif", sketches) | |
print(cur_path) | |