Spaces:
Sleeping
Sleeping
# Please note that the current implementation of DER only contains the dynamic expansion process, since masking and pruning are not implemented by the source repo. | |
import logging | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from models.base import BaseLearner | |
from utils.inc_net import DERNet, IncrementalNet | |
from utils.toolkit import count_parameters, target2onehot, tensor2numpy | |
EPSILON = 1e-8 | |
init_epoch = 100 | |
init_lr = 0.1 | |
init_milestones = [40, 60, 80] | |
init_lr_decay = 0.1 | |
init_weight_decay = 0.0005 | |
epochs = 80 | |
lrate = 0.1 | |
milestones = [30, 50, 70] | |
lrate_decay = 0.1 | |
batch_size = 32 | |
weight_decay = 2e-4 | |
num_workers = 8 | |
T = 2 | |
class DER(BaseLearner): | |
def __init__(self, args): | |
super().__init__(args) | |
self._network = DERNet(args, False) | |
def after_task(self): | |
self._known_classes = self._total_classes | |
logging.info("Exemplar size: {}".format(self.exemplar_size)) | |
def incremental_train(self, data_manager): | |
self._cur_task += 1 | |
self._total_classes = self._known_classes + data_manager.get_task_size( | |
self._cur_task | |
) | |
self._network.update_fc(self._total_classes) | |
logging.info( | |
"Learning on {}-{}".format(self._known_classes, self._total_classes) | |
) | |
if self._cur_task > 0: | |
for i in range(self._cur_task): | |
for p in self._network.convnets[i].parameters(): | |
p.requires_grad = False | |
logging.info("All params: {}".format(count_parameters(self._network))) | |
logging.info( | |
"Trainable params: {}".format(count_parameters(self._network, True)) | |
) | |
train_dataset = data_manager.get_dataset( | |
np.arange(self._known_classes, self._total_classes), | |
source="train", | |
mode="train", | |
appendent=self._get_memory(), | |
) | |
self.train_loader = DataLoader( | |
train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers | |
) | |
test_dataset = data_manager.get_dataset( | |
np.arange(0, self._total_classes), source="test", mode="test" | |
) | |
self.test_loader = DataLoader( | |
test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers | |
) | |
if len(self._multiple_gpus) > 1: | |
self._network = nn.DataParallel(self._network, self._multiple_gpus) | |
self._train(self.train_loader, self.test_loader) | |
self.build_rehearsal_memory(data_manager, self.samples_per_class) | |
if len(self._multiple_gpus) > 1: | |
self._network = self._network.module | |
def train(self): | |
self._network.train() | |
if len(self._multiple_gpus) > 1 : | |
self._network_module_ptr = self._network.module | |
else: | |
self._network_module_ptr = self._network | |
self._network_module_ptr.convnets[-1].train() | |
if self._cur_task >= 1: | |
for i in range(self._cur_task): | |
self._network_module_ptr.convnets[i].eval() | |
def _train(self, train_loader, test_loader): | |
self._network.to(self._device) | |
if self._cur_task == 0: | |
optimizer = optim.SGD( | |
filter(lambda p: p.requires_grad, self._network.parameters()), | |
momentum=0.9, | |
lr=init_lr, | |
weight_decay=init_weight_decay, | |
) | |
scheduler = optim.lr_scheduler.MultiStepLR( | |
optimizer=optimizer, milestones=init_milestones, gamma=init_lr_decay | |
) | |
self._init_train(train_loader, test_loader, optimizer, scheduler) | |
else: | |
optimizer = optim.SGD( | |
filter(lambda p: p.requires_grad, self._network.parameters()), | |
lr=lrate, | |
momentum=0.9, | |
weight_decay=weight_decay, | |
) | |
scheduler = optim.lr_scheduler.MultiStepLR( | |
optimizer=optimizer, milestones=milestones, gamma=lrate_decay | |
) | |
self._update_representation(train_loader, test_loader, optimizer, scheduler) | |
if len(self._multiple_gpus) > 1: | |
self._network.module.weight_align( | |
self._total_classes - self._known_classes | |
) | |
else: | |
self._network.weight_align(self._total_classes - self._known_classes) | |
def _init_train(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(init_epoch)) | |
for _, epoch in enumerate(prog_bar): | |
self.train() | |
losses = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to(self._device), targets.to(self._device) | |
logits = self._network(inputs)["logits"] | |
loss = F.cross_entropy(logits, targets) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
init_epoch, | |
losses / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
init_epoch, | |
losses / len(train_loader), | |
train_acc, | |
) | |
prog_bar.set_description(info) | |
logging.info(info) | |
def _update_representation(self, train_loader, test_loader, optimizer, scheduler): | |
prog_bar = tqdm(range(epochs)) | |
for _, epoch in enumerate(prog_bar): | |
self.train() | |
losses = 0.0 | |
losses_clf = 0.0 | |
losses_aux = 0.0 | |
correct, total = 0, 0 | |
for i, (_, inputs, targets) in enumerate(train_loader): | |
inputs, targets = inputs.to(self._device), targets.to(self._device) | |
outputs = self._network(inputs) | |
logits, aux_logits = outputs["logits"], outputs["aux_logits"] | |
loss_clf = F.cross_entropy(logits, targets) | |
aux_targets = targets.clone() | |
aux_targets = torch.where( | |
aux_targets - self._known_classes + 1 > 0, | |
aux_targets - self._known_classes + 1, | |
torch.tensor([0]).to(self._device), | |
) | |
loss_aux = F.cross_entropy(aux_logits, aux_targets) | |
loss = loss_clf + loss_aux | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
losses += loss.item() | |
losses_aux += loss_aux.item() | |
losses_clf += loss_clf.item() | |
_, preds = torch.max(logits, dim=1) | |
correct += preds.eq(targets.expand_as(preds)).cpu().sum() | |
total += len(targets) | |
scheduler.step() | |
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) | |
if epoch % 5 == 0: | |
test_acc = self._compute_accuracy(self._network, test_loader) | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_aux / len(train_loader), | |
train_acc, | |
test_acc, | |
) | |
else: | |
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Loss_clf {:.3f}, Loss_aux {:.3f}, Train_accy {:.2f}".format( | |
self._cur_task, | |
epoch + 1, | |
epochs, | |
losses / len(train_loader), | |
losses_clf / len(train_loader), | |
losses_aux / len(train_loader), | |
train_acc, | |
) | |
prog_bar.set_description(info) | |
logging.info(info) | |