Spaces:
Sleeping
Sleeping
''' | |
We implemented `iCaRL+RMM`, `FOSTER+RMM` in [rmm.py](models/rmm.py). We implemented the `Pretraining Stage` of `RMM` in [rmm_train.py](rmm_train.py). | |
Use the following training script to run it. | |
```bash | |
python rmm_train.py --config=./exps/rmm-pretrain.json | |
``` | |
''' | |
import json | |
import argparse | |
from trainer import train | |
import sys | |
import logging | |
import copy | |
import torch | |
from utils import factory | |
from utils.data_manager import DataManager | |
from utils.rl_utils.ddpg import DDPG | |
from utils.rl_utils.rl_utils import ReplayBuffer | |
from utils.toolkit import count_parameters | |
import os | |
import numpy as np | |
import random | |
class CILEnv: | |
def __init__(self, args) -> None: | |
self._args = copy.deepcopy(args) | |
self.settings = [(50, 2), (50, 5), (50, 10), (50, 20), (10, 10), (20, 20), (5, 5)] | |
# self.settings = [(5,5)] # Debug | |
self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] | |
self.data_manager = DataManager( | |
self._args["dataset"], | |
self._args["shuffle"], | |
self._args["seed"], | |
self._args["init_cls"], | |
self._args["increment"], | |
) | |
self.model = factory.get_model(self._args["model_name"], self._args) | |
def nb_task(self): | |
return self.data_manager.nb_tasks | |
def cur_task(self): | |
return self.model._cur_task | |
def get_task_size(self, task_id): | |
return self.data_manager.get_task_size(task_id) | |
def reset(self): | |
self._args["init_cls"], self._args["increment"] = self.settings[np.random.randint(len(self.settings))] | |
self.data_manager = DataManager( | |
self._args["dataset"], | |
self._args["shuffle"], | |
self._args["seed"], | |
self._args["init_cls"], | |
self._args["increment"], | |
) | |
self.model = factory.get_model(self._args["model_name"], self._args) | |
info = "start new task: dataset: {}, init_cls: {}, increment: {}".format( | |
self._args["dataset"], self._args["init_cls"], self._args["increment"] | |
) | |
return np.array([self.get_task_size(0) / 100, 0]), None, False, info | |
def step(self, action): | |
self.model._m_rate_list.append(action[0]) | |
self.model._c_rate_list.append(action[1]) | |
self.model.incremental_train(self.data_manager) | |
cnn_accy, nme_accy = self.model.eval_task() | |
self.model.after_task() | |
done = self.cur_task == self.nb_task - 1 | |
info = "running task [{}/{}]: dataset: {}, increment: {}, cnn_accy top1: {}, top5: {}".format( | |
self.model._known_classes, | |
100, | |
self._args["dataset"], | |
self._args["increment"], | |
cnn_accy["top1"], | |
cnn_accy["top5"], | |
) | |
return ( | |
np.array( | |
[ | |
self.get_task_size(self.cur_task+1)/100 if not done else 0., | |
self.model.memory_size | |
/ (self.model.memory_size + self.model.new_memory_size), | |
] | |
), | |
cnn_accy["top1"]/100, | |
done, | |
info, | |
) | |
def _train(args): | |
logs_name = "logs/RL-CIL/{}/".format(args["model_name"]) | |
if not os.path.exists(logs_name): | |
os.makedirs(logs_name) | |
logfilename = "logs/RL-CIL/{}/{}_{}_{}_{}_{}".format( | |
args["model_name"], | |
args["prefix"], | |
args["seed"], | |
args["model_name"], | |
args["convnet_type"], | |
args["dataset"], | |
) | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(filename)s] => %(message)s", | |
handlers=[ | |
logging.FileHandler(filename=logfilename + ".log"), | |
logging.StreamHandler(sys.stdout), | |
], | |
) | |
_set_random() | |
_set_device(args) | |
print_args(args) | |
actor_lr = 5e-4 | |
critic_lr = 5e-3 | |
num_episodes = 200 | |
hidden_dim = 32 | |
gamma = 0.98 | |
tau = 0.005 | |
buffer_size = 1000 | |
minimal_size = 50 | |
batch_size = 32 | |
sigma = 0.2 # action noise, encouraging the off-policy algo to explore. | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
env = CILEnv(args) | |
replay_buffer = ReplayBuffer(buffer_size) | |
agent = DDPG( | |
2, 1, 4, hidden_dim, False, 1, sigma, actor_lr, critic_lr, tau, gamma, device | |
) | |
for iteration in range(num_episodes): | |
state, *_, info = env.reset() | |
logging.info(info) | |
done = False | |
while not done: | |
action = agent.take_action(state) | |
logging.info(f"take action: m_rate {action[0]}, c_rate {action[1]}") | |
next_state, reward, done, info = env.step(action) | |
logging.info(info) | |
replay_buffer.add(state, action, reward, next_state, done) | |
state = next_state | |
if replay_buffer.size() > minimal_size: | |
b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size) | |
transition_dict = { | |
"states": b_s, | |
"actions": b_a, | |
"next_states": b_ns, | |
"rewards": b_r, | |
"dones": b_d, | |
} | |
agent.update(transition_dict) | |
def _set_device(args): | |
device_type = args["device"] | |
gpus = [] | |
for device in device_type: | |
if device_type == -1: | |
device = torch.device("cpu") | |
else: | |
device = torch.device("cuda:{}".format(device)) | |
gpus.append(device) | |
args["device"] = gpus | |
def _set_random(): | |
random.seed(1) | |
torch.manual_seed(1) | |
torch.cuda.manual_seed(1) | |
torch.cuda.manual_seed_all(1) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def print_args(args): | |
for key, value in args.items(): | |
logging.info("{}: {}".format(key, value)) | |
def train(args): | |
seed_list = copy.deepcopy(args["seed"]) | |
device = copy.deepcopy(args["device"]) | |
for seed in seed_list: | |
args["seed"] = seed | |
args["device"] = device | |
_train(args) | |
def main(): | |
args = setup_parser().parse_args() | |
param = load_json(args.config) | |
args = vars(args) # Converting argparse Namespace to a dict. | |
args.update(param) # Add parameters from json | |
train(args) | |
def load_json(settings_path): | |
with open(settings_path) as data_file: | |
param = json.load(data_file) | |
return param | |
def setup_parser(): | |
parser = argparse.ArgumentParser( | |
description="Reproduce of multiple continual learning algorthms." | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="./exps/finetune.json", | |
help="Json file of settings.", | |
) | |
return parser | |
if __name__ == "__main__": | |
main() | |