Doven
update code.
f7009b3
import torch
import einops
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10
from torchvision import transforms
import os
import math
import random
import json
from abc import ABC
import pickle
def pad_to_length(x, common_factor, **config):
if x.numel() % common_factor == 0:
return x.flatten()
# print(f"padding {x.shape} according to {common_factor}")
full_length = (x.numel() // common_factor + 1) * common_factor
padding_length = full_length - len(x.flatten())
padding = torch.full([padding_length, ], dtype=x.dtype, device=x.device, fill_value=config["fill_value"])
x = torch.cat((x.flatten(), padding), dim=0)
return x
def layer_to_token(x, common_factor, **config):
if config["granularity"] == 2: # split by output
if x.numel() <= common_factor:
return pad_to_length(x.flatten(), common_factor, **config)[None]
dim2 = x[0].numel()
dim1 = x.shape[0]
if dim2 <= common_factor:
i = int(dim1 / (common_factor / dim2))
while True:
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
output = x.view(-1, dim2 * (dim1 // i))
output = [pad_to_length(item, common_factor, **config) for item in output]
return torch.stack(output, dim=0)
i += 1
else: # dim2 > common_factor
output = [layer_to_token(item, common_factor, **config) for item in x]
return torch.cat(output, dim=0)
elif config["granularity"] == 1: # split by layer
return pad_to_length(x.flatten(), common_factor, **config).view(-1, common_factor)
elif config["granularity"] == 0: # flatten directly
return x.flatten()
else: # NotImplementedError
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
def token_to_layer(tokens, shape, **config):
common_factor = tokens.shape[-1]
if config["granularity"] == 2: # split by output
num_element = math.prod(shape)
if num_element <= common_factor:
param = tokens[0][:num_element].view(shape)
tokens = tokens[1:]
return param, tokens
dim2 = num_element // shape[0]
dim1 = shape[0]
if dim2 <= common_factor:
i = int(dim1 / (common_factor / dim2))
while True:
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor:
item_per_token = dim2 * (dim1 // i)
length = num_element // item_per_token
output = [item[:item_per_token] for item in tokens[:length]]
param = torch.cat(output, dim=0).view(shape)
tokens = tokens[length:]
return param, tokens
i += 1
else: # dim2 > common_factor
output = []
for i in range(shape[0]):
param, tokens = token_to_layer(tokens, shape[1:], **config)
output.append(param.flatten())
param = torch.cat(output, dim=0).view(shape)
return param, tokens
elif config["granularity"] == 1: # split by layer
num_element = math.prod(shape)
token_num = num_element // common_factor if num_element % common_factor == 0 \
else num_element // common_factor + 1
param = tokens.flatten()[:num_element].view(shape)
tokens = tokens[token_num:]
return param, tokens
elif config["granularity"] == 0: # flatten directly
num_element = math.prod(shape)
param = tokens.flatten()[:num_element].view(shape)
tokens = pad_to_length(tokens.flatten()[num_element:],
common_factor, fill_value=torch.nan).view(-1, common_factor)
return param, tokens
else: # NotImplementedError
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim")
def positional_embedding_2d(dim1, dim2, d_model):
assert d_model % 4 == 0, f"Cannot use sin/cos positional encoding with odd dimension {d_model}"
pe = torch.zeros(d_model, dim1, dim2)
d_model = int(d_model / 2) # Each dimension use half of d_model
div_term = torch.exp(torch.arange(0., d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0., dim2).unsqueeze(1)
pos_h = torch.arange(0., dim1).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2)
return pe.permute(1, 2, 0)
def positional_embedding_1d(dim1, d_model):
pe = torch.zeros(dim1, d_model)
position = torch.arange(0, dim1, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
class BaseDataset(Dataset, ABC):
data_path = None
generated_path = None
test_command = None
config = {
"fill_value": torch.nan,
"granularity": 1, # 0: flatten directly, 1: split by layer, 2: split by output
"pe_granularity": 2, # 0: no embedding, 1: 1d embedding, 2: 2d embedding
}
def __init__(self, checkpoint_path=None, dim_per_token=8192, **kwargs):
if not os.path.exists(self.data_path):
os.makedirs(self.data_path, exist_ok=False)
if self.generated_path is not None and not os.path.exists(os.path.dirname(self.generated_path)):
os.makedirs(os.path.dirname(self.generated_path))
self.config.update(kwargs)
checkpoint_path = self.data_path if checkpoint_path is None else checkpoint_path
assert os.path.exists(checkpoint_path)
self.dim_per_token = dim_per_token
self.structure = None # set in get_structure()
self.sequence_length = None # set in get_structure()
# load checkpoint_list
checkpoint_list = os.listdir(checkpoint_path)
self.checkpoint_list = list([os.path.join(checkpoint_path, item) for item in checkpoint_list])
self.length = self.real_length = len(self.checkpoint_list)
self.set_infinite_dataset()
# get structure
structure_cache_file = os.path.join(os.path.dirname(self.data_path), "structure.cache")
try: # try to load cache file
assert os.path.exists(structure_cache_file)
with open(structure_cache_file, "rb") as f:
print(f"Loading cache from {structure_cache_file}")
cache_file = pickle.load(f)
if len(self.checkpoint_list) != 0:
assert set(cache_file["checkpoint_list"]) == set(self.checkpoint_list)
self.structure = cache_file["structure"]
else: # empty checkpoint_list, only generate
print("Cannot find any trained checkpoint, loading cache file for generating!")
self.structure = cache_file["structure"]
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
torch.save(fake_diction, os.path.join(checkpoint_path, "fake_checkpoint.pth"))
self.checkpoint_list.append(os.path.join(checkpoint_path, "fake_checkpoint.pth"))
self.length = self.real_length = len(self.checkpoint_list)
self.set_infinite_dataset()
os.system(f"rm {os.path.join(checkpoint_path, 'fake_checkpoint.pth')}")
except AssertionError: # recompute cache file
print("==> Organizing structure..")
self.structure = self.get_structure()
with open(structure_cache_file, "wb") as f:
pickle.dump({"structure": self.structure, "checkpoint_list": self.checkpoint_list}, f)
# get sequence_length
self.sequence_length = self.get_sequence_length()
def get_sequence_length(self):
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()}
# get sequence_length
param = self.preprocess(fake_diction)
self.sequence_length = param.size(0)
return self.sequence_length
def get_structure(self):
# get structure
checkpoint_list = self.checkpoint_list
structures = [{} for _ in range(len(checkpoint_list))]
for i, checkpoint in enumerate(checkpoint_list):
diction = torch.load(checkpoint, map_location="cpu")
for key, value in diction.items():
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
structures[i][key] = (value.shape, value, None)
elif "running_var" in key:
pre_mean = value.mean() * 0.95
value = torch.log(value / pre_mean + 0.05)
structures[i][key] = (value.shape, pre_mean, value.mean(), value.std())
else: # conv & linear
structures[i][key] = (value.shape, value.mean(), value.std())
final_structure = {}
structure_diction = torch.load(checkpoint_list[0], map_location="cpu")
for key, param in structure_diction.items():
if ("num_batches_tracked" in key) or (param.numel() == 1) or not torch.is_floating_point(param):
final_structure[key] = (param.shape, param, None)
elif "running_var" in key:
value = [param.shape, 0., 0., 0.]
for structure in structures:
for i in [1, 2, 3]:
value[i] += structure[key][i]
for i in [1, 2, 3]:
value[i] /= len(structures)
final_structure[key] = tuple(value)
else: # conv & linear
value = [param.shape, 0., 0.]
for structure in structures:
for i in [1, 2]:
value[i] += structure[key][i]
for i in [1, 2]:
value[i] /= len(structures)
final_structure[key] = tuple(value)
self.structure = final_structure
return self.structure
def set_infinite_dataset(self, max_num=None):
if max_num is None:
max_num = self.length * 1000000
self.length = max_num
return self
@property
def max_permutation_state(self):
return self.real_length
def get_position_embedding(self, positional_embedding_dim=None):
if positional_embedding_dim is None:
positional_embedding_dim = self.dim_per_token // 2
assert self.structure is not None, "run get_structure before get_position_embedding"
if self.config["pe_granularity"] == 2:
print("Use 2d positional embedding")
positional_embedding_index = []
for key, item in self.structure.items():
if ("num_batches_tracked" in key) or (item[-1] is None):
continue
else: # conv & linear
shape, *_ = item
fake_param = torch.ones(size=shape)
fake_param = layer_to_token(fake_param, self.dim_per_token, **self.config)
positional_embedding_index.append(list(range(fake_param.size(0))))
dim1 = len(positional_embedding_index)
dim2 = max([len(token_per_layer) for token_per_layer in positional_embedding_index])
full_pe = positional_embedding_2d(dim1, dim2, positional_embedding_dim)
positional_embedding = []
for layer_index, token_indexes in enumerate(positional_embedding_index):
for token_index in token_indexes:
this_pe = full_pe[layer_index, token_index]
positional_embedding.append(this_pe)
positional_embedding = torch.stack(positional_embedding)
return positional_embedding
elif self.config["pe_granularity"] == 1:
print("Use 1d positional embedding")
return positional_embedding_1d(self.sequence_length, positional_embedding_dim)
elif self.config["pe_granularity"] == 0:
print("Not use positional embedding")
return torch.zeros_like(self.__getitem__(0))
else: # NotImplementedError
raise NotImplementedError("pe_granularity: 0: no embedding, 1: 1d embedding, 2: 2d embedding")
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.real_length
diction = torch.load(self.checkpoint_list[index], map_location="cpu")
param = self.preprocess(diction)
return param, index
def save_params(self, params, save_path):
diction = self.postprocess(params.cpu().to(torch.float32))
torch.save(diction, save_path)
def preprocess(self, diction: dict, **kwargs) -> torch.Tensor:
param_list = []
for key, value in diction.items():
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value):
continue
elif "running_var" in key:
shape, pre_mean, mean, std = self.structure[key]
value = torch.log(value / pre_mean + 0.05)
else: # normal
shape, mean, std = self.structure[key]
value = (value - mean) / std
value = layer_to_token(value, self.dim_per_token, **self.config)
param_list.append(value)
param = torch.cat(param_list, dim=0)
if self.config["granularity"] == 0: # padding directly process tail
param = pad_to_length(param, self.dim_per_token, **self.config).view(-1, self.dim_per_token)
# print("Sequence length:", param.size(0))
return param.to(torch.float32)
def postprocess(self, params: torch.Tensor, **kwargs) -> dict:
diction = {}
params = params if len(params.shape) == 2 else params.squeeze(0)
for key, item in self.structure.items():
if ("num_batches_tracked" in key) or (item[-1] is None):
shape, mean, std = item
diction[key] = mean
continue
elif "running_var" in key:
shape, pre_mean, mean, std = item
else: # conv & linear
shape, mean, std = item
this_param, params = token_to_layer(params, shape, **self.config)
this_param = this_param * std + mean
if "running_var" in key:
this_param = torch.clip(torch.exp(this_param) - 0.05, min=0.001) * pre_mean
diction[key] = this_param
return diction
class ConditionalDataset(BaseDataset, ABC):
def _extract_condition(self, index: int):
name = self.checkpoint_list[index]
condition_list = os.path.basename(name).split("_")
return condition_list
def __getitem__(self, index):
index = index % self.real_length
diction = torch.load(self.checkpoint_list[index], map_location="cpu")
condition = self._extract_condition(index)
param = self.preprocess(diction)
return param, condition