spacepxl commited on
Commit
57845c8
·
verified ·
1 Parent(s): 48a076a

Upload convert.py

Browse files
Files changed (1) hide show
  1. convert.py +22 -0
convert.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import save_file
3
+
4
+ input_file = "AnimateDiffV3MotionModule____training_NOISEWARP_72x40_Test_Sep14_FromV3_RandomlyDegraded-2024-09-19T12-37-23____checkpoint_latest.ckpt"
5
+
6
+ model = torch.load(input_file, weights_only=True)
7
+
8
+ mm_sd = {}
9
+ for key in model["state_dict"].keys():
10
+ if "motion_modules" in key:
11
+ new_key = key.replace("module.", "")
12
+ mm_sd[new_key] = model["state_dict"][key]
13
+
14
+ save_file(mm_sd, "motion_model.safetensors")
15
+
16
+ unet_sd = {}
17
+ for key in model["state_dict"].keys():
18
+ if "motion_modules" not in key:
19
+ new_key = key.replace("module.", "")
20
+ unet_sd[new_key] = model["state_dict"][key]
21
+
22
+ save_file(unet_sd, "unet.safetensors")