Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ python <run-file> -h
- Load the saved models and optimizers at the beginning.

### Class Diagram
Class diagram drawn on [e447f3e](https://github.com/medipixel/rl_algorithms/commit/e447f3e743f6f85505f2275b646e46f0adcf8f89). This won't be frequently updated.
![rl_algorithms_cls](https://user-images.githubusercontent.com/14961526/55703648-26022a80-5a15-11e9-8099-9bbfdffcb96d.png)
Class diagram drawn on [bd76239] (https://github.com/medipixel/rl_algorithms/pull/135/commits/bd76239684d55a92893106a3ceee9cde90294b4d)
This won't be frequently updated.
![RL_Algorithms_ClassDiagram](https://user-images.githubusercontent.com/16010242/55934443-812d5a80-5c6b-11e9-9b31-fa8214965a55.png)

### W&B for logging
We use [W&B](https://www.wandb.com/) for logging of network parameters and others. For more details, read [W&B tutorial](https://docs.wandb.com/docs/started.html).
Expand Down
8 changes: 4 additions & 4 deletions algorithms/a2c/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import torch.nn.functional as F
import wandb

from algorithms.common.abstract.agent import AbstractAgent
from algorithms.common.abstract.agent import Agent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(AbstractAgent):
class A2CAgent(Agent):
"""1-Step Advantage Actor-Critic interacting with environment.

Attributes:
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
optims (tuple): optimizers for actor and critic

"""
AbstractAgent.__init__(self, env, args)
Agent.__init__(self, env, args)

self.actor, self.critic = models
self.actor_optimizer, self.critic_optimizer = optims
Expand Down Expand Up @@ -158,7 +158,7 @@ def save_params(self, n_episode: int):
"critic_optim_state_dict": self.critic_optimizer.state_dict(),
}

AbstractAgent.save_params(self, params, n_episode)
Agent.save_params(self, params, n_episode)

def write_log(self, i: int, score: int, policy_loss: float, value_loss: float):
total_loss = policy_loss + value_loss
Expand Down
15 changes: 7 additions & 8 deletions algorithms/bc/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
import torch.nn.functional as F
import wandb

from algorithms.common.abstract.her import AbstractHER
from algorithms.common.abstract.her import HER
from algorithms.common.buffer.replay_buffer import ReplayBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.common.noise import OUNoise
from algorithms.ddpg.agent import Agent as DDPGAgent
from algorithms.ddpg.agent import DDPGAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(DDPGAgent):
class BCDDPGAgent(DDPGAgent):
"""BC with DDPG agent interacting with environment.

Attributes:
HER (AbstractHER): hinsight experience replay
her (HER): hinsight experience replay
transitions_epi (list): transitions per episode (for HER)
desired_state (np.ndarray): desired state of current episode
memory (ReplayBuffer): replay memory
Expand All @@ -47,14 +47,14 @@ def __init__(
models: tuple,
optims: tuple,
noise: OUNoise,
HER: AbstractHER,
her: HER,
):
"""Initialization.
Args:
HER (AbstractHER): hinsight experience replay
her (HER): hinsight experience replay

"""
self.HER = HER
self.her = her
DDPGAgent.__init__(self, env, args, hyper_params, models, optims, noise)

# pylint: disable=attribute-defined-outside-init
Expand All @@ -66,7 +66,6 @@ def _initialize(self):

# HER
if self.hyper_params["USE_HER"]:
self.her = self.HER()
if self.hyper_params["DESIRED_STATES_FROM_DEMO"]:
self.her.fetch_desired_states_from_demo(demo)

Expand Down
15 changes: 7 additions & 8 deletions algorithms/bc/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
import torch.nn.functional as F
import wandb

from algorithms.common.abstract.her import AbstractHER
from algorithms.common.abstract.her import HER
from algorithms.common.buffer.replay_buffer import ReplayBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.sac.agent import Agent as SACAgent
from algorithms.sac.agent import SACAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(SACAgent):
class BCSACAgent(SACAgent):
"""BC with SAC agent interacting with environment.

Attrtibutes:
HER (AbstractHER): hinsight experience replay
her (HER): hinsight experience replay
transitions_epi (list): transitions per episode (for HER)
desired_state (np.ndarray): desired state of current episode
memory (ReplayBuffer): replay memory
Expand All @@ -48,14 +48,14 @@ def __init__(
models: tuple,
optims: tuple,
target_entropy: float,
HER: AbstractHER,
her: HER,
):
"""Initialization.
Args:
HER (AbstractHER): hinsight experience replay
her (HER): hinsight experience replay

"""
self.HER = HER
self.her = her
SACAgent.__init__(self, env, args, hyper_params, models, optims, target_entropy)

# pylint: disable=attribute-defined-outside-init
Expand All @@ -67,7 +67,6 @@ def _initialize(self):

# HER
if self.hyper_params["USE_HER"]:
self.her = self.HER()
if self.hyper_params["DESIRED_STATES_FROM_DEMO"]:
self.her.fetch_desired_states_from_demo(demo)

Expand Down
2 changes: 1 addition & 1 deletion algorithms/common/abstract/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import wandb


class AbstractAgent(ABC):
class Agent(ABC):
"""Abstract Agent used for all agents.

Attributes:
Expand Down
6 changes: 3 additions & 3 deletions algorithms/common/abstract/her.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

import numpy as np

from algorithms.common.abstract.reward_fn import AbstractRewardFn
from algorithms.common.abstract.reward_fn import RewardFn


class AbstractHER(ABC):
class HER(ABC):
"""Abstract class for HER (final strategy).

Attributes:
reward_func (Callable): returns reward from state, action, next_state

"""

def __init__(self, reward_func: AbstractRewardFn):
def __init__(self, reward_func: RewardFn):
"""Initialization.

Args:
Expand Down
2 changes: 1 addition & 1 deletion algorithms/common/abstract/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np


class AbstractRewardFn(ABC):
class RewardFn(ABC):
"""Abstract class for computing reward.
New compute_reward class should redefine __call__()

Expand Down
8 changes: 4 additions & 4 deletions algorithms/ddpg/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import torch.nn.functional as F
import wandb

from algorithms.common.abstract.agent import AbstractAgent
from algorithms.common.abstract.agent import Agent
from algorithms.common.buffer.replay_buffer import ReplayBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.common.noise import OUNoise

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(AbstractAgent):
class DDPGAgent(Agent):
"""ActorCritic interacting with environment.

Attributes:
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
noise (OUNoise): random noise for exploration

"""
AbstractAgent.__init__(self, env, args)
Agent.__init__(self, env, args)

self.actor, self.actor_target, self.critic, self.critic_target = models
self.actor_optimizer, self.critic_optimizer = optims
Expand Down Expand Up @@ -196,7 +196,7 @@ def save_params(self, n_episode: int):
"critic_optim_state_dict": self.critic_optimizer.state_dict(),
}

AbstractAgent.save_params(self, params, n_episode)
Agent.save_params(self, params, n_episode)

def write_log(self, i: int, loss: np.ndarray, score: int):
"""Write log about loss and score"""
Expand Down
8 changes: 4 additions & 4 deletions algorithms/dqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.nn.utils import clip_grad_norm_
import wandb

from algorithms.common.abstract.agent import AbstractAgent
from algorithms.common.abstract.agent import Agent
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer
import algorithms.common.helper_functions as common_utils
Expand All @@ -33,7 +33,7 @@
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(AbstractAgent):
class DQNAgent(Agent):
"""DQN interacting with environment.

Attribute:
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
optim (torch.optim.Adam): optimizers for dqn

"""
AbstractAgent.__init__(self, env, args)
Agent.__init__(self, env, args)

self.use_n_step = hyper_params["N_STEP"] > 1
self.epsilon = hyper_params["MAX_EPSILON"]
Expand Down Expand Up @@ -267,7 +267,7 @@ def save_params(self, n_episode: int):
"dqn_optim_state_dict": self.dqn_optimizer.state_dict(),
}

AbstractAgent.save_params(self, params, n_episode)
Agent.save_params(self, params, n_episode)

def write_log(self, i: int, loss: np.ndarray, score: float, avg_time_cost: float):
"""Write log about loss and score"""
Expand Down
4 changes: 2 additions & 2 deletions algorithms/fd/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBufferfD
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.ddpg.agent import Agent as DDPGAgent
from algorithms.ddpg.agent import DDPGAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(DDPGAgent):
class DDPGfDAgent(DDPGAgent):
"""ActorCritic interacting with environment.

Attributes:
Expand Down
4 changes: 2 additions & 2 deletions algorithms/fd/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBufferfD
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.dqn.agent import Agent as DQNAgent
from algorithms.dqn.agent import DQNAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(DQNAgent):
class DQfDAgent(DQNAgent):
"""DQN interacting with environment.

Attribute:
Expand Down
4 changes: 2 additions & 2 deletions algorithms/fd/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBufferfD
from algorithms.common.buffer.replay_buffer import NStepTransitionBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.sac.agent import Agent as SACAgent
from algorithms.sac.agent import SACAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(SACAgent):
class SACfDAgent(SACAgent):
"""SAC agent interacting with environment.

Attrtibutes:
Expand Down
4 changes: 2 additions & 2 deletions algorithms/per/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
import algorithms.common.helper_functions as common_utils
from algorithms.ddpg.agent import Agent as DDPGAgent
from algorithms.ddpg.agent import DDPGAgent

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(DDPGAgent):
class PERDDPGAgent(DDPGAgent):
"""ActorCritic interacting with environment.

Attributes:
Expand Down
8 changes: 4 additions & 4 deletions algorithms/ppo/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import torch.nn as nn
import wandb

from algorithms.common.abstract.agent import AbstractAgent
from algorithms.common.abstract.agent import Agent
from algorithms.common.env.multiprocessing_env import SubprocVecEnv
import algorithms.ppo.utils as ppo_utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(AbstractAgent):
class PPOAgent(Agent):
"""PPO Agent.

Attributes:
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
optims (tuple): optimizers for actor and critic

"""
AbstractAgent.__init__(self, env_single, args)
Agent.__init__(self, env_single, args)

if not self.args.test:
self.env = env_multi
Expand Down Expand Up @@ -251,7 +251,7 @@ def save_params(self, n_episode: int):
"actor_optim_state_dict": self.actor_optimizer.state_dict(),
"critic_optim_state_dict": self.critic_optimizer.state_dict(),
}
AbstractAgent.save_params(self, params, n_episode)
Agent.save_params(self, params, n_episode)

def write_log(
self,
Expand Down
8 changes: 4 additions & 4 deletions algorithms/sac/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import torch.optim as optim
import wandb

from algorithms.common.abstract.agent import AbstractAgent
from algorithms.common.abstract.agent import Agent
from algorithms.common.buffer.replay_buffer import ReplayBuffer
import algorithms.common.helper_functions as common_utils

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Agent(AbstractAgent):
class SACAgent(Agent):
"""SAC agent interacting with environment.

Attrtibutes:
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
target_entropy (float): target entropy for the inequality constraint

"""
AbstractAgent.__init__(self, env, args)
Agent.__init__(self, env, args)

self.actor, self.vf, self.vf_target, self.qf_1, self.qf_2 = models
self.actor_optimizer, self.vf_optimizer = optims[0:2]
Expand Down Expand Up @@ -281,7 +281,7 @@ def save_params(self, n_episode: int):
if self.hyper_params["AUTO_ENTROPY_TUNING"]:
params["alpha_optim"] = self.alpha_optimizer.state_dict()

AbstractAgent.save_params(self, params, n_episode)
Agent.save_params(self, params, n_episode)

def write_log(
self, i: int, loss: np.ndarray, score: float = 0.0, policy_update_freq: int = 1
Expand Down
Loading