这是indexloc提供的服务,不要输入任何密码
Skip to content

add policy.update to enable post process and remove collector.sample #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page

### Modularized Policy

We decouple all of the algorithms into 4 parts:
We decouple all of the algorithms roughly into the following parts:

- `__init__`: initialize the policy;
- `forward`: to compute actions over given observations;
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `learn`: to learn from a given batch data.
- `learn`: to learn from a given batch data;
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.

Within this API, we can interact with different policies conveniently.

Expand All @@ -165,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3])
If you want to train the given policy with a sampled batch:

```python
result = policy.learn(collector.sample(batch_size))
result = policy.update(batch_size, collector.buffer)
```

You can check out the [documentation](https://tianshou.readthedocs.io) for further usage.
Expand Down
17 changes: 8 additions & 9 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ Policy

Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.

A policy class typically has four parts:
A policy class typically has the following parts:

* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on;
* :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation;
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer);
* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer;
* :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data.
* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data.
* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``.

Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as:

Expand Down Expand Up @@ -125,10 +127,8 @@ Collector
---------

The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.
In short, :class:`~tianshou.data.Collector` has two main methods:

* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer;
* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data.
:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.

Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.

Expand All @@ -144,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho

Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage.

There will be more types of trainers, for instance, multi-agent trainer.


.. _pseudocode:

Expand All @@ -165,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
if i % 1000 == 0: # done in trainer
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size)
# the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size)
# compute 2-step returns. How?
b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice)
# update DQN policy
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet:
# back to training eps
policy.set_eps(0.1)

# train policy with a sampled batch data
losses = policy.learn(train_collector.sample(batch_size=64))
# train policy with a sampled batch data from buffer
losses = policy.update(64, train_collector.buffer)

For further usage, you can refer to the :doc:`/tutorials/cheatsheet`.

Expand Down
8 changes: 4 additions & 4 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_collector_with_dict_state():
c1.seed(0)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([
Expand All @@ -184,7 +184,7 @@ def test_collector_with_dict_state():
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
Logger.single_preprocess_fn)
c2.collect(n_episode=[0, 0, 0, 10])
batch = c2.sample(10)
batch, _ = c2.buffer.sample(10)
print(batch['obs_next']['index'])


Expand All @@ -209,7 +209,7 @@ def reward_metric(x):
assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c1.sample(10)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
obs = np.array(np.expand_dims([
Expand All @@ -226,7 +226,7 @@ def reward_metric(x):
Logger.single_preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10)
batch, _ = c2.buffer.sample(10)
print(batch['obs_next'])


Expand Down
19 changes: 7 additions & 12 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,6 @@ class Collector(object):
# sleep time between rendering consecutive frames)
collector.collect(n_episode=1, render=0.03)

# sample data with a given number of batch-size:
batch_data = collector.sample(batch_size=64)
# policy.learn(batch_data) # btw, vanilla policy gradient only
# supports on-policy training, so here we pick all data in the buffer
batch_data = collector.sample(batch_size=0)
policy.learn(batch_data)
# on-policy algorithms use the collected data only once, so here we
# clear the buffer
collector.reset_buffer()

Collected data always consist of full episodes. So if only ``n_step``
argument is give, the collector may return the data more than the
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
Expand Down Expand Up @@ -357,13 +347,18 @@ def collect(self,

def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer. It will call
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning the
final batch data.

:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
import warnings
warnings.warn(
'Collector.sample is deprecated and will cause error if you use '
'prioritized experience replay! Collector.sample will be removed '
'upon version 0.3. Use policy.update instead!', Warning)
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data
Expand Down
37 changes: 31 additions & 6 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def forward(self, batch: Batch,

# some code
return Batch(..., policy=Batch(log_prob=dist.log_prob(act)))
# and in the sampled data batch, you can directly call
# batch.policy.log_prob to get your data, although it is stored in
# np.ndarray.
# and in the sampled data batch, you can directly use
# batch.policy.log_prob to get your data.
"""
pass

Expand Down Expand Up @@ -123,6 +122,7 @@ def compute_episodic_return(
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: float = 0.99,
gae_lambda: float = 0.95,
rew_norm: bool = False,
) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimator (arXiv:1506.02438).
Expand All @@ -136,6 +136,8 @@ def compute_episodic_return(
to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage
Estimation, should be in [0, 1], defaults to 0.95.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to ``False``.

:return: a Batch. The result will be stored in batch.returns as a numpy
array with shape (bsz, ).
Expand All @@ -150,6 +152,8 @@ def compute_episodic_return(
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] += gae
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch

Expand Down Expand Up @@ -196,7 +200,7 @@ def compute_nstep_return(
if rew_norm:
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0):
if np.isclose(std, 0, 1e-2):
mean, std = 0, 1
else:
mean, std = 0, 1
Expand All @@ -216,9 +220,30 @@ def compute_nstep_return(
batch.returns = target_q * gammas + returns
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice
batch.weight = to_torch_as(batch.weight, target_q)
else:
batch.weight = torch.ones_like(target_q)
return batch

def post_process_fn(self, batch: Batch,
buffer: ReplayBuffer, indice: np.ndarray):
"""Post-process the data from the provided replay buffer. Typical
usage is to update the sampling weight in prioritized experience
replay. Check out :ref:`policy_concept` for more information.
"""
if isinstance(buffer, PrioritizedReplayBuffer):
buffer.update_weight(indice, batch.weight)

def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs):
"""Update the policy network and replay buffer (if needed). It includes
three function steps: process_fn, learn, and post_process_fn.

:param int batch_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with the given batch_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
"""
batch, indice = buffer.sample(batch_size)
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, *args, **kwargs)
self.post_process_fn(batch, buffer, indice)
return result
6 changes: 2 additions & 4 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer,
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
rew_norm=self._rew_norm)

def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
Expand Down Expand Up @@ -97,9 +98,6 @@ def forward(self, batch: Batch,
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):
Expand Down
4 changes: 1 addition & 3 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
critic_loss = (td.pow(2) * batch.weight).mean()
# critic_loss = F.mse_loss(current_q, target_q)
batch.weight = td # prio-buffer
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
Expand Down
4 changes: 1 addition & 3 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
td = r - q
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, td)
loss = (td.pow(2) * batch.weight).mean()
# loss = F.mse_loss(q, r)
batch.weight = td # prio-buffer
loss.backward()
self.optim.step()
self._cnt += 1
Expand Down
5 changes: 1 addition & 4 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer,
# batch.returns = self._vectorized_returns(batch)
# return batch
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1.)
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)

def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
Expand Down Expand Up @@ -81,9 +81,6 @@ def forward(self, batch: Batch,
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
losses = []
r = batch.returns
if self._rew_norm and not np.isclose(r.std(), 0):
batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
Expand Down
45 changes: 18 additions & 27 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,29 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if not np.isclose(std, 0):
if not np.isclose(std, 0, 1e-2):
batch.rew = (batch.rew - mean) / std
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
v, v_, old_log_prob = [], [], []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next))
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
v_ = to_numpy(torch.cat(v_, dim=0))
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
batch = self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
rew_norm=self._rew_norm)
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0, 1e-2):
batch.adv = (batch.adv - mean) / std
return batch

def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
Expand Down Expand Up @@ -123,26 +134,6 @@ def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
v = []
old_log_prob = []
with torch.no_grad():
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
if not np.isclose(std.item(), 0):
batch.returns = (batch.returns - mean) / std
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0):
batch.adv = (batch.adv - mean) / std
for _ in range(repeat):
for b in batch.split(batch_size):
dist = self(b).dist
Expand Down
4 changes: 1 addition & 3 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
# prio-buffer
if hasattr(batch, 'update_weight'):
batch.update_weight(batch.indice, (td1 + td2) / 2.)
batch.weight = (td1 + td2) / 2. # prio-buffer
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
Expand Down
3 changes: 1 addition & 2 deletions tianshou/policy/modelfree/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, (td1 + td2) / 2.)
batch.weight = (td1 + td2) / 2. # prio-buffer
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()
Expand Down
Loading