File size: 18,517 Bytes
c43f41f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 |
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from copy import deepcopy
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, AutoModel
from transformers.utils import ModelOutput
from transformers.models.llama import LlamaModel
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
# on how to generate the text.
# --------------------------------------------------------------------
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
def use_default(value, default):
return value if value is not None else default
def load_text_encoder(
text_encoder_type: str,
text_encoder_path: str,
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
):
logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
# reduce peak memory usage by specifying the dtype of the model
dtype = text_encoder_dtype
if text_encoder_type == "clipL":
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
elif text_encoder_type == "llm":
text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
text_encoder.final_layer_norm = text_encoder.norm
else:
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
# from_pretrained will ensure that the model is in eval mode.
if dtype is not None:
text_encoder = text_encoder.to(dtype=dtype)
text_encoder.requires_grad_(False)
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
return text_encoder, text_encoder_path
def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
if tokenizer_type == "clipL":
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
elif tokenizer_type == "llm":
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side=padding_side)
else:
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
return tokenizer, tokenizer_path
@dataclass
class TextEncoderModelOutput(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
List of decoded texts.
"""
hidden_state: torch.FloatTensor = None
attention_mask: Optional[torch.LongTensor] = None
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
text_outputs: Optional[list] = None
class TextEncoder(nn.Module):
def __init__(
self,
text_encoder_type: str,
max_length: int,
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
text_encoder_path: Optional[str] = None,
tokenizer_type: Optional[str] = None,
tokenizer_path: Optional[str] = None,
output_key: Optional[str] = None,
use_attention_mask: bool = True,
input_max_length: Optional[int] = None,
prompt_template: Optional[dict] = None,
prompt_template_video: Optional[dict] = None,
hidden_state_skip_layer: Optional[int] = None,
apply_final_norm: bool = False,
reproduce: bool = False,
):
super().__init__()
self.text_encoder_type = text_encoder_type
self.max_length = max_length
# self.precision = text_encoder_precision
self.model_path = text_encoder_path
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
self.use_attention_mask = use_attention_mask
if prompt_template_video is not None:
assert use_attention_mask is True, "Attention mask is True required when training videos."
self.input_max_length = input_max_length if input_max_length is not None else max_length
self.prompt_template = prompt_template
self.prompt_template_video = prompt_template_video
self.hidden_state_skip_layer = hidden_state_skip_layer
self.apply_final_norm = apply_final_norm
self.reproduce = reproduce
self.use_template = self.prompt_template is not None
if self.use_template:
assert (
isinstance(self.prompt_template, dict) and "template" in self.prompt_template
), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
assert "{}" in str(self.prompt_template["template"]), (
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
f"got {self.prompt_template['template']}"
)
self.use_video_template = self.prompt_template_video is not None
if self.use_video_template:
if self.prompt_template_video is not None:
assert (
isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
assert "{}" in str(self.prompt_template_video["template"]), (
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
f"got {self.prompt_template_video['template']}"
)
if "t5" in text_encoder_type:
self.output_key = output_key or "last_hidden_state"
elif "clip" in text_encoder_type:
self.output_key = output_key or "pooler_output"
elif "llm" in text_encoder_type or "glm" in text_encoder_type:
self.output_key = output_key or "last_hidden_state"
else:
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
self.model, self.model_path = load_text_encoder(
text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
)
self.dtype = self.model.dtype
self.tokenizer, self.tokenizer_path = load_tokenizer(
tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
)
def __repr__(self):
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
@property
def device(self):
return self.model.device
@staticmethod
def apply_text_to_template(text, template, prevent_empty_text=True):
"""
Apply text to template.
Args:
text (str): Input text.
template (str or list): Template string or list of chat conversation.
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
by adding a space. Defaults to True.
"""
if isinstance(template, str):
# Will send string to tokenizer. Used for llm
return template.format(text)
else:
raise TypeError(f"Unsupported template type: {type(template)}")
def text2tokens(self, text, data_type="image"):
"""
Tokenize the input text.
Args:
text (str or list): Input text.
"""
tokenize_input_type = "str"
if self.use_template:
if data_type == "image":
prompt_template = self.prompt_template["template"]
elif data_type == "video":
prompt_template = self.prompt_template_video["template"]
else:
raise ValueError(f"Unsupported data type: {data_type}")
if isinstance(text, (list, tuple)):
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
if isinstance(text[0], list):
tokenize_input_type = "list"
elif isinstance(text, str):
text = self.apply_text_to_template(text, prompt_template)
if isinstance(text, list):
tokenize_input_type = "list"
else:
raise TypeError(f"Unsupported text type: {type(text)}")
kwargs = dict(
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
if tokenize_input_type == "str":
return self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
**kwargs,
)
elif tokenize_input_type == "list":
return self.tokenizer.apply_chat_template(
text,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
**kwargs,
)
else:
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
def encode(
self,
batch_encoding,
use_attention_mask=None,
output_hidden_states=False,
do_sample=None,
hidden_state_skip_layer=None,
return_texts=False,
data_type="image",
device=None,
):
"""
Args:
batch_encoding (dict): Batch encoding from tokenizer.
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
Defaults to None.
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
output_hidden_states will be set True. Defaults to False.
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
When self.produce is False, do_sample is set to True by default.
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
If None, self.output_key will be used. Defaults to None.
return_texts (bool): Whether to return the decoded texts. Defaults to False.
"""
device = self.model.device if device is None else device
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
do_sample = use_default(do_sample, not self.reproduce)
attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
outputs = self.model(
input_ids=batch_encoding["input_ids"].to(device),
attention_mask=attention_mask,
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
if hidden_state_skip_layer > 0 and self.apply_final_norm:
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
else:
last_hidden_state = outputs[self.output_key]
# Remove hidden states of instruction tokens, only keep prompt tokens.
if self.use_template:
if data_type == "image":
crop_start = self.prompt_template.get("crop_start", -1)
elif data_type == "video":
crop_start = self.prompt_template_video.get("crop_start", -1)
else:
raise ValueError(f"Unsupported data type: {data_type}")
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
if output_hidden_states:
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
return TextEncoderModelOutput(last_hidden_state, attention_mask)
def forward(
self,
text,
use_attention_mask=None,
output_hidden_states=False,
do_sample=False,
hidden_state_skip_layer=None,
return_texts=False,
):
batch_encoding = self.text2tokens(text)
return self.encode(
batch_encoding,
use_attention_mask=use_attention_mask,
output_hidden_states=output_hidden_states,
do_sample=do_sample,
hidden_state_skip_layer=hidden_state_skip_layer,
return_texts=return_texts,
)
# region HunyanVideo architecture
def load_text_encoder_1(
text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
) -> TextEncoder:
text_encoder_dtype = dtype or torch.float16
text_encoder_type = "llm"
text_len = 256
hidden_state_skip_layer = 2
apply_final_norm = False
reproduce = False
prompt_template = "dit-llm-encode"
prompt_template = PROMPT_TEMPLATE[prompt_template]
prompt_template_video = "dit-llm-encode-video"
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
max_length = text_len + crop_start
text_encoder_1 = TextEncoder(
text_encoder_type=text_encoder_type,
max_length=max_length,
text_encoder_dtype=text_encoder_dtype,
text_encoder_path=text_encoder_dir,
tokenizer_type=text_encoder_type,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=hidden_state_skip_layer,
apply_final_norm=apply_final_norm,
reproduce=reproduce,
)
text_encoder_1.eval()
if fp8_llm:
org_dtype = text_encoder_1.dtype
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
# prepare LLM for fp8
def prepare_fp8(llama_model: LlamaModel, target_dtype):
def forward_hook(module):
def forward(hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
return forward
for module in llama_model.modules():
if module.__class__.__name__ in ["Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["LlamaRMSNorm"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
prepare_fp8(text_encoder_1.model, org_dtype)
else:
text_encoder_1.to(device=device)
return text_encoder_1
def load_text_encoder_2(
text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
) -> TextEncoder:
text_encoder_dtype = dtype or torch.float16
reproduce = False
text_encoder_2_type = "clipL"
text_len_2 = 77
text_encoder_2 = TextEncoder(
text_encoder_type=text_encoder_2_type,
max_length=text_len_2,
text_encoder_dtype=text_encoder_dtype,
text_encoder_path=text_encoder_dir,
tokenizer_type=text_encoder_2_type,
reproduce=reproduce,
)
text_encoder_2.eval()
text_encoder_2.to(device=device)
return text_encoder_2
# endregion
|