这是indexloc提供的服务,不要输入任何密码
Skip to content

Add Rainbow DQN #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 29, 2021
Merged

Add Rainbow DQN #386

merged 40 commits into from
Aug 29, 2021

Conversation

nuance1979
Copy link
Collaborator

I am currently running Atari examples. Will update the results soon.

@codecov-commenter
Copy link

codecov-commenter commented Jun 30, 2021

Codecov Report

Merging #386 (a040a6d) into master (d161059) will increase coverage by 0.13%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #386      +/-   ##
==========================================
+ Coverage   94.69%   94.82%   +0.13%     
==========================================
  Files          57       58       +1     
  Lines        3749     3807      +58     
==========================================
+ Hits         3550     3610      +60     
+ Misses        199      197       -2     
Flag Coverage Δ
unittests 94.82% <100.00%> (+0.13%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/data/buffer/prio.py 91.30% <100.00%> (+0.82%) ⬆️
tianshou/data/buffer/vecbuf.py 100.00% <100.00%> (ø)
tianshou/policy/__init__.py 100.00% <100.00%> (ø)
tianshou/policy/modelfree/rainbow.py 100.00% <100.00%> (ø)
tianshou/utils/net/common.py 95.74% <100.00%> (+2.12%) ⬆️
tianshou/utils/net/discrete.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d161059...a040a6d. Read the comment docs.

@Trinkle23897 Trinkle23897 linked an issue Jun 30, 2021 that may be closed by this pull request
8 tasks
@Trinkle23897 Trinkle23897 linked an issue Jul 1, 2021 that may be closed by this pull request
Copy link
Collaborator

@Trinkle23897 Trinkle23897 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also could you please check #393?

@Trinkle23897
Copy link
Collaborator

Why did you move weight norm to policy side?
2021-07-19 09-17-30屏幕截图

@nuance1979
Copy link
Collaborator Author

Why did you move weight norm to policy side?

I thought for backward compatibility it might be better to put the weight norm hack on the policy side. Basically it's up to the policy to decide whether it wants to use the weight norm hack or not.

I don't have a strong opinion about it so if you insist, I can move it back to the buffer side.

@Trinkle23897
Copy link
Collaborator

So how about adding an extra argument in prio-buffer and set it default to False?

@Trinkle23897
Copy link
Collaborator

@nuance1979 is it ready now...?

@nuance1979
Copy link
Collaborator Author

@nuance1979 is it ready now...?

Sorry for the delay.

After fixing the weight normalization and beta annealing, some tasks (Enduro, SpaceInvaders, etc.) still get terrible results so I tried hard to figure out why. Let me summarize my current findings:

I compared the current NoisyLinear implementation with https://github.com/deepmind/dqn_zoo/ and found two differences:

  1. A different but equivalent way of adding the noise;
  2. dqn_zoo removes the bias term for the second NoisyLinear layer.

I tried to align with dqn_zoo implementations regarding the above two but found no meaningful differences in Enduro performance.

So I tried turning off some features of Rainbow:

Experiment Priority Buffer NoisyLinear Dueling Best reward
Enduro C51 baseline (copied from README.md) N N N 1032
weight norm + beta from 0.4 to 1 Y Y Y 450.2
- prio buffer N Y Y 487.2
- prio buffer - dueling N Y N 354.9
- prio buffer - dueling - noisy N N N 1459.7
- prio buffer - noisy N N Y 1469
- noisy Y N Y 1645.6

Therefore for Enduro, NoisyLinear layer hurts the performance. I suspect the same would be true for other low-performing tasks.

I feel that I have exhausted ideas to explore further. I could add "--no-noisy" as a task-specific parameter for Enduro (and potentially more tasks with similar behavior) or I could keep the low-performing numbers for the sake of consistency.

What do you think? @Trinkle23897

@Trinkle23897
Copy link
Collaborator

That's fine for --no-noisy. But have you ever tried --training-num=4 ? I suspect it suffers from policy lag.

@nuance1979
Copy link
Collaborator Author

That's fine for --no-noisy. But have you ever tried --training-num=4 ? I suspect it suffers from policy lag.

I tried it before and it didn't make a difference. But things have changed a lot since then. I'll kick off an experiment to confirm.

@nuance1979
Copy link
Collaborator Author

That's fine for --no-noisy. But have you ever tried --training-num=4 ? I suspect it suffers from policy lag.

I tried it before and it didn't make a difference. But things have changed a lot since then. I'll kick off an experiment to confirm.

I got the result:

Experiment Priority Buffer NoisyLinear Dueling Best reward
Enduro C51 baseline (copied from README.md) N N N 1032
weight norm + beta from 0.4 to 1 Y Y Y 450.2
+ training_num_4 Y Y Y 957.3
- noisy (training_num_10) Y N Y 1645.6

Better than default ("--training-num=10") but still far worse than "--no-noisy".

@Trinkle23897
Copy link
Collaborator

okay that's fine. feel free to add --no-noisy

@nuance1979
Copy link
Collaborator Author

I might have found a bug in my code which could cause NoisyLinear not working correctly. I'm running experiments to confirm.

@Trinkle23897
Copy link
Collaborator

I have another unrelated issue: usually the atari network's feature part is end up by linear(3136, 512) instead of nn.flatten. I see only https://github.com/ku2482/fqf-iqn-qrdqn.pytorch use the latter setting, is that correct?

@nuance1979
Copy link
Collaborator Author

I have another unrelated issue: usually the atari network's feature part is end up by linear(3136, 512) instead of nn.flatten. I see only https://github.com/ku2482/fqf-iqn-qrdqn.pytorch use the latter setting, is that correct?

I think that's correct. Otherwise it will not match the input shape of the following linear layer. I found the following line, which is just another way of flattening, in the repo where there is no nn.Flatten in the model itself:

https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/model.py#L71

    x = x.view(-1, self.conv_output_size)

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Aug 24, 2021

well I mean in on-policy atari setting they use the following:

ActorCriticCnnPolicy(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=3136, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (action_net): Linear(in_features=512, out_features=6, bias=True)
  (value_net): Linear(in_features=512, out_features=1, bias=True)
)

Could you please double-check the author's implementation on a series of offline-rl atari settings you previously implemented, together with fqf and iqn?

@nuance1979
Copy link
Collaborator Author

well I mean in on-policy atari setting they use the following:

ActorCriticCnnPolicy(
  (features_extractor): NatureCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=3136, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (action_net): Linear(in_features=512, out_features=6, bias=True)
  (value_net): Linear(in_features=512, out_features=1, bias=True)
)

Could you please double-check the author's implementation on a series of offline-rl atari settings you previously implemented, together with fqf and iqn?

I see. The above structure essentially shared one more linear layer of (3136, 512) between action and value nets. I checked a few repos:

  • dqn_zoo does NOT share this layer. link1 and link2
  • Dopamine does. link1 and link2
  • Kaixhin's rainbow does NOT share this layer. link1 and link2
  • fqf-iqn-qrdqn's rainbow does NOT share this layer. link1 and link2
  • My current Rainbow implementation does NOT share this layer.

So for Rainbow's action and value nets, we are with the majority.

For IQN, the equivalent question is whether the CosNet has input size 3136 or 512 (i.e., after one linear layer of (3136, 512)). All the repos above has input size 3136, following the original paper. Same for FQF.

For offline-rl methods, CQL doesn't have this problem since no two heads share layers; CRR does have this issue since there are actor and critic nets. However, I can't find the reference implementation from the authors.

@nuance1979
Copy link
Collaborator Author

I might have found a bug in my code which could cause NoisyLinear not working correctly. I'm running experiments to confirm.

I have fixed a bug due to a misunderstanding of the role NoisyLinear plays in Rainbow model. Basically I disabled explore_noise when NoisyLinear layers were used. Therefore the bad results with NoisyLinear layer were mainly due to the absence of the exploration. After the fix, the results with NoisyLinear are comparable to, but not always better than, the results without. Since we only have a single run, I wouldn't draw any conclusion based on it. Now I will not use "--no-noisy" option for the reported results in README.md .

There is one exception: Seaquest. When I accidentally disabled explore_noise, I got much higher results (~16000 vs ~2300). I couldn't figure out the reason.

Anyway I think this PR is ready to merge.

@Trinkle23897
Copy link
Collaborator

Cool, I'll take a look this weekend

@Trinkle23897 Trinkle23897 merged commit 291be08 into thu-ml:master Aug 29, 2021
@nuance1979 nuance1979 deleted the rainbow branch October 6, 2021 17:27
@Trinkle23897
Copy link
Collaborator

There are some differences in the implementation of the function f(x). In this repo, the x = torch.randn(x.size(0), device=x.device) will override the input parameter x.

https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/11d70bb428e449fe5384654c05e4ab2c3bbdd4cd/fqf_iqn_qrdqn/network.py#L218-219

def f(self, x: torch.Tensor) -> torch.Tensor:
x = torch.randn(x.size(0), device=x.device)
return x.sign().mul_(x.abs().sqrt_())

@nuance1979 is that the cause of bad performance? Could you please check it when you're free? Many thanks!

BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
- add RainbowPolicy
- add `set_beta` method in prio_buffer
- add NoisyLinear in utils/network
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rainbow DQN Noisy network implementation
4 participants