|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from diffusers.models.lora import LoRACompatibleLinear |
|
from diffusers.models.lora import LoRALinearLayer,LoRAConv2dLayer |
|
from einops import rearrange |
|
|
|
from diffusers.models.transformer_2d import Transformer2DModel |
|
|
|
class AttnProcessor(nn.Module): |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=None, |
|
cross_attention_dim=None, |
|
): |
|
super().__init__() |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class IPAttnProcessor(nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = attn.head_to_batch_dim(ip_key) |
|
ip_value = attn.head_to_batch_dim(ip_value) |
|
|
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) |
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) |
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) |
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class AttnProcessor2_0(torch.nn.Module): |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=None, |
|
cross_attention_dim=None, |
|
): |
|
super().__init__() |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
scale= 1.0, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
args = () |
|
query = attn.to_q(hidden_states, *args) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states, *args) |
|
value = attn.to_v(encoder_hidden_states, *args) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, *args) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class AttnProcessor2_0_attn(torch.nn.Module): |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size=None, |
|
cross_attention_dim=None, |
|
): |
|
super().__init__() |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
is_cloth_pass=False, |
|
cloth = None, |
|
up_cnt=None, |
|
mid_cnt=None, |
|
down_cnt=None, |
|
inside_up=None, |
|
inside_down=None, |
|
cloth_text=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class AttnProcessor2_0_Lora(torch.nn.Module): |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
scale_lora =1.0, |
|
hidden_size=None, |
|
cross_attention_dim=None, |
|
): |
|
super().__init__() |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
self.scale_lora = scale_lora |
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
if hasattr(attn,'q_lora'): |
|
query = attn.to_q(hidden_states) |
|
q_lora = attn.q_lora(hidden_states) |
|
query = query + self.scale_lora * q_lora |
|
else: |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
|
|
if hasattr(attn,'k_lora'): |
|
key = attn.to_k(hidden_states) |
|
k_lora = attn.k_lora(hidden_states) |
|
key = key + self.scale_lora * k_lora |
|
else: |
|
key = attn.to_k(hidden_states) |
|
|
|
if hasattr(attn,'v_lora'): |
|
value = attn.to_v(encoder_hidden_states) |
|
v_lora = attn.v_lora(hidden_states) |
|
value = value + self.scale_lora * v_lora |
|
else: |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
|
|
if hasattr(attn,'out_lora'): |
|
hidden_states = attn.to_out[0](hidden_states) |
|
out_lora = attn.out_lora(hidden_states) |
|
hidden_states = hidden_states+ self.scale_lora*out_lora |
|
else: |
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class IPAttnProcessor_clothpass_noip(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
is_cloth_pass=False, |
|
cloth = None, |
|
up_cnt=None, |
|
mid_cnt=None, |
|
down_cnt=None, |
|
inside=None, |
|
): |
|
|
|
if is_cloth_pass or up_cnt is None: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
else: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
if attention_mask is not None: |
|
print('!!!!attention_mask is not NoNE!!!!') |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
cloth_feature = cloth[up_cnt*3 + inside-1] |
|
cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() |
|
|
|
|
|
c_key = self.to_k_c(cloth_feature) |
|
c_value = self.to_v_c(cloth_feature) |
|
|
|
|
|
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states_cloth = F.scaled_dot_product_attention( |
|
query, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states_cloth = hidden_states_cloth.to(query.dtype) |
|
|
|
hidden_states = hidden_states + self.scale * hidden_states_cloth |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
return hidden_states |
|
|
|
|
|
class IPAttnProcessor_clothpass(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
is_cloth_pass=False, |
|
cloth = None, |
|
up_cnt=None, |
|
mid_cnt=None, |
|
down_cnt=None, |
|
inside=None, |
|
): |
|
|
|
if is_cloth_pass : |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
elif up_cnt is None: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
else: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
if attention_mask is not None: |
|
print('!!!!attention_mask is not NoNE!!!!') |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
cloth_feature = cloth[up_cnt*3 + inside-1] |
|
cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
c_key = self.to_k_c(cloth_feature) |
|
c_value = self.to_v_c(cloth_feature) |
|
|
|
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAttnProcessor_clothpass_extend(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
is_cloth_pass=False, |
|
cloth = None, |
|
up_cnt=None, |
|
mid_cnt=None, |
|
down_cnt=None, |
|
inside_up=None, |
|
inside_down=None, |
|
): |
|
if is_cloth_pass : |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif down_cnt is not None or up_cnt is not None or mid_cnt is not None: |
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
print('!!!!attention_mask is not NoNE!!!!') |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
|
|
|
|
cloth_feature = cloth[inside_down] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() |
|
|
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
c_key = self.to_k_c(cloth_feature) |
|
c_value = self.to_v_c(cloth_feature) |
|
|
|
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
ip_hidden_states, c_key, c_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
return hidden_states |
|
else: |
|
assert(False) |
|
|
|
|
|
|
|
class IPAttnProcessorMulti2_0_2(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
is_cloth_pass=False, |
|
cloth = None, |
|
up_cnt=None, |
|
mid_cnt=None, |
|
down_cnt=None, |
|
inside=None, |
|
cloth_text=None, |
|
): |
|
|
|
if is_cloth_pass or up_cnt is None: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
else: |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
cloth_feature = cloth[up_cnt*3 + inside-1] |
|
cloth_feature = rearrange(cloth_feature, "b c h w -> b (h w) c").contiguous() |
|
|
|
|
|
|
|
|
|
|
|
query_cloth = self.q_additional(cloth_feature) |
|
query_cloth = query_cloth.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_key = self.to_k_ip(cloth_text) |
|
ip_value = self.to_v_ip(cloth_text) |
|
|
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states_cloth = F.scaled_dot_product_attention( |
|
query_cloth, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states_cloth = hidden_states_cloth.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states_cloth = hidden_states_cloth.to(query.dtype) |
|
|
|
ip_key = self.k_additional(hidden_states_cloth) |
|
ip_value = self.v_additional(hidden_states_cloth) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAttnProcessor2_0_paint(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
if cross_attention_dim==None: |
|
print("cross_attention_dim is none") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
|
|
|
|
ip_key = self.to_k_ip(encoder_hidden_states) |
|
ip_value = self.to_v_ip(encoder_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class IPAttnProcessor2_0_variant(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
|
|
|
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
hidden_states, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
with torch.no_grad(): |
|
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
hidden_states = ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
class IPAttnProcessor2_0(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
scale=1.0 |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
|
|
args = () |
|
|
|
query = attn.to_q(hidden_states, *args) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states, *args) |
|
value = attn.to_v(encoder_hidden_states, *args) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
with torch.no_grad(): |
|
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, *args) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IPAttnProcessor_referencenet_2_0(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4,attn_head_dim=10): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.num_tokens = num_tokens |
|
self.attn_head_dim=attn_head_dim |
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
with torch.no_grad(): |
|
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
class IPAttnProcessor2_0_Lora(torch.nn.Module): |
|
r""" |
|
Attention processor for IP-Adapater for PyTorch 2.0. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, scale_lora=1.0, rank = 4,num_tokens=4): |
|
super().__init__() |
|
|
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.hidden_size = hidden_size |
|
self.cross_attention_dim = cross_attention_dim |
|
self.scale = scale |
|
self.scale_lora = scale_lora |
|
self.num_tokens = num_tokens |
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) |
|
self.to_k_ip_lora = LoRALinearLayer(in_features=self.to_k_ip.in_features, out_features=self.to_k_ip.out_features, rank=rank) |
|
self.to_v_ip_lora =LoRALinearLayer(in_features=self.to_v_ip.in_features, out_features=self.to_v_ip.out_features, rank=rank) |
|
|
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
if hasattr(attn,'q_lora'): |
|
query = attn.to_q(hidden_states) |
|
q_lora = attn.q_lora(hidden_states) |
|
query = query + self.scale_lora * q_lora |
|
else: |
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states, ip_hidden_states = ( |
|
encoder_hidden_states[:, :end_pos, :], |
|
encoder_hidden_states[:, end_pos:, :], |
|
) |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
if hasattr(attn,'k_lora'): |
|
key = attn.to_k(encoder_hidden_states) |
|
k_lora = attn.k_lora(encoder_hidden_states) |
|
key = key + self.scale_lora * k_lora |
|
else: |
|
key = attn.to_k(encoder_hidden_states) |
|
|
|
if hasattr(attn,'v_lora'): |
|
value = attn.to_v(encoder_hidden_states) |
|
v_lora = attn.v_lora(encoder_hidden_states) |
|
value = value + self.scale_lora * v_lora |
|
else: |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
|
|
ip_key = self.to_k_ip(ip_hidden_states) |
|
ip_key_lora = self.to_k_ip_lora(ip_hidden_states) |
|
ip_key = ip_key + self.scale_lora * ip_key_lora |
|
ip_value = self.to_v_ip(ip_hidden_states) |
|
ip_value_lora = self.to_v_ip_lora(ip_hidden_states) |
|
ip_value = ip_value + self.scale_lora * ip_value_lora |
|
|
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
|
|
if hasattr(attn,'out_lora'): |
|
hidden_states = attn.to_out[0](hidden_states) |
|
out_lora = attn.out_lora(hidden_states) |
|
hidden_states = hidden_states+ self.scale_lora*out_lora |
|
else: |
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
|
|
class CNAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __init__(self, num_tokens=4): |
|
self.num_tokens = num_tokens |
|
|
|
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class CNAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self, num_tokens=4): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
self.num_tokens = num_tokens |
|
|
|
def __call__( |
|
self, |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
else: |
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens |
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos] |
|
if attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|