Spaces:
Runtime error
Runtime error
import warnings | |
warnings.filterwarnings('ignore') | |
warnings.simplefilter('ignore') | |
import argparse | |
import math | |
import os | |
import sys | |
import time | |
import traceback | |
import numpy as np | |
import PIL | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from PIL import Image | |
from torchvision import models, transforms | |
from tqdm.auto import tqdm, trange | |
import config | |
import sketch_utils as utils | |
from models.loss import Loss | |
from models.painter_params import Painter, PainterOptimizer | |
from IPython.display import display, SVG | |
def load_renderer(args, target_im=None, mask=None): | |
renderer = Painter(num_strokes=args.num_paths, args=args, | |
num_segments=args.num_segments, | |
imsize=args.image_scale, | |
device=args.device, | |
target_im=target_im, | |
mask=mask) | |
renderer = renderer.to(args.device) | |
return renderer | |
def get_target(args): | |
target = Image.open(args.target) | |
if target.mode == "RGBA": | |
# Create a white rgba background | |
new_image = Image.new("RGBA", target.size, "WHITE") | |
# Paste the image on the background. | |
new_image.paste(target, (0, 0), target) | |
target = new_image | |
target = target.convert("RGB") | |
masked_im, mask = utils.get_mask_u2net(args, target) | |
if args.mask_object: | |
target = masked_im | |
if args.fix_scale: | |
target = utils.fix_image_scale(target) | |
transforms_ = [] | |
if target.size[0] != target.size[1]: | |
transforms_.append(transforms.Resize( | |
(args.image_scale, args.image_scale), interpolation=PIL.Image.BICUBIC)) | |
else: | |
transforms_.append(transforms.Resize( | |
args.image_scale, interpolation=PIL.Image.BICUBIC)) | |
transforms_.append(transforms.CenterCrop(args.image_scale)) | |
transforms_.append(transforms.ToTensor()) | |
data_transforms = transforms.Compose(transforms_) | |
target_ = data_transforms(target).unsqueeze(0).to(args.device) | |
return target_, mask | |
def main(args): | |
loss_func = Loss(args) | |
inputs, mask = get_target(args) | |
utils.log_input(args.use_wandb, 0, inputs, args.output_dir) | |
renderer = load_renderer(args, inputs, mask) | |
optimizer = PainterOptimizer(args, renderer) | |
counter = 0 | |
configs_to_save = {"loss_eval": []} | |
best_loss, best_fc_loss = 100, 100 | |
best_iter, best_iter_fc = 0, 0 | |
min_delta = 1e-5 | |
terminate = False | |
renderer.set_random_noise(0) | |
img = renderer.init_image(stage=0) | |
optimizer.init_optimizers() | |
# not using tdqm for jupyter demo | |
if args.display: | |
epoch_range = range(args.num_iter) | |
else: | |
epoch_range = tqdm(range(args.num_iter)) | |
for epoch in epoch_range: | |
if not args.display: | |
epoch_range.refresh() | |
renderer.set_random_noise(epoch) | |
if args.lr_scheduler: | |
optimizer.update_lr(counter) | |
start = time.time() | |
optimizer.zero_grad_() | |
sketches = renderer.get_image().to(args.device) | |
losses_dict = loss_func(sketches, inputs.detach( | |
), renderer.get_color_parameters(), renderer, counter, optimizer) | |
loss = sum(list(losses_dict.values())) | |
loss.backward() | |
optimizer.step_() | |
if epoch % args.save_interval == 0: | |
utils.plot_batch(inputs, sketches, f"{args.output_dir}/jpg_logs", counter, | |
use_wandb=args.use_wandb, title=f"iter{epoch}.jpg") | |
renderer.save_svg( | |
f"{args.output_dir}/svg_logs", f"svg_iter{epoch}") | |
if epoch % args.eval_interval == 0: | |
with torch.no_grad(): | |
losses_dict_eval = loss_func(sketches, inputs, renderer.get_color_parameters( | |
), renderer.get_points_parans(), counter, optimizer, mode="eval") | |
loss_eval = sum(list(losses_dict_eval.values())) | |
configs_to_save["loss_eval"].append(loss_eval.item()) | |
for k in losses_dict_eval.keys(): | |
if k not in configs_to_save.keys(): | |
configs_to_save[k] = [] | |
configs_to_save[k].append(losses_dict_eval[k].item()) | |
if args.clip_fc_loss_weight: | |
if losses_dict_eval["fc"].item() < best_fc_loss: | |
best_fc_loss = losses_dict_eval["fc"].item( | |
) / args.clip_fc_loss_weight | |
best_iter_fc = epoch | |
# print( | |
# f"eval iter[{epoch}/{args.num_iter}] loss[{loss.item()}] time[{time.time() - start}]") | |
cur_delta = loss_eval.item() - best_loss | |
if abs(cur_delta) > min_delta: | |
if cur_delta < 0: | |
best_loss = loss_eval.item() | |
best_iter = epoch | |
terminate = False | |
utils.plot_batch( | |
inputs, sketches, args.output_dir, counter, use_wandb=args.use_wandb, title="best_iter.jpg") | |
renderer.save_svg(args.output_dir, "best_iter") | |
if args.use_wandb: | |
wandb.run.summary["best_loss"] = best_loss | |
wandb.run.summary["best_loss_fc"] = best_fc_loss | |
wandb_dict = {"delta": cur_delta, | |
"loss_eval": loss_eval.item()} | |
for k in losses_dict_eval.keys(): | |
wandb_dict[k + "_eval"] = losses_dict_eval[k].item() | |
wandb.log(wandb_dict, step=counter) | |
if abs(cur_delta) <= min_delta: | |
if terminate: | |
break | |
terminate = True | |
if counter == 0 and args.attention_init: | |
utils.plot_atten(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(), | |
args.use_wandb, "{}/{}.jpg".format( | |
args.output_dir, "attention_map"), | |
args.saliency_model, args.display_logs) | |
if args.use_wandb: | |
wandb_dict = {"loss": loss.item(), "lr": optimizer.get_lr()} | |
for k in losses_dict.keys(): | |
wandb_dict[k] = losses_dict[k].item() | |
wandb.log(wandb_dict, step=counter) | |
counter += 1 | |
renderer.save_svg(args.output_dir, "final_svg") | |
path_svg = os.path.join(args.output_dir, "best_iter.svg") | |
utils.log_sketch_summary_final( | |
path_svg, args.use_wandb, args.device, best_iter, best_loss, "best total") | |
return configs_to_save | |
if __name__ == "__main__": | |
args = config.parse_arguments() | |
final_config = vars(args) | |
try: | |
configs_to_save = main(args) | |
except BaseException as err: | |
print(f"Unexpected error occurred:\n {err}") | |
print(traceback.format_exc()) | |
sys.exit(1) | |
for k in configs_to_save.keys(): | |
final_config[k] = configs_to_save[k] | |
np.save(f"{args.output_dir}/config.npy", final_config) | |
if args.use_wandb: | |
wandb.finish() | |