这是indexloc提供的服务,不要输入任何密码
Skip to content

fix #98, support #99 #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
6 changes: 4 additions & 2 deletions examples/continuous_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self, s, **kwargs):

class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
max_action, device='cpu', unbounded=False):
super().__init__()
self.device = device
self.model = [
Expand All @@ -40,14 +40,16 @@ def __init__(self, layer_num, state_shape, action_shape,
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action
self._unbounded = unbounded

def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
mu = self._max * torch.tanh(self.mu(logits))
if not self._unbounded:
mu = self._max * torch.tanh(self.mu(logits))
sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None

Expand Down
2 changes: 1 addition & 1 deletion examples/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
2 changes: 1 addition & 1 deletion examples/sac_mcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
8 changes: 6 additions & 2 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def offpolicy_trainer(
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
Expand All @@ -42,10 +43,13 @@ def offpolicy_trainer(
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
and do some policy network update.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames be collected. In other words, collect some
frames and do some policy network update.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
Expand Down Expand Up @@ -98,7 +102,7 @@ def offpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
for i in range(min(
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
Expand Down