Transformers
Inference Endpoints
pantat88 commited on
Commit
4c306dc
·
verified ·
1 Parent(s): 755b6d7

Upload lora_block_weight.py

Browse files
Files changed (1) hide show
  1. lora_block_weight.py +1152 -0
lora_block_weight.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import os
4
+ import gc
5
+ import re
6
+ import sys
7
+ import torch
8
+ import shutil
9
+ import math
10
+ import importlib
11
+ import numpy as np
12
+ import gradio as gr
13
+ import os.path
14
+ import random
15
+ from pprint import pprint
16
+ import modules.ui
17
+ import modules.scripts as scripts
18
+ from PIL import Image, ImageFont, ImageDraw
19
+ import modules.shared as shared
20
+ from modules import devices, sd_models, images,cmd_args, extra_networks, sd_hijack
21
+ from modules.shared import cmd_opts, opts, state
22
+ from modules.processing import process_images, Processed
23
+ from modules.script_callbacks import CFGDenoiserParams, on_cfg_denoiser
24
+
25
+ LBW_T = "customscript/lora_block_weight.py/txt2img/Active/value"
26
+ LBW_I = "customscript/lora_block_weight.py/img2img/Active/value"
27
+
28
+ if os.path.exists(cmd_opts.ui_config_file):
29
+ with open(cmd_opts.ui_config_file, 'r', encoding="utf-8") as json_file:
30
+ ui_config = json.load(json_file)
31
+ else:
32
+ print("ui config file not found, using default values")
33
+ ui_config = {}
34
+
35
+ startup_t = ui_config[LBW_T] if LBW_T in ui_config else None
36
+ startup_i = ui_config[LBW_I] if LBW_I in ui_config else None
37
+ active_t = "Active" if startup_t else "Not Active"
38
+ active_i = "Active" if startup_i else "Not Active"
39
+
40
+ lxyz = ""
41
+ lzyx = ""
42
+ prompts = ""
43
+ xyelem = ""
44
+ princ = False
45
+
46
+ try:
47
+ from ldm_patched.modules import model_management
48
+ forge = True
49
+ except:
50
+ forge = False
51
+
52
+ BLOCKID26=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
53
+ BLOCKID17=["BASE","IN01","IN02","IN04","IN05","IN07","IN08","M00","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11"]
54
+ BLOCKID12=["BASE","IN04","IN05","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05"]
55
+ BLOCKID20=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08"]
56
+ BLOCKNUMS = [12,17,20,26]
57
+ BLOCKIDS=[BLOCKID12,BLOCKID17,BLOCKID20,BLOCKID26]
58
+
59
+ BLOCKS=["encoder",
60
+ "diffusion_model_input_blocks_0_",
61
+ "diffusion_model_input_blocks_1_",
62
+ "diffusion_model_input_blocks_2_",
63
+ "diffusion_model_input_blocks_3_",
64
+ "diffusion_model_input_blocks_4_",
65
+ "diffusion_model_input_blocks_5_",
66
+ "diffusion_model_input_blocks_6_",
67
+ "diffusion_model_input_blocks_7_",
68
+ "diffusion_model_input_blocks_8_",
69
+ "diffusion_model_input_blocks_9_",
70
+ "diffusion_model_input_blocks_10_",
71
+ "diffusion_model_input_blocks_11_",
72
+ "diffusion_model_middle_block_",
73
+ "diffusion_model_output_blocks_0_",
74
+ "diffusion_model_output_blocks_1_",
75
+ "diffusion_model_output_blocks_2_",
76
+ "diffusion_model_output_blocks_3_",
77
+ "diffusion_model_output_blocks_4_",
78
+ "diffusion_model_output_blocks_5_",
79
+ "diffusion_model_output_blocks_6_",
80
+ "diffusion_model_output_blocks_7_",
81
+ "diffusion_model_output_blocks_8_",
82
+ "diffusion_model_output_blocks_9_",
83
+ "diffusion_model_output_blocks_10_",
84
+ "diffusion_model_output_blocks_11_",
85
+ "embedders"]
86
+
87
+ loopstopper = True
88
+
89
+ ATYPES =["none","Block ID","values","seed","Original Weights","elements"]
90
+
91
+ DEF_WEIGHT_PRESET = "\
92
+ NONE:0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
93
+ ALL:1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\n\
94
+ INS:1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0\n\
95
+ IND:1,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
96
+ INALL:1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0\n\
97
+ MIDD:1,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0\n\
98
+ OUTD:1,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0\n\
99
+ OUTS:1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1\n\
100
+ OUTALL:1,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1\n\
101
+ ALL0.5:0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5"
102
+
103
+ scriptpath = os.path.dirname(os.path.abspath(__file__))
104
+
105
+ class Script(modules.scripts.Script):
106
+ def __init__(self):
107
+ self.log = {}
108
+ self.stops = {}
109
+ self.starts = {}
110
+ self.active = False
111
+ self.lora = {}
112
+ self.lycoris = {}
113
+ self.networks = {}
114
+
115
+ self.stopsf = []
116
+ self.startsf = []
117
+ self.uf = []
118
+ self.lf = []
119
+ self.ef = []
120
+
121
+ def title(self):
122
+ return "LoRA Block Weight"
123
+
124
+ def show(self, is_img2img):
125
+ return modules.scripts.AlwaysVisible
126
+
127
+ def ui(self, is_img2img):
128
+ LWEIGHTSPRESETS = DEF_WEIGHT_PRESET
129
+
130
+ runorigin = scripts.scripts_txt2img.run
131
+ runorigini = scripts.scripts_img2img.run
132
+
133
+ scriptpath = os.path.dirname(os.path.abspath(__file__))
134
+ path_root = scripts.basedir()
135
+
136
+ extpath = os.path.join(scriptpath, "lbwpresets.txt")
137
+ extpathe = os.path.join(scriptpath, "elempresets.txt")
138
+ filepath = os.path.join(path_root,"scripts", "lbwpresets.txt")
139
+ filepathe = os.path.join(path_root,"scripts", "elempresets.txt")
140
+
141
+ if os.path.isfile(filepath) and not os.path.isfile(extpath):
142
+ shutil.move(filepath,extpath)
143
+
144
+ if os.path.isfile(filepathe) and not os.path.isfile(extpathe):
145
+ shutil.move(filepathe,extpathe)
146
+
147
+ lbwpresets=""
148
+
149
+ try:
150
+ with open(extpath,encoding="utf-8") as f:
151
+ lbwpresets = f.read()
152
+ except OSError as e:
153
+ lbwpresets=LWEIGHTSPRESETS
154
+ if not os.path.isfile(extpath):
155
+ try:
156
+ with open(extpath,mode = 'w',encoding="utf-8") as f:
157
+ f.write(lbwpresets)
158
+ except:
159
+ pass
160
+
161
+ try:
162
+ with open(extpathe,encoding="utf-8") as f:
163
+ elempresets = f.read()
164
+ except OSError as e:
165
+ elempresets=ELEMPRESETS
166
+ if not os.path.isfile(extpathe):
167
+ try:
168
+ with open(extpathe,mode = 'w',encoding="utf-8") as f:
169
+ f.write(elempresets)
170
+ except:
171
+ pass
172
+
173
+ loraratios=lbwpresets.splitlines()
174
+ lratios={}
175
+ for i,l in enumerate(loraratios):
176
+ if checkloadcond(l) : continue
177
+ lratios[l.split(":")[0]]=l.split(":")[1]
178
+ ratiostags = [k for k in lratios.keys()]
179
+ ratiostags = ",".join(ratiostags)
180
+
181
+ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
182
+ args = cmd_args.parser.parse_args()
183
+ else:
184
+ args, _ = cmd_args.parser.parse_known_args()
185
+ if args.api:
186
+ register()
187
+
188
+ with gr.Accordion(f"LoRA Block Weight : {active_i if is_img2img else active_t}",open = False) as acc:
189
+ with gr.Row():
190
+ with gr.Column(min_width = 50, scale=1):
191
+ lbw_useblocks = gr.Checkbox(value = True,label="Active",interactive =True,elem_id="lbw_active")
192
+ debug = gr.Checkbox(value = False,label="Debug",interactive =True,elem_id="lbw_debug")
193
+ with gr.Column(scale=5):
194
+ bw_ratiotags= gr.TextArea(label="",value=ratiostags,visible =True,interactive =True,elem_id="lbw_ratios")
195
+ with gr.Accordion("XYZ plot",open = False):
196
+ gr.HTML(value='<p style= "word-wrap:break-word;">changeable blocks : BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11</p>')
197
+ xyzsetting = gr.Radio(label = "Active",choices = ["Disable","XYZ plot","Effective Block Analyzer"], value ="Disable",type = "index")
198
+ with gr.Row(visible = False) as esets:
199
+ diffcol = gr.Radio(label = "diff image color",choices = ["black","white"], value ="black",type = "value",interactive =True)
200
+ revxy = gr.Checkbox(value = False,label="change X-Y",interactive =True,elem_id="lbw_changexy")
201
+ thresh = gr.Textbox(label="difference threshold",lines=1,value="20",interactive =True,elem_id="diff_thr")
202
+ xtype = gr.Dropdown(label="X Types", choices=[x for x in ATYPES], value=ATYPES [2],interactive =True,elem_id="lbw_xtype")
203
+ xmen = gr.Textbox(label="X Values",lines=1,value="0,0.25,0.5,0.75,1",interactive =True,elem_id="lbw_xmen")
204
+ ytype = gr.Dropdown(label="Y Types", choices=[y for y in ATYPES], value=ATYPES [1],interactive =True,elem_id="lbw_ytype")
205
+ ymen = gr.Textbox(label="Y Values" ,lines=1,value="IN05-OUT05",interactive =True,elem_id="lbw_ymen")
206
+ ztype = gr.Dropdown(label="Z type", choices=[z for z in ATYPES], value=ATYPES[0],interactive =True,elem_id="lbw_ztype")
207
+ zmen = gr.Textbox(label="Z values",lines=1,value="",interactive =True,elem_id="lbw_zmen")
208
+
209
+ exmen = gr.Textbox(label="Range",lines=1,value="0.5,1",interactive =True,elem_id="lbw_exmen",visible = False)
210
+ eymen = gr.Textbox(label="Blocks (12ALL,17ALL,20ALL,26ALL also can be used)" ,lines=1,value="BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11",interactive =True,elem_id="lbw_eymen",visible = False)
211
+ ecount = gr.Number(value=1, label="number of seed", interactive=True, visible = True)
212
+
213
+ with gr.Accordion("Weights setting",open = True):
214
+ with gr.Row():
215
+ reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
216
+ reloadtags = gr.Button(value="Reload Tags",variant='primary',elem_id="lbw_reload")
217
+ savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
218
+ openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
219
+ lbw_loraratios = gr.TextArea(label="",value=lbwpresets,visible =True,interactive = True,elem_id="lbw_ratiospreset")
220
+
221
+ with gr.Accordion("Elemental",open = False):
222
+ with gr.Row():
223
+ e_reloadtext = gr.Button(value="Reload Presets",variant='primary',elem_id="lbw_reload")
224
+ e_savetext = gr.Button(value="Save Presets",variant='primary',elem_id="lbw_savetext")
225
+ e_openeditor = gr.Button(value="Open TextEditor",variant='primary',elem_id="lbw_openeditor")
226
+ elemsets = gr.Checkbox(value = False,label="print change",interactive =True,elem_id="lbw_print_change")
227
+ elemental = gr.TextArea(label="Identifer:BlockID:Elements:Ratio,...,separated by empty line ",value = elempresets,interactive =True,elem_id="element")
228
+
229
+ d_true = gr.Checkbox(value = True,visible = False)
230
+ d_false = gr.Checkbox(value = False,visible = False)
231
+
232
+ lbw_useblocks.change(fn=lambda x:gr.update(label = f"LoRA Block Weight : {'Active' if x else 'Not Active'}"),inputs=lbw_useblocks, outputs=[acc])
233
+
234
+ import subprocess
235
+ def openeditors(b):
236
+ path = extpath if b else extpathe
237
+ subprocess.Popen(['start', path], shell=True)
238
+
239
+ def reloadpresets(isweight):
240
+ if isweight:
241
+ try:
242
+ with open(extpath,encoding="utf-8") as f:
243
+ return f.read()
244
+ except OSError as e:
245
+ pass
246
+ else:
247
+ try:
248
+ with open(extpath,encoding="utf-8") as f:
249
+ return f.read()
250
+ except OSError as e:
251
+ pass
252
+
253
+ def tagdicter(presets):
254
+ presets=presets.splitlines()
255
+ wdict={}
256
+ for l in presets:
257
+ if checkloadcond(l) : continue
258
+ w=[]
259
+ if ":" in l :
260
+ key = l.split(":",1)[0]
261
+ w = l.split(":",1)[1]
262
+ if any(len([w for w in w.split(",")]) == x for x in BLOCKNUMS):
263
+ wdict[key.strip()]=w
264
+ return ",".join(list(wdict.keys()))
265
+
266
+ def savepresets(text,isweight):
267
+ if isweight:
268
+ with open(extpath,mode = 'w',encoding="utf-8") as f:
269
+ f.write(text)
270
+ else:
271
+ with open(extpathe,mode = 'w',encoding="utf-8") as f:
272
+ f.write(text)
273
+
274
+ reloadtext.click(fn=reloadpresets,inputs=[d_true],outputs=[lbw_loraratios])
275
+ reloadtags.click(fn=tagdicter,inputs=[lbw_loraratios],outputs=[bw_ratiotags])
276
+ savetext.click(fn=savepresets,inputs=[lbw_loraratios,d_true],outputs=[])
277
+ openeditor.click(fn=openeditors,inputs=[d_true],outputs=[])
278
+
279
+ e_reloadtext.click(fn=reloadpresets,inputs=[d_false],outputs=[elemental])
280
+ e_savetext.click(fn=savepresets,inputs=[elemental,d_false],outputs=[])
281
+ e_openeditor.click(fn=openeditors,inputs=[d_false],outputs=[])
282
+
283
+ def urawaza(active):
284
+ if active > 0:
285
+ register()
286
+ scripts.scripts_txt2img.run = newrun
287
+ scripts.scripts_img2img.run = newrun
288
+ if active == 1:return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
289
+ else:return [*[gr.update(visible = False) for x in range(6)],*[gr.update(visible = True) for x in range(4)]]
290
+ else:
291
+ scripts.scripts_txt2img.run = runorigin
292
+ scripts.scripts_img2img.run = runorigini
293
+ return [*[gr.update(visible = True) for x in range(6)],*[gr.update(visible = False) for x in range(4)]]
294
+
295
+ xyzsetting.change(fn=urawaza,inputs=[xyzsetting],outputs =[xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,esets])
296
+
297
+ return lbw_loraratios,lbw_useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug
298
+
299
+ def process(self, p, loraratios,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug):
300
+ #print("self =",self,"p =",p,"presets =",loraratios,"useblocks =",useblocks,"xyzsettings =",xyzsetting,"xtype =",xtype,"xmen =",xmen,"ytype =",ytype,"ymen =",ymen,"ztype =",ztype,"zmen =",zmen)
301
+ #Note that this does not use the default arg syntax because the default args are supposed to be at the end of the function
302
+ if(loraratios == None):
303
+ loraratios = DEF_WEIGHT_PRESET
304
+ if(useblocks == None):
305
+ useblocks = True
306
+
307
+ lorachecker(self)
308
+ self.log["enable LBW"] = useblocks
309
+ self.log["registerd"] = registerd
310
+
311
+ if useblocks:
312
+ self.active = True
313
+ loraratios=loraratios.splitlines()
314
+ elemental = elemental.split("\n\n") if elemental is not None else []
315
+ lratios={}
316
+ elementals={}
317
+ for l in loraratios:
318
+ if checkloadcond(l) : continue
319
+ l0=l.split(":",1)[0]
320
+ lratios[l0.strip()]=l.split(":",1)[1]
321
+ for e in elemental:
322
+ if ":" not in e: continue
323
+ e0=e.split(":",1)[0]
324
+ elementals[e0.strip()]=e.split(":",1)[1]
325
+ if elemsets : print(xyelem)
326
+ if xyzsetting and "XYZ" in p.prompt:
327
+ lratios["XYZ"] = lxyz
328
+ lratios["ZYX"] = lzyx
329
+ if xyelem != "":
330
+ if "XYZ" in elementals.keys():
331
+ elementals["XYZ"] = elementals["XYZ"] + ","+ xyelem
332
+ else:
333
+ elementals["XYZ"] = xyelem
334
+ self.lratios = lratios
335
+ self.elementals = elementals
336
+ global princ
337
+ princ = elemsets
338
+
339
+ if not hasattr(self,"lbt_dr_callbacks"):
340
+ self.lbt_dr_callbacks = on_cfg_denoiser(self.denoiser_callback)
341
+
342
+ def denoiser_callback(self, params: CFGDenoiserParams):
343
+ def setparams(self, key, te, u ,sets):
344
+ for dicts in [self.lora,self.lycoris,self.networks]:
345
+ for lora in dicts:
346
+ if lora.name.split("_in_LBW_")[0] == key:
347
+ lora.te_multiplier = te
348
+ lora.unet_multiplier = u
349
+ sets.append(key)
350
+
351
+ if forge and self.active:
352
+ if params.sampling_step in self.startsf:
353
+ shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
354
+ for key, vals in shared.sd_model.forge_objects.unet.patches.items():
355
+ n_vals = []
356
+ lvals = [val for val in vals if val[1][0] in LORAS]
357
+ for s, v, m, l, e in zip(self.startsf, lvals, self.uf, self.lf, self.ef):
358
+ if s is not None and s == params.sampling_step:
359
+ ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
360
+ n_vals.append((ratio * m, *v[1:]))
361
+ else:
362
+ n_vals.append(v)
363
+ shared.sd_model.forge_objects.unet.patches[key] = n_vals
364
+ shared.sd_model.forge_objects.unet.patch_model()
365
+
366
+ if params.sampling_step in self.stopsf:
367
+ shared.sd_model.forge_objects.unet.unpatch_model(device_to=devices.device)
368
+ for key, vals in shared.sd_model.forge_objects.unet.patches.items():
369
+ n_vals = []
370
+ lvals = [val for val in vals if val[1][0] in LORAS]
371
+ for s, v, m, l, e in zip(self.stopsf, lvals, self.uf, self.lf, self.ef):
372
+ if s is not None and s == params.sampling_step:
373
+ n_vals.append((0, *v[1:]))
374
+ else:
375
+ n_vals.append(v)
376
+ shared.sd_model.forge_objects.unet.patches[key] = n_vals
377
+ shared.sd_model.forge_objects.unet.patch_model()
378
+
379
+ elif self.active:
380
+ if self.starts and params.sampling_step == 0:
381
+ for key, step_te_u in self.starts.items():
382
+ setparams(self, key, 0, 0, [])
383
+ #print("\nstart 0", self, key, 0, 0, [])
384
+
385
+ if self.starts:
386
+ sets = []
387
+ for key, step_te_u in self.starts.items():
388
+ step, te, u = step_te_u
389
+ if params.sampling_step > step - 2:
390
+ setparams(self, key, te, u, sets)
391
+ #print("\nstart", self, key, u, te, sets)
392
+ for key in sets:
393
+ del self.starts[key]
394
+
395
+ if self.stops:
396
+ sets = []
397
+ for key, step in self.stops.items():
398
+ if params.sampling_step > step - 2:
399
+ setparams(self, key, 0, 0, sets)
400
+ #print("\nstop", self, key, 0, 0, sets)
401
+ for key in sets:
402
+ del self.stops[key]
403
+
404
+ def before_process_batch(self, p, loraratios,useblocks,*args,**kwargs):
405
+ if useblocks:
406
+ resetmemory()
407
+ if not self.isnet: p.disable_extra_networks = False
408
+ global prompts
409
+ prompts = kwargs["prompts"].copy()
410
+
411
+ def process_batch(self, p, loraratios,useblocks,*args,**kwargs):
412
+ if useblocks:
413
+ if not self.isnet: p.disable_extra_networks = True
414
+
415
+ o_prompts = [p.prompt]
416
+ for prompt in prompts:
417
+ if "<lora" in prompt or "<lyco" in prompt:
418
+ o_prompts = prompts.copy()
419
+ if not self.isnet: loradealer(self, o_prompts ,self.lratios,self.elementals)
420
+
421
+ def postprocess(self, p, processed, presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug,*args):
422
+ if not useblocks:
423
+ return
424
+ lora = importer(self)
425
+ emb_db = sd_hijack.model_hijack.embedding_db
426
+
427
+ for net in lora.loaded_loras:
428
+ if hasattr(net,"bundle_embeddings"):
429
+ for emb_name, embedding in net.bundle_embeddings.items():
430
+ if embedding.loaded:
431
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
432
+
433
+ lora.loaded_loras.clear()
434
+
435
+ if forge:
436
+ sd_models.model_data.get_sd_model().current_lora_hash = None
437
+ shared.sd_model.forge_objects_after_applying_lora.unet.unpatch_model()
438
+ shared.sd_model.forge_objects_after_applying_lora.clip.patcher.unpatch_model()
439
+
440
+ global lxyz,lzyx,xyelem
441
+ lxyz = lzyx = xyelem = ""
442
+ if debug:
443
+ print(self.log)
444
+ gc.collect()
445
+
446
+ def after_extra_networks_activate(self, p, presets,useblocks, *args, **kwargs):
447
+ if useblocks:
448
+ loradealer(self, kwargs["prompts"] ,self.lratios,self.elementals,kwargs["extra_network_data"])
449
+
450
+ def run(self,p,presets,useblocks,xyzsetting,xtype,xmen,ytype,ymen,ztype,zmen,exmen,eymen,ecount,diffcol,thresh,revxy,elemental,elemsets,debug):
451
+ if not useblocks:
452
+ return
453
+ self.__init__()
454
+ self.log["pass XYZ"] = True
455
+ self.log["XYZsets"] = xyzsetting
456
+ self.log["enable LBW"] = useblocks
457
+
458
+ if xyzsetting >0:
459
+ lorachecker(self)
460
+ lora = importer(self)
461
+ loraratios=presets.splitlines()
462
+ lratios={}
463
+ for l in loraratios:
464
+ if checkloadcond(l) : continue
465
+ l0=l.split(":",1)[0]
466
+ lratios[l0.strip()]=l.split(":",1)[1]
467
+
468
+ if "XYZ" in p.prompt:
469
+ base = lratios["XYZ"] if "XYZ" in lratios.keys() else "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1"
470
+ else: return
471
+
472
+ for i, all in enumerate(["12ALL","17ALL","20ALL","26ALL"]):
473
+ if eymen == all:
474
+ eymen = ",".join(BLOCKIDS[i])
475
+
476
+ if xyzsetting > 1:
477
+ xmen,ymen = exmen,eymen
478
+ xtype,ytype = "values","ID"
479
+ ebase = xmen.split(",")[1]
480
+ ebase = [ebase.strip()]*26
481
+ base = ",".join(ebase)
482
+ ztype = ""
483
+ if ecount > 1:
484
+ ztype = "seed"
485
+ zmen = ",".join([str(random.randrange(4294967294)) for x in range(int(ecount))])
486
+
487
+ #ATYPES =["none","Block ID","values","seed","Base Weights"]
488
+
489
+ def dicedealer(am):
490
+ for i,a in enumerate(am):
491
+ if a =="-1": am[i] = str(random.randrange(4294967294))
492
+ print(f"the die was thrown : {am}")
493
+
494
+ if p.seed == -1: p.seed = str(random.randrange(4294967294))
495
+
496
+ #print(f"xs:{xmen},ys:{ymen},zs:{zmen}")
497
+
498
+ def adjuster(a,at):
499
+ if "none" in at:a = ""
500
+ a = [a.strip() for a in a.split(',')]
501
+ if "seed" in at:dicedealer(a)
502
+ return a
503
+
504
+ xs = adjuster(xmen,xtype)
505
+ ys = adjuster(ymen,ytype)
506
+ zs = adjuster(zmen,ztype)
507
+
508
+ ids = alpha =seed = ""
509
+ p.batch_size = 1
510
+
511
+ print(f"xs:{xs},ys:{ys},zs:{zs}")
512
+
513
+ images = []
514
+
515
+ def weightsdealer(alpha,ids,base):
516
+ #print(f"weights from : {base}")
517
+ ids = [z.strip() for z in ids.split(' ')]
518
+ weights_t = [w.strip() for w in base.split(',')]
519
+ blockid = BLOCKIDS[BLOCKNUMS.index(len(weights_t))]
520
+ if ids[0]!="NOT":
521
+ flagger=[False]*len(weights_t)
522
+ changer = True
523
+ else:
524
+ flagger=[True]*len(weights_t)
525
+ changer = False
526
+ for id in ids:
527
+ if id =="NOT":continue
528
+ if "-" in id:
529
+ it = [it.strip() for it in id.split('-')]
530
+ if blockid.index(it[1]) > blockid.index(it[0]):
531
+ flagger[blockid.index(it[0]):blockid.index(it[1])+1] = [changer]*(blockid.index(it[1])-blockid.index(it[0])+1)
532
+ else:
533
+ flagger[blockid.index(it[1]):blockid.index(it[0])+1] = [changer]*(blockid.index(it[0])-blockid.index(it[1])+1)
534
+ else:
535
+ flagger[blockid.index(id)] =changer
536
+ for i,f in enumerate(flagger):
537
+ if f:weights_t[i]=alpha
538
+ outext = ",".join(weights_t)
539
+ #print(f"weights changed: {outext}")
540
+ return outext
541
+
542
+ generatedbases=[]
543
+ def xyzdealer(a,at):
544
+ nonlocal ids,alpha,p,base,c_base,generatedbases
545
+ if "ID" in at:return
546
+ if "values" in at:alpha = a
547
+ if "seed" in at:
548
+ p.seed = int(a)
549
+ generatedbases=[]
550
+ if "Weights" in at:base =c_base = lratios[a]
551
+ if "elements" in at:
552
+ global xyelem
553
+ xyelem = a
554
+
555
+ def imagedupewatcher(baselist,basetocheck,currentiteration):
556
+ for idx,alreadygenerated in enumerate(baselist):
557
+ if (basetocheck == alreadygenerated):
558
+ # E.g., we already generated IND+OUTS and this is now OUTS+IND with identical weights.
559
+ baselist.insert(currentiteration-1, basetocheck)
560
+ return idx
561
+ return -1
562
+
563
+ def strThree(someNumber): # Returns 1.12345 as 1.123 and 1.0000 as 1
564
+ return format(someNumber, ".3f").rstrip('0').rstrip('.')
565
+
566
+ # Adds X and Y together using array addition.
567
+ # If both X and Y have a value in the same block then Y's is set to 0;
568
+ # both values are used due to both XY and YX being generated, but the diagonal then only show the first value.
569
+ # imagedupwatcher prevents duplicate images from being generated;
570
+ # when X and Y have non-overlapping blocks then the upper triangular images are identical to the lower ones.
571
+ def xyoriginalweightsdealer(x,y):
572
+ xweights = np.asarray(lratios[x].split(','), dtype=np.float32) # np array easier to add later
573
+ yweights = np.asarray(lratios[y].split(','), dtype=np.float32)
574
+ for idx,xval in np.ndenumerate(xweights):
575
+ yval = yweights[idx]
576
+ if xval != 0 and yval != 0:
577
+ yweights[idx] = 0
578
+ # Add xweights to yweights, round to 3 places,
579
+ # map floats to string with format of 3 decimals trailing zeroes and decimal stripped
580
+ baseListToStrings = list(map(strThree, np.around(np.add(xweights,yweights,),3).tolist()))
581
+ return ",".join(baseListToStrings)
582
+
583
+ grids = []
584
+ images =[]
585
+
586
+ totalcount = len(xs)*len(ys)*len(zs) if xyzsetting < 2 else len(xs)*len(ys)*len(zs) //2 +1
587
+ shared.total_tqdm.updateTotal(totalcount)
588
+ xc = yc =zc = 0
589
+ state.job_count = totalcount
590
+ totalcount = len(xs)*len(ys)*len(zs)
591
+ c_base = base
592
+
593
+ for z in zs:
594
+ generatedbases=[]
595
+ images = []
596
+ yc = 0
597
+ xyzdealer(z,ztype)
598
+ for y in ys:
599
+ xc = 0
600
+ xyzdealer(y,ytype)
601
+ for x in xs:
602
+ xyzdealer(x,xtype)
603
+ if "Weights" in xtype and "Weights" in ytype:
604
+ c_base = xyoriginalweightsdealer(x,y)
605
+ else:
606
+ if "ID" in xtype:
607
+ if "values" in ytype:c_base = weightsdealer(y,x,base)
608
+ if "values" in ztype:c_base = weightsdealer(z,x,base)
609
+ if "ID" in ytype:
610
+ if "values" in xtype:c_base = weightsdealer(x,y,base)
611
+ if "values" in ztype:c_base = weightsdealer(z,y,base)
612
+ if "ID" in ztype:
613
+ if "values" in xtype:c_base = weightsdealer(x,z,base)
614
+ if "values" in ytype:c_base = weightsdealer(y,z,base)
615
+
616
+ iteration = len(xs)*len(ys)*zc + yc*len(xs) +xc +1
617
+ print(f"X:{xtype}, {x},Y: {ytype},{y}, Z:{ztype},{z}, base:{c_base} ({iteration}/{totalcount})")
618
+
619
+ dupe_index = imagedupewatcher(generatedbases,c_base,iteration)
620
+ if dupe_index > -1:
621
+ print(f"Skipping generation of duplicate base:{c_base}")
622
+ images.append(images[dupe_index].copy())
623
+ xc += 1
624
+ continue
625
+
626
+ global lxyz,lzyx
627
+ lxyz = c_base
628
+
629
+ cr_base = c_base.split(",")
630
+ cr_base_t=[]
631
+ for x in cr_base:
632
+ if not identifier(x):
633
+ cr_base_t.append(str(1-float(x)))
634
+ else:
635
+ cr_base_t.append(x)
636
+ lzyx = ",".join(cr_base_t)
637
+
638
+ if not(xc == 1 and not (yc ==0 ) and xyzsetting >1):
639
+ lora.loaded_loras.clear()
640
+ p.cached_c = [None,None]
641
+ p.cached_uc = [None,None]
642
+ p.cached_hr_c = [None, None]
643
+ p.cached_hr_uc = [None, None]
644
+ processed:Processed = process_images(p)
645
+ images.append(processed.images[0])
646
+ generatedbases.insert(iteration-1, c_base)
647
+ xc += 1
648
+ yc += 1
649
+ zc += 1
650
+ origin = loranames(processed.all_prompts) + ", "+ znamer(ztype,z,base)
651
+ images,xst,yst = effectivechecker(images,xs.copy(),ys.copy(),diffcol,thresh,revxy) if xyzsetting >1 else (images,xs.copy(),ys.copy())
652
+ grids.append(smakegrid(images,xst,yst,origin,p))
653
+ processed.images= grids
654
+ lora.loaded_loras.clear()
655
+ return processed
656
+
657
+ def identifier(char):
658
+ return char[0] in ["R", "U", "X"]
659
+
660
+ def znamer(at,a,base):
661
+ if "ID" in at:return f"Block : {a}"
662
+ if "values" in at:return f"value : {a}"
663
+ if "seed" in at:return f"seed : {a}"
664
+ if "Weights" in at:return f"original weights :\n {base}"
665
+ else: return ""
666
+
667
+ def loranames(all_prompts):
668
+ _, extra_network_data = extra_networks.parse_prompts(all_prompts[0:1])
669
+ calledloras = extra_network_data["lora"] if "lyco" not in extra_network_data.keys() else extra_network_data["lyco"]
670
+ names = ""
671
+ for called in calledloras:
672
+ if len(called.items) <3:continue
673
+ names += called.items[0]
674
+ return names
675
+
676
+ def lorachecker(self):
677
+ try:
678
+ import networks
679
+ self.isnet = True
680
+ self.layer_name = "network_layer_name"
681
+ except:
682
+ self.isnet = False
683
+ self.layer_name = "lora_layer_name"
684
+ try:
685
+ import lora
686
+ self.islora = True
687
+ except:
688
+ pass
689
+ try:
690
+ import lycoris
691
+ self.islyco = True
692
+ except:
693
+ pass
694
+ self.onlyco = (not self.islora) and self.islyco
695
+ self.isxl = hasattr(shared.sd_model,"conditioner")
696
+
697
+ self.log["isnet"] = self.isnet
698
+ self.log["isxl"] = self.isxl
699
+ self.log["islora"] = self.islora
700
+
701
+ def resetmemory():
702
+ try:
703
+ import networks as nets
704
+ nets.networks_in_memory = {}
705
+ gc.collect()
706
+
707
+ except:
708
+ pass
709
+
710
+ def importer(self):
711
+ if self.onlyco:
712
+ # lycorisモジュールを動的にインポート
713
+ lora_module = importlib.import_module("lycoris")
714
+ return lora_module
715
+ else:
716
+ # loraモジュールを動的にインポート
717
+ lora_module = importlib.import_module("lora")
718
+ return lora_module
719
+
720
+ def loradealer(self, prompts,lratios,elementals, extra_network_data = None):
721
+ if extra_network_data is None:
722
+ _, extra_network_data = extra_networks.parse_prompts(prompts)
723
+ moduletypes = extra_network_data.keys()
724
+
725
+ for ltype in moduletypes:
726
+ lorans = []
727
+ lorars = []
728
+ te_multipliers = []
729
+ unet_multipliers = []
730
+ elements = []
731
+ starts = []
732
+ stops = []
733
+ fparams = []
734
+ load = False
735
+ go_lbw = False
736
+
737
+ if not (ltype == "lora" or ltype == "lyco") : continue
738
+ for called in extra_network_data[ltype]:
739
+ items = called.items
740
+ setnow = False
741
+ name = items[0]
742
+ te = syntaxdealer(items,"te=",1)
743
+ unet = syntaxdealer(items,"unet=",2)
744
+ te,unet = multidealer(te,unet)
745
+
746
+ weights = syntaxdealer(items,"lbw=",2) if syntaxdealer(items,"lbw=",2) is not None else syntaxdealer(items,"w=",2)
747
+ elem = syntaxdealer(items, "lbwe=",3)
748
+ start = syntaxdealer(items,"start=",None)
749
+ stop = syntaxdealer(items,"stop=",None)
750
+ start, stop = stepsdealer(syntaxdealer(items,"step=",None), start, stop)
751
+
752
+ if weights is not None and (weights in lratios or any(weights.count(",") == x - 1 for x in BLOCKNUMS)):
753
+ wei = lratios[weights] if weights in lratios else weights
754
+ ratios = [w.strip() for w in wei.split(",")]
755
+ for i,r in enumerate(ratios):
756
+ if r =="R":
757
+ ratios[i] = round(random.random(),3)
758
+ elif r == "U":
759
+ ratios[i] = round(random.uniform(-0.5,1.5),3)
760
+ elif r[0] == "X":
761
+ base = syntaxdealer(items,"x=", 3) if len(items) >= 4 else 1
762
+ ratios[i] = getinheritedweight(base, r)
763
+ else:
764
+ ratios[i] = float(r)
765
+
766
+ if len(ratios) != 26:
767
+ ratios = to26(ratios)
768
+ setnow = True
769
+ else:
770
+ ratios = [1] * 26
771
+
772
+ if elem in elementals:
773
+ setnow = True
774
+ elem = elementals[elem]
775
+ else:
776
+ elem = ""
777
+
778
+ if setnow:
779
+ go_lbw = True
780
+ fparams.append([unet,ratios,elem])
781
+ settolist([lorans,te_multipliers,unet_multipliers,lorars,elements,starts,stops],[name,te,unet,ratios,elem,start,stop])
782
+
783
+ if start:
784
+ self.starts[name] = [int(start),te,unet]
785
+ self.log["starts"] = load = True
786
+
787
+ if stop:
788
+ self.stops[name] = int(stop)
789
+ self.log["stops"] = load = True
790
+
791
+ self.startsf = [int(s) if s is not None else None for s in starts]
792
+ self.stopsf = [int(s) if s is not None else None for s in stops]
793
+ self.uf = unet_multipliers
794
+ self.lf = lorars
795
+ self.ef = elements
796
+
797
+ if self.isnet: ltype = "nets"
798
+ if forge: ltype = "forge"
799
+ if go_lbw or load: load_loras_blocks(self, lorans,lorars,te_multipliers,unet_multipliers,elements,ltype, starts=starts)
800
+
801
+ def stepsdealer(step, start, stop):
802
+ if step is None or "-" not in step:
803
+ return start, stop
804
+ return step.split("-")
805
+
806
+ def settolist(ls,vs):
807
+ for l, v in zip(ls,vs):
808
+ l.append(v)
809
+
810
+ def syntaxdealer(items,target,index): #type "unet=", "x=", "lwbe="
811
+ for item in items:
812
+ if target in item:
813
+ return item.replace(target,"")
814
+ if index is None or index + 1> len(items): return None
815
+ if "=" in items[index]:return None
816
+ return items[index] if "@" not in items[index] else 1
817
+
818
+ def isfloat(t):
819
+ try:
820
+ float(t)
821
+ return True
822
+ except:
823
+ return False
824
+
825
+ def multidealer(t, u):
826
+ if t is None and u is None:
827
+ return 1,1
828
+ elif t is None:
829
+ return float(u),float(u)
830
+ elif u is None:
831
+ return float(t), float(t)
832
+ else:
833
+ return float(t),float(u)
834
+
835
+ re_inherited_weight = re.compile(r"X([+-])?([\d.]+)?")
836
+
837
+ def getinheritedweight(weight, offset):
838
+ match = re_inherited_weight.search(offset)
839
+ if match.group(1) == "+":
840
+ return float(weight) + float(match.group(2))
841
+ elif match.group(1) == "-":
842
+ return float(weight) - float(match.group(2))
843
+ else:
844
+ return float(weight)
845
+
846
+ def load_loras_blocks(self, names, lwei,te,unet,elements,ltype = "lora", starts = None):
847
+ oldnew=[]
848
+ if "lora" == ltype:
849
+ lora = importer(self)
850
+ self.lora = lora.loaded_loras
851
+ for loaded in lora.loaded_loras:
852
+ for n, name in enumerate(names):
853
+ if name == loaded.name:
854
+ if lwei[n] == [1] * 26 and elements[n] == "": continue
855
+ lbw(loaded,lwei[n],elements[n])
856
+ setall(loaded,te[n],unet[n])
857
+ newname = loaded.name +"_in_LBW_"+ str(round(random.random(),3))
858
+ oldname = loaded.name
859
+ loaded.name = newname
860
+ oldnew.append([oldname,newname])
861
+
862
+ elif "lyco" == ltype:
863
+ import lycoris as lycomo
864
+ self.lycoris = lycomo.loaded_lycos
865
+ for loaded in lycomo.loaded_lycos:
866
+ for n, name in enumerate(names):
867
+ if name == loaded.name:
868
+ lbw(loaded,lwei[n],elements[n])
869
+ setall(loaded,te[n],unet[n])
870
+
871
+ elif "nets" == ltype:
872
+ import networks as nets
873
+ self.networks = nets.loaded_networks
874
+ for loaded in nets.loaded_networks:
875
+ for n, name in enumerate(names):
876
+ if name == loaded.name:
877
+ lbw(loaded,lwei[n],elements[n])
878
+ setall(loaded,te[n],unet[n])
879
+
880
+ elif "forge" == ltype:
881
+ lbwf(te, unet, lwei, elements, starts)
882
+
883
+ try:
884
+ import lora_ctl_network as ctl
885
+ for old,new in oldnew:
886
+ if old in ctl.lora_weights.keys():
887
+ ctl.lora_weights[new] = ctl.lora_weights[old]
888
+ except:
889
+ pass
890
+
891
+ def setall(m,te,unet):
892
+ m.name = m.name + "_in_LBW_"+ str(round(random.random(),3))
893
+ m.te_multiplier = te
894
+ m.unet_multiplier = unet
895
+ m.multiplier = unet
896
+
897
+ def smakegrid(imgs,xs,ys,currentmodel,p):
898
+ ver_texts = [[images.GridAnnotation(y)] for y in ys]
899
+ hor_texts = [[images.GridAnnotation(x)] for x in xs]
900
+
901
+ w, h = imgs[0].size
902
+ grid = Image.new('RGB', size=(len(xs) * w, len(ys) * h), color='black')
903
+
904
+ for i, img in enumerate(imgs):
905
+ grid.paste(img, box=(i % len(xs) * w, i // len(xs) * h))
906
+
907
+ grid = images.draw_grid_annotations(grid,w, h, hor_texts, ver_texts)
908
+ grid = draw_origin(grid, currentmodel,w*len(xs),h*len(ys),w)
909
+ if opts.grid_save:
910
+ images.save_image(grid, opts.outdir_txt2img_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=p.seed, grid=True, p=p)
911
+
912
+ return grid
913
+
914
+ def get_font(fontsize):
915
+ fontpath = os.path.join(scriptpath, "Roboto-Regular.ttf")
916
+ try:
917
+ return ImageFont.truetype(opts.font or fontpath, fontsize)
918
+ except Exception:
919
+ return ImageFont.truetype(fontpath, fontsize)
920
+
921
+ def draw_origin(grid, text,width,height,width_one):
922
+ grid_d= Image.new("RGB", (grid.width,grid.height), "white")
923
+ grid_d.paste(grid,(0,0))
924
+
925
+ d= ImageDraw.Draw(grid_d)
926
+ color_active = (0, 0, 0)
927
+ fontsize = (width+height)//25
928
+ fnt = get_font(fontsize)
929
+
930
+ if grid.width != width_one:
931
+ while d.multiline_textsize(text, font=fnt)[0] > width_one*0.75 and fontsize > 0:
932
+ fontsize -=1
933
+ fnt = get_font(fontsize)
934
+ d.multiline_text((0,0), text, font=fnt, fill=color_active,align="center")
935
+ return grid_d
936
+
937
+ def newrun(p, *args):
938
+ script_index = args[0]
939
+
940
+ if args[0] ==0:
941
+ script = None
942
+ for obj in scripts.scripts_txt2img.alwayson_scripts:
943
+ if "lora_block_weight" in obj.filename:
944
+ script = obj
945
+ script_args = args[script.args_from:script.args_to]
946
+ else:
947
+ script = scripts.scripts_txt2img.selectable_scripts[script_index-1]
948
+
949
+ if script is None:
950
+ return None
951
+
952
+ script_args = args[script.args_from:script.args_to]
953
+
954
+ processed = script.run(p, *script_args)
955
+
956
+ shared.total_tqdm.clear()
957
+
958
+ return processed
959
+
960
+ registerd = False
961
+
962
+ def register():
963
+ global registerd
964
+ registerd = True
965
+ for obj in scripts.scripts_txt2img.alwayson_scripts:
966
+ if "lora_block_weight" in obj.filename:
967
+ if obj not in scripts.scripts_txt2img.selectable_scripts:
968
+ scripts.scripts_txt2img.selectable_scripts.append(obj)
969
+ scripts.scripts_txt2img.titles.append("LoRA Block Weight")
970
+ for obj in scripts.scripts_img2img.alwayson_scripts:
971
+ if "lora_block_weight" in obj.filename:
972
+ if obj not in scripts.scripts_img2img.selectable_scripts:
973
+ scripts.scripts_img2img.selectable_scripts.append(obj)
974
+ scripts.scripts_img2img.titles.append("LoRA Block Weight")
975
+
976
+ def effectivechecker(imgs,ss,ls,diffcol,thresh,revxy):
977
+ orig = imgs[1]
978
+ imgs = imgs[::2]
979
+ diffs = []
980
+ outnum =[]
981
+
982
+ for img in imgs:
983
+ abs_diff = cv2.absdiff(np.array(img) , np.array(orig))
984
+
985
+ abs_diff_t = cv2.threshold(abs_diff, int(thresh), 255, cv2.THRESH_BINARY)[1]
986
+ res = abs_diff_t.astype(np.uint8)
987
+ percentage = (np.count_nonzero(res) * 100)/ res.size
988
+ if "white" in diffcol: abs_diff = cv2.bitwise_not(abs_diff)
989
+ outnum.append(percentage)
990
+
991
+ abs_diff = Image.fromarray(abs_diff)
992
+
993
+ diffs.append(abs_diff)
994
+
995
+ outs = []
996
+ for i in range(len(ls)):
997
+ ls[i] = ls[i] + "\n Diff : " + str(round(outnum[i],3)) + "%"
998
+
999
+ if not revxy:
1000
+ for diff,img in zip(diffs,imgs):
1001
+ outs.append(diff)
1002
+ outs.append(img)
1003
+ outs.append(orig)
1004
+ ss = ["diff",ss[0],"source"]
1005
+ return outs,ss,ls
1006
+ else:
1007
+ outs = [orig]*len(diffs) + imgs + diffs
1008
+ ss = ["source",ss[0],"diff"]
1009
+ return outs,ls,ss
1010
+
1011
+ def lbw(lora,lwei,elemental):
1012
+ elemental = elemental.split(",")
1013
+ for key in lora.modules.keys():
1014
+ ratio, errormodules = ratiodealer(key, lwei, elemental)
1015
+
1016
+ ltype = type(lora.modules[key]).__name__
1017
+ set = False
1018
+ if ltype in LORAANDSOON.keys():
1019
+ if "OFT" not in ltype:
1020
+ setattr(lora.modules[key],LORAANDSOON[ltype],torch.nn.Parameter(getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio))
1021
+ else:
1022
+ setattr(lora.modules[key],LORAANDSOON[ltype],getattr(lora.modules[key],LORAANDSOON[ltype]) * ratio)
1023
+ set = True
1024
+ else:
1025
+ if hasattr(lora.modules[key],"up_model"):
1026
+ lora.modules[key].up_model.weight= torch.nn.Parameter(lora.modules[key].up_model.weight *ratio)
1027
+ #print("LoRA using LoCON")
1028
+ set = True
1029
+ else:
1030
+ lora.modules[key].up.weight= torch.nn.Parameter(lora.modules[key].up.weight *ratio)
1031
+ #print("LoRA")
1032
+ set = True
1033
+ if not set :
1034
+ print("unkwon LoRA")
1035
+
1036
+ if len(errormodules) > 0:
1037
+ print(errormodules)
1038
+ return lora
1039
+
1040
+ LORAS = ["lora", "loha", "lokr"]
1041
+
1042
+ def lbwf(mt, mu, lwei, elemental, starts):
1043
+ for key, vals in shared.sd_model.forge_objects_after_applying_lora.unet.patches.items():
1044
+ n_vals = []
1045
+ lvals = [val for val in vals if val[1][0] in LORAS]
1046
+ for v, m, l, e ,s in zip(lvals, mu, lwei, elemental, starts):
1047
+ ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
1048
+ n_vals.append((ratio * m if s is None else 0, *v[1:]))
1049
+ shared.sd_model.forge_objects_after_applying_lora.unet.patches[key] = n_vals
1050
+
1051
+ for key, vals in shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches.items():
1052
+ n_vals = []
1053
+ lvals = [val for val in vals if val[1][0] in LORAS]
1054
+ for v, m, l, e in zip(lvals, mt, lwei, elemental):
1055
+ ratio, errormodules = ratiodealer(key.replace(".","_"), l, e)
1056
+ n_vals.append((ratio * m, *v[1:]))
1057
+ shared.sd_model.forge_objects_after_applying_lora.clip.patcher.patches[key] = n_vals
1058
+
1059
+ def ratiodealer(key, lwei, elemental):
1060
+ ratio = 1
1061
+ picked = False
1062
+ errormodules = []
1063
+ currentblock = 0
1064
+
1065
+ for i,block in enumerate(BLOCKS):
1066
+ if block in key:
1067
+ if i == 26:
1068
+ i = 0
1069
+ ratio = lwei[i]
1070
+ picked = True
1071
+ currentblock = i
1072
+
1073
+ if not picked:
1074
+ errormodules.append(key)
1075
+
1076
+ if len(elemental) > 0:
1077
+ skey = key + BLOCKID26[currentblock]
1078
+ for d in elemental:
1079
+ if d.count(":") != 2 :continue
1080
+ dbs,dws,dr = (hyphener(d.split(":")[0]),d.split(":")[1],d.split(":")[2])
1081
+ dbs,dws = (dbs.split(" "), dws.split(" "))
1082
+ dbn,dbs = (True,dbs[1:]) if dbs[0] == "NOT" else (False,dbs)
1083
+ dwn,dws = (True,dws[1:]) if dws[0] == "NOT" else (False,dws)
1084
+ flag = dbn
1085
+ for db in dbs:
1086
+ if db in skey:
1087
+ flag = not dbn
1088
+ if flag:flag = dwn
1089
+ else:continue
1090
+ for dw in dws:
1091
+ if dw in skey:
1092
+ flag = not dwn
1093
+ if flag:
1094
+ dr = float(dr)
1095
+ if princ :print(dbs,dws,key,dr)
1096
+ ratio = dr
1097
+
1098
+ return ratio, errormodules
1099
+
1100
+ LORAANDSOON = {
1101
+ "LoraHadaModule" : "w1a",
1102
+ "LycoHadaModule" : "w1a",
1103
+ "NetworkModuleHada": "w1a",
1104
+ "FullModule" : "weight",
1105
+ "NetworkModuleFull": "weight",
1106
+ "IA3Module" : "w",
1107
+ "NetworkModuleIa3" : "w",
1108
+ "LoraKronModule" : "w1",
1109
+ "LycoKronModule" : "w1",
1110
+ "NetworkModuleLokr": "w1",
1111
+ "NetworkModuleGLora": "w1a",
1112
+ "NetworkModuleNorm": "w_norm",
1113
+ "NetworkModuleOFT": "scale"
1114
+ }
1115
+
1116
+ def hyphener(t):
1117
+ t = t.split(" ")
1118
+ for i,e in enumerate(t):
1119
+ if "-" in e:
1120
+ e = e.split("-")
1121
+ if BLOCKID26.index(e[1]) > BLOCKID26.index(e[0]):
1122
+ t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[0]):BLOCKID26.index(e[1])+1])
1123
+ else:
1124
+ t[i] = " ".join(BLOCKID26[BLOCKID26.index(e[1]):BLOCKID26.index(e[0])+1])
1125
+ return " ".join(t)
1126
+
1127
+ ELEMPRESETS="\
1128
+ ATTNDEEPON:IN05-OUT05:attn:1\n\n\
1129
+ ATTNDEEPOFF:IN05-OUT05:attn:0\n\n\
1130
+ PROJDEEPOFF:IN05-OUT05:proj:0\n\n\
1131
+ XYZ:::1"
1132
+
1133
+ def to26(ratios):
1134
+ ids = BLOCKIDS[BLOCKNUMS.index(len(ratios))]
1135
+ output = [0]*26
1136
+ for i, id in enumerate(ids):
1137
+ output[BLOCKID26.index(id)] = ratios[i]
1138
+ return output
1139
+
1140
+ def checkloadcond(l:str)->bool:
1141
+ # ここの条件分岐は読み込んだ行がBlock Waightの書式にあっているかを確認している。
1142
+ # [:]が含まれ、16個(LoRa)か25個(LyCORIS),11,19(XL),のカンマが含まれる形式であるうえ、
1143
+ # それがコメントアウト行(# foobar)でないことが求められている。
1144
+ # 逆に言うとコメントアウトしたいなら絶対"# "から始めることを要求している。
1145
+
1146
+ # This conditional branch is checking whether the loaded line conforms to the Block Weight format.
1147
+ # It is required that "[:]" is included, and the format contains either 16 commas (for LoRa) or 25 commas (for LyCORIS),
1148
+ # and it's not a comment line (e.g., "# foobar").
1149
+ # Conversely, if you want to comment out, it requires that it absolutely starts with "# ".
1150
+ res=(":" not in l) or (not any(l.count(",") == x - 1 for x in BLOCKNUMS)) or ("#" in l)
1151
+ #print("[debug]", res,repr(l))
1152
+ return res