Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import numpy as np | |
class PolicyNet(torch.nn.Module): | |
def __init__(self, state_dim, hidden_dim, action_dim, action_bound): | |
super(PolicyNet, self).__init__() | |
self.fc1 = torch.nn.Linear(state_dim, hidden_dim) | |
self.fc2 = torch.nn.Linear(hidden_dim, action_dim) | |
self.action_bound = action_bound | |
def forward(self, x): | |
x = F.relu(self.fc1(x)) | |
return torch.tanh(self.fc2(x)) * self.action_bound | |
class RMMPolicyNet(torch.nn.Module): | |
def __init__(self, state_dim, hidden_dim, action_dim): | |
super(RMMPolicyNet, self).__init__() | |
self.fc1 = nn.Sequential( | |
nn.Linear(state_dim, hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, action_dim), | |
) | |
self.fc2 = nn.Sequential( | |
nn.Linear(state_dim+action_dim, hidden_dim), | |
nn.ReLU(inplace=True), | |
nn.Linear(hidden_dim, action_dim), | |
) | |
def forward(self, x): | |
a1 = torch.sigmoid(self.fc1(x)) | |
x = torch.cat([x,a1],dim=1) | |
a2 = torch.tanh(self.fc2(x)) | |
return torch.cat([a1,a2],dim=1) | |
class QValueNet(torch.nn.Module): | |
def __init__(self, state_dim, hidden_dim, action_dim): | |
super(QValueNet, self).__init__() | |
self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim) | |
self.fc2 = torch.nn.Linear(hidden_dim, 1) | |
def forward(self, x, a): | |
cat = torch.cat([x, a], dim=1) | |
x = F.relu(self.fc1(cat)) | |
return self.fc2(x) | |
class TwoLayerFC(torch.nn.Module): | |
def __init__( | |
self, num_in, num_out, hidden_dim, activation=F.relu, out_fn=lambda x: x | |
): | |
super().__init__() | |
self.fc1 = nn.Linear(num_in, hidden_dim) | |
self.fc2 = nn.Linear(hidden_dim, hidden_dim) | |
self.fc3 = nn.Linear(hidden_dim, num_out) | |
self.activation = activation | |
self.out_fn = out_fn | |
def forward(self, x): | |
x = self.activation(self.fc1(x)) | |
x = self.activation(self.fc2(x)) | |
x = self.out_fn(self.fc3(x)) | |
return x | |
class DDPG: | |
"""DDPG algo""" | |
def __init__( | |
self, | |
num_in_actor, | |
num_out_actor, | |
num_in_critic, | |
hidden_dim, | |
discrete, | |
action_bound, | |
sigma, | |
actor_lr, | |
critic_lr, | |
tau, | |
gamma, | |
device, | |
use_rmm=True, | |
): | |
out_fn = (lambda x: x) if discrete else (lambda x: torch.tanh(x) * action_bound) | |
if use_rmm: | |
self.actor = RMMPolicyNet( | |
num_in_actor, | |
hidden_dim, | |
num_out_actor, | |
).to(device) | |
self.target_actor = RMMPolicyNet( | |
num_in_actor, | |
hidden_dim, | |
num_out_actor, | |
).to(device) | |
else: | |
self.actor = TwoLayerFC( | |
num_in_actor, | |
num_out_actor, | |
hidden_dim, | |
activation=F.relu, | |
out_fn=out_fn, | |
).to(device) | |
self.target_actor = TwoLayerFC( | |
num_in_actor, | |
num_out_actor, | |
hidden_dim, | |
activation=F.relu, | |
out_fn=out_fn, | |
).to(device) | |
self.critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) | |
self.target_critic = TwoLayerFC(num_in_critic, 1, hidden_dim).to(device) | |
self.target_critic.load_state_dict(self.critic.state_dict()) | |
self.target_actor.load_state_dict(self.actor.state_dict()) | |
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) | |
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) | |
self.gamma = gamma | |
self.sigma = sigma | |
self.action_bound = action_bound | |
self.tau = tau | |
self.action_dim = num_out_actor | |
self.device = device | |
def take_action(self, state): | |
state = torch.tensor(np.expand_dims(state,0), dtype=torch.float).to(self.device) | |
action = self.actor(state)[0].detach().cpu().numpy() | |
action = action + self.sigma * np.random.randn(self.action_dim) | |
action[0]=np.clip(action[0],0,1) | |
action[1]=np.clip(action[1],-1,1) | |
return action | |
def save_state_dict(self,name): | |
dicts = { | |
"critic":self.critic.state_dict(), | |
"target_critic":self.target_critic.state_dict(), | |
"actor":self.actor.state_dict(), | |
"target_actor":self.target_actor.state_dict() | |
} | |
torch.save(dicts,name) | |
def load_state_dict(self,name): | |
dicts = torch.load(name) | |
self.critic.load_state_dict(dicts["critic"]) | |
self.target_critic.load_state_dict(dicts["target_critic"]) | |
self.actor.load_state_dict(dicts["actor"]) | |
self.target_actor.load_state_dict(dicts["target_actor"]) | |
def soft_update(self, net, target_net): | |
for param_target, param in zip(target_net.parameters(), net.parameters()): | |
param_target.data.copy_( | |
param_target.data * (1.0 - self.tau) + param.data * self.tau | |
) | |
def update(self, transition_dict): | |
states = torch.tensor(transition_dict["states"], dtype=torch.float).to( | |
self.device | |
) | |
actions = ( | |
torch.tensor(transition_dict["actions"], dtype=torch.float) | |
.to(self.device) | |
) | |
rewards = ( | |
torch.tensor(transition_dict["rewards"], dtype=torch.float) | |
.view(-1, 1) | |
.to(self.device) | |
) | |
next_states = torch.tensor( | |
transition_dict["next_states"], dtype=torch.float | |
).to(self.device) | |
dones = ( | |
torch.tensor(transition_dict["dones"], dtype=torch.float) | |
.view(-1, 1) | |
.to(self.device) | |
) | |
next_q_values = self.target_critic( | |
torch.cat([next_states, self.target_actor(next_states)], dim=1) | |
) | |
q_targets = rewards + self.gamma * next_q_values * (1 - dones) | |
critic_loss = torch.mean( | |
F.mse_loss( | |
self.critic(torch.cat([states, actions], dim=1)), | |
q_targets, | |
) | |
) | |
self.critic_optimizer.zero_grad() | |
critic_loss.backward() | |
self.critic_optimizer.step() | |
actor_loss = -torch.mean( | |
self.critic( | |
torch.cat([states, self.actor(states)], dim=1) | |
) | |
) | |
self.actor_optimizer.zero_grad() | |
actor_loss.backward() | |
self.actor_optimizer.step() | |
logging.info(f"update DDPG: actor loss {actor_loss.item():.3f}, critic loss {critic_loss.item():.3f}, ") | |
self.soft_update(self.actor, self.target_actor) # soft-update the target policy net | |
self.soft_update(self.critic, self.target_critic) # soft-update the target Q value net | |