Spaces:
Sleeping
Sleeping
''' | |
Re-implementation of SimpleCIL (https://arxiv.org/abs/2303.07338) without pre-trained weights. | |
The training process is as follows: train the model with cross-entropy in the first stage and replace the classifier with prototypes for all the classes in the subsequent stages. | |
Please refer to the original implementation (https://github.com/zhoudw-zdw/RevisitingCIL) if you are using pre-trained weights. | |
''' | |
import logging | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.serialization import load | |
from tqdm import tqdm | |
from torch import optim | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from utils.inc_net import SimpleCosineIncrementalNet | |
from models.base import BaseLearner | |
from utils.toolkit import target2onehot, tensor2numpy | |
num_workers = 8 | |
batch_size = 32 | |
milestones = [40, 80] | |
class SimpleCIL(BaseLearner): | |
def __init__(self, args): | |
super().__init__(args) | |
self._network = SimpleCosineIncrementalNet(args, False) | |
self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8 | |
self.args = args | |
def load_checkpoint(self, filename): | |
checkpoint = torch.load(filename) | |
self._total_classes = len(checkpoint["classes"]) | |
self.class_list = np.array(checkpoint["classes"]) | |
self.label_list = checkpoint["label_list"] | |
print("Class list: ", self.class_list) | |
self._network.update_fc(self._total_classes) | |
self._network.load_checkpoint(checkpoint["network"]) | |
self._network.to(self._device) | |
def after_task(self): | |
self._known_classes = self._total_classes | |
def save_checkpoint(self, filename): | |
self._network.cpu() | |
save_dict = { | |
"classes": self.data_manager.get_class_list(self._cur_task), | |
"network": { | |
"convnet": self._network.convnet.state_dict(), | |
"fc": self._network.fc.state_dict() | |
}, | |
"label_list": self.data_manager.get_label_list(self._cur_task), | |
} | |
torch.save(save_dict, "./{}/{}_{}.pkl".format(filename, self.args['model_name'], self._cur_task)) | |
def replace_fc(self,trainloader, model, args): | |
model = model.eval() | |
embedding_list = [] | |
label_list = [] | |
with torch.no_grad(): | |
for i, batch in enumerate(trainloader): | |
(_,data,label) = batch | |
data = data.cuda() | |
label = label.cuda() | |
embedding = model(data)["features"] | |
embedding_list.append(embedding.cpu()) | |
label_list.append(label.cpu()) | |
embedding_list = torch.cat(embedding_list, dim=0) | |
label_list = torch.cat(label_list, dim=0) | |
class_list = np.unique(self.train_dataset.labels) | |
proto_list = [] | |
for class_index in class_list: | |
# print('Replacing...',class_index) | |
data_index = torch.nonzero(label_list == class_index).squeeze(-1) | |
embedding = embedding_list[data_index] | |
proto = embedding.mean(0) | |
if len(self._multiple_gpus) > 1: | |
self._network.module.fc.weight.data[class_index] = proto | |
else: | |
self._network.fc.weight.data[class_index] = proto | |
return model | |
def incremental_train(self, data_manager): | |
self._cur_task += 1 | |
self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) | |
self._network.update_fc(self._total_classes) | |
logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) | |
self.class_list = np.array(data_manager.get_class_list(self._cur_task)) | |
train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train", ) | |
self.train_dataset = train_dataset | |
self.data_manager = data_manager | |
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test" ) | |
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) | |
train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="test", ) | |
self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=batch_size, shuffle=True, num_workers=num_workers) | |
if len(self._multiple_gpus) > 1: | |
print('Multiple GPUs') | |
self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) | |
if len(self._multiple_gpus) > 1: | |
self._network = self._network.module | |
def _train(self, train_loader, test_loader, train_loader_for_protonet): | |
self._network.to(self._device) | |
if self._cur_task == 0: | |
optimizer = optim.SGD( | |
self._network.parameters(), | |
momentum=0.9, | |
lr=self.args["init_lr"], | |
weight_decay=self.args["init_weight_decay"] | |
) | |
scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
optimizer=optimizer, T_max=self.args['init_epoch'], eta_min=self.min_lr | |
) | |
self._init_train(train_loader, test_loader, optimizer, scheduler) | |
self.replace_fc(train_loader_for_protonet, self._network, None) | |
def _init_train(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(self.args["init_epoch"])) | |
for _, epoch in enumerate(prog_bar): | |
self._network.train() | |
losses = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to(self._device), targets.to(self._device) | |
logits = self._network(inputs)["logits"] | |
loss = F.cross_entropy(logits, targets) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.args['init_epoch'], | |
losses / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
self.args['init_epoch'], | |
losses / len(train_loader), | |
train_acc, | |
) | |
elapsed = prog_bar.format_dict["elapsed"] | |
rate = prog_bar.format_dict["rate"] | |
remaining = (prog_bar.total - prog_bar.n) / rate if rate and prog_bar.total else 0 # Seconds* | |
prog_bar.set_description(info) | |
logging.info("Working on task {}: {:.2f}:{:.2f}".format( | |
self._cur_task, | |
elapsed, | |
remaining)) | |
logging.info(info) | |
logging.info("Finised on task {}: {:.2f}".format( | |
self._cur_task, elapsed)) | |