|
import torch |
|
import einops |
|
from torch.utils.data import Dataset |
|
from torchvision.datasets import CIFAR10 |
|
from torchvision import transforms |
|
import os |
|
import math |
|
import random |
|
import json |
|
from abc import ABC |
|
import pickle |
|
|
|
|
|
|
|
|
|
def pad_to_length(x, common_factor, **config): |
|
if x.numel() % common_factor == 0: |
|
return x.flatten() |
|
|
|
full_length = (x.numel() // common_factor + 1) * common_factor |
|
padding_length = full_length - len(x.flatten()) |
|
padding = torch.full([padding_length, ], dtype=x.dtype, device=x.device, fill_value=config["fill_value"]) |
|
x = torch.cat((x.flatten(), padding), dim=0) |
|
return x |
|
|
|
def layer_to_token(x, common_factor, **config): |
|
if config["granularity"] == 2: |
|
if x.numel() <= common_factor: |
|
return pad_to_length(x.flatten(), common_factor, **config)[None] |
|
dim2 = x[0].numel() |
|
dim1 = x.shape[0] |
|
if dim2 <= common_factor: |
|
i = int(dim1 / (common_factor / dim2)) |
|
while True: |
|
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor: |
|
output = x.view(-1, dim2 * (dim1 // i)) |
|
output = [pad_to_length(item, common_factor, **config) for item in output] |
|
return torch.stack(output, dim=0) |
|
i += 1 |
|
else: |
|
output = [layer_to_token(item, common_factor, **config) for item in x] |
|
return torch.cat(output, dim=0) |
|
elif config["granularity"] == 1: |
|
return pad_to_length(x.flatten(), common_factor, **config).view(-1, common_factor) |
|
elif config["granularity"] == 0: |
|
return x.flatten() |
|
else: |
|
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim") |
|
|
|
|
|
def token_to_layer(tokens, shape, **config): |
|
common_factor = tokens.shape[-1] |
|
if config["granularity"] == 2: |
|
num_element = math.prod(shape) |
|
if num_element <= common_factor: |
|
param = tokens[0][:num_element].view(shape) |
|
tokens = tokens[1:] |
|
return param, tokens |
|
dim2 = num_element // shape[0] |
|
dim1 = shape[0] |
|
if dim2 <= common_factor: |
|
i = int(dim1 / (common_factor / dim2)) |
|
while True: |
|
if dim1 % i == 0 and dim2 * (dim1 // i) <= common_factor: |
|
item_per_token = dim2 * (dim1 // i) |
|
length = num_element // item_per_token |
|
output = [item[:item_per_token] for item in tokens[:length]] |
|
param = torch.cat(output, dim=0).view(shape) |
|
tokens = tokens[length:] |
|
return param, tokens |
|
i += 1 |
|
else: |
|
output = [] |
|
for i in range(shape[0]): |
|
param, tokens = token_to_layer(tokens, shape[1:], **config) |
|
output.append(param.flatten()) |
|
param = torch.cat(output, dim=0).view(shape) |
|
return param, tokens |
|
elif config["granularity"] == 1: |
|
num_element = math.prod(shape) |
|
token_num = num_element // common_factor if num_element % common_factor == 0 \ |
|
else num_element // common_factor + 1 |
|
param = tokens.flatten()[:num_element].view(shape) |
|
tokens = tokens[token_num:] |
|
return param, tokens |
|
elif config["granularity"] == 0: |
|
num_element = math.prod(shape) |
|
param = tokens.flatten()[:num_element].view(shape) |
|
tokens = pad_to_length(tokens.flatten()[num_element:], |
|
common_factor, fill_value=torch.nan).view(-1, common_factor) |
|
return param, tokens |
|
else: |
|
raise NotImplementedError("granularity: 0: flatten directly, 1: split by layer, 2: split by output dim") |
|
|
|
|
|
def positional_embedding_2d(dim1, dim2, d_model): |
|
assert d_model % 4 == 0, f"Cannot use sin/cos positional encoding with odd dimension {d_model}" |
|
pe = torch.zeros(d_model, dim1, dim2) |
|
d_model = int(d_model / 2) |
|
div_term = torch.exp(torch.arange(0., d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / d_model)) |
|
pos_w = torch.arange(0., dim2).unsqueeze(1) |
|
pos_h = torch.arange(0., dim1).unsqueeze(1) |
|
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1) |
|
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, dim1, 1) |
|
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2) |
|
pe[d_model+1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, dim2) |
|
return pe.permute(1, 2, 0) |
|
|
|
|
|
def positional_embedding_1d(dim1, d_model): |
|
pe = torch.zeros(dim1, d_model) |
|
position = torch.arange(0, dim1, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
return pe |
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset, ABC): |
|
data_path = None |
|
generated_path = None |
|
test_command = None |
|
config = { |
|
"fill_value": torch.nan, |
|
"granularity": 1, |
|
"pe_granularity": 2, |
|
} |
|
|
|
def __init__(self, checkpoint_path=None, dim_per_token=8192, **kwargs): |
|
if not os.path.exists(self.data_path): |
|
os.makedirs(self.data_path, exist_ok=False) |
|
if self.generated_path is not None and not os.path.exists(os.path.dirname(self.generated_path)): |
|
os.makedirs(os.path.dirname(self.generated_path)) |
|
self.config.update(kwargs) |
|
checkpoint_path = self.data_path if checkpoint_path is None else checkpoint_path |
|
assert os.path.exists(checkpoint_path) |
|
self.dim_per_token = dim_per_token |
|
self.structure = None |
|
self.sequence_length = None |
|
|
|
checkpoint_list = os.listdir(checkpoint_path) |
|
self.checkpoint_list = list([os.path.join(checkpoint_path, item) for item in checkpoint_list]) |
|
self.length = self.real_length = len(self.checkpoint_list) |
|
self.set_infinite_dataset() |
|
|
|
structure_cache_file = os.path.join(os.path.dirname(self.data_path), "structure.cache") |
|
try: |
|
assert os.path.exists(structure_cache_file) |
|
with open(structure_cache_file, "rb") as f: |
|
print(f"Loading cache from {structure_cache_file}") |
|
cache_file = pickle.load(f) |
|
if len(self.checkpoint_list) != 0: |
|
assert set(cache_file["checkpoint_list"]) == set(self.checkpoint_list) |
|
self.structure = cache_file["structure"] |
|
else: |
|
print("Cannot find any trained checkpoint, loading cache file for generating!") |
|
self.structure = cache_file["structure"] |
|
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()} |
|
torch.save(fake_diction, os.path.join(checkpoint_path, "fake_checkpoint.pth")) |
|
self.checkpoint_list.append(os.path.join(checkpoint_path, "fake_checkpoint.pth")) |
|
self.length = self.real_length = len(self.checkpoint_list) |
|
self.set_infinite_dataset() |
|
os.system(f"rm {os.path.join(checkpoint_path, 'fake_checkpoint.pth')}") |
|
except AssertionError: |
|
print("==> Organizing structure..") |
|
self.structure = self.get_structure() |
|
with open(structure_cache_file, "wb") as f: |
|
pickle.dump({"structure": self.structure, "checkpoint_list": self.checkpoint_list}, f) |
|
|
|
self.sequence_length = self.get_sequence_length() |
|
|
|
def get_sequence_length(self): |
|
fake_diction = {key: torch.zeros(item[0]) for key, item in self.structure.items()} |
|
|
|
param = self.preprocess(fake_diction) |
|
self.sequence_length = param.size(0) |
|
return self.sequence_length |
|
|
|
def get_structure(self): |
|
|
|
checkpoint_list = self.checkpoint_list |
|
structures = [{} for _ in range(len(checkpoint_list))] |
|
for i, checkpoint in enumerate(checkpoint_list): |
|
diction = torch.load(checkpoint, map_location="cpu") |
|
for key, value in diction.items(): |
|
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value): |
|
structures[i][key] = (value.shape, value, None) |
|
elif "running_var" in key: |
|
pre_mean = value.mean() * 0.95 |
|
value = torch.log(value / pre_mean + 0.05) |
|
structures[i][key] = (value.shape, pre_mean, value.mean(), value.std()) |
|
else: |
|
structures[i][key] = (value.shape, value.mean(), value.std()) |
|
final_structure = {} |
|
structure_diction = torch.load(checkpoint_list[0], map_location="cpu") |
|
for key, param in structure_diction.items(): |
|
if ("num_batches_tracked" in key) or (param.numel() == 1) or not torch.is_floating_point(param): |
|
final_structure[key] = (param.shape, param, None) |
|
elif "running_var" in key: |
|
value = [param.shape, 0., 0., 0.] |
|
for structure in structures: |
|
for i in [1, 2, 3]: |
|
value[i] += structure[key][i] |
|
for i in [1, 2, 3]: |
|
value[i] /= len(structures) |
|
final_structure[key] = tuple(value) |
|
else: |
|
value = [param.shape, 0., 0.] |
|
for structure in structures: |
|
for i in [1, 2]: |
|
value[i] += structure[key][i] |
|
for i in [1, 2]: |
|
value[i] /= len(structures) |
|
final_structure[key] = tuple(value) |
|
self.structure = final_structure |
|
return self.structure |
|
|
|
def set_infinite_dataset(self, max_num=None): |
|
if max_num is None: |
|
max_num = self.length * 1000000 |
|
self.length = max_num |
|
return self |
|
|
|
@property |
|
def max_permutation_state(self): |
|
return self.real_length |
|
|
|
def get_position_embedding(self, positional_embedding_dim=None): |
|
if positional_embedding_dim is None: |
|
positional_embedding_dim = self.dim_per_token // 2 |
|
assert self.structure is not None, "run get_structure before get_position_embedding" |
|
if self.config["pe_granularity"] == 2: |
|
print("Use 2d positional embedding") |
|
positional_embedding_index = [] |
|
for key, item in self.structure.items(): |
|
if ("num_batches_tracked" in key) or (item[-1] is None): |
|
continue |
|
else: |
|
shape, *_ = item |
|
fake_param = torch.ones(size=shape) |
|
fake_param = layer_to_token(fake_param, self.dim_per_token, **self.config) |
|
positional_embedding_index.append(list(range(fake_param.size(0)))) |
|
dim1 = len(positional_embedding_index) |
|
dim2 = max([len(token_per_layer) for token_per_layer in positional_embedding_index]) |
|
full_pe = positional_embedding_2d(dim1, dim2, positional_embedding_dim) |
|
positional_embedding = [] |
|
for layer_index, token_indexes in enumerate(positional_embedding_index): |
|
for token_index in token_indexes: |
|
this_pe = full_pe[layer_index, token_index] |
|
positional_embedding.append(this_pe) |
|
positional_embedding = torch.stack(positional_embedding) |
|
return positional_embedding |
|
elif self.config["pe_granularity"] == 1: |
|
print("Use 1d positional embedding") |
|
return positional_embedding_1d(self.sequence_length, positional_embedding_dim) |
|
elif self.config["pe_granularity"] == 0: |
|
print("Not use positional embedding") |
|
return torch.zeros_like(self.__getitem__(0)) |
|
else: |
|
raise NotImplementedError("pe_granularity: 0: no embedding, 1: 1d embedding, 2: 2d embedding") |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, index): |
|
index = index % self.real_length |
|
diction = torch.load(self.checkpoint_list[index], map_location="cpu") |
|
param = self.preprocess(diction) |
|
return param, index |
|
|
|
def save_params(self, params, save_path): |
|
diction = self.postprocess(params.cpu().to(torch.float32)) |
|
torch.save(diction, save_path) |
|
|
|
def preprocess(self, diction: dict, **kwargs) -> torch.Tensor: |
|
param_list = [] |
|
for key, value in diction.items(): |
|
if ("num_batches_tracked" in key) or (value.numel() == 1) or not torch.is_floating_point(value): |
|
continue |
|
elif "running_var" in key: |
|
shape, pre_mean, mean, std = self.structure[key] |
|
value = torch.log(value / pre_mean + 0.05) |
|
else: |
|
shape, mean, std = self.structure[key] |
|
value = (value - mean) / std |
|
value = layer_to_token(value, self.dim_per_token, **self.config) |
|
param_list.append(value) |
|
param = torch.cat(param_list, dim=0) |
|
if self.config["granularity"] == 0: |
|
param = pad_to_length(param, self.dim_per_token, **self.config).view(-1, self.dim_per_token) |
|
|
|
return param.to(torch.float32) |
|
|
|
def postprocess(self, params: torch.Tensor, **kwargs) -> dict: |
|
diction = {} |
|
params = params if len(params.shape) == 2 else params.squeeze(0) |
|
for key, item in self.structure.items(): |
|
if ("num_batches_tracked" in key) or (item[-1] is None): |
|
shape, mean, std = item |
|
diction[key] = mean |
|
continue |
|
elif "running_var" in key: |
|
shape, pre_mean, mean, std = item |
|
else: |
|
shape, mean, std = item |
|
this_param, params = token_to_layer(params, shape, **self.config) |
|
this_param = this_param * std + mean |
|
if "running_var" in key: |
|
this_param = torch.clip(torch.exp(this_param) - 0.05, min=0.001) * pre_mean |
|
diction[key] = this_param |
|
return diction |
|
|
|
|
|
class ConditionalDataset(BaseDataset, ABC): |
|
def _extract_condition(self, index: int): |
|
name = self.checkpoint_list[index] |
|
condition_list = os.path.basename(name).split("_") |
|
return condition_list |
|
|
|
def __getitem__(self, index): |
|
index = index % self.real_length |
|
diction = torch.load(self.checkpoint_list[index], map_location="cpu") |
|
condition = self._extract_condition(index) |
|
param = self.preprocess(diction) |
|
return param, condition |