这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix DiscreteSACExperimentBuilder not exposing with_actor_factory_default #1250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Changelog

## Release 1.2.0
## Unreleased

### Changes/Improvements

- trainer:
- Custom scoring now supported for selecting the best model. #1202
- highlevel:
- `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250

### Breaking Changes

Expand Down
11 changes: 11 additions & 0 deletions docs/04_contributing/04_contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ to install all relevant requirements in editable mode you can simply call
$ poetry install --with dev


Platform-Specific Configuration
-------------------------------

**Windows**:
Since the repository contains symbolic links, make sure this is supported:

* Enable Windows Developer Mode to allow symbolic links to be created: Search Start Menu for "Developer Settings" and enable "Developer Mode"
* Enable symbolic links for this repository: ``git config core.symlinks true``
* Re-checkout the current git state: ``git checkout .``


PEP8 Code Style Check and Formatting
----------------------------------------

Expand Down
28 changes: 25 additions & 3 deletions tianshou/highlevel/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,28 @@ def with_actor_factory_default(
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)


class _BuilderMixinActorFactory_DiscreteOnly(_BuilderMixinActorFactory):
"""Specialization of the actor mixin where only environments with discrete action spaces are supported."""

def __init__(self) -> None:
super().__init__(ContinuousActorType.UNSUPPORTED)

def with_actor_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Defines use of the default actor factory, allowing its parameters it to be customized.

The default actor factory uses an MLP-style architecture.

:param hidden_sizes: dimensions of hidden layers used by the network
:param hidden_activation: the activation function to use for hidden layers
:return: the builder
"""
return super()._with_actor_factory_default(hidden_sizes, hidden_activation)


class _BuilderMixinCriticsFactory:
def __init__(self, num_critics: int, actor_future_provider: ActorFutureProviderProtocol):
self._actor_future_provider = actor_future_provider
Expand Down Expand Up @@ -959,7 +981,7 @@ def with_critic2_factory_default(
return self

def with_critic2_factory_use_actor(self) -> Self:
"""Makes the first critic reuse the actor's preprocessing network (parameter sharing)."""
"""Makes the second critic reuse the actor's preprocessing network (parameter sharing)."""
return self._with_critic_factory_use_actor(1)


Expand Down Expand Up @@ -1333,7 +1355,7 @@ def _create_agent_factory(self) -> AgentFactory:

class DiscreteSACExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory,
_BuilderMixinActorFactory_DiscreteOnly,
_BuilderMixinDualCriticFactory,
):
def __init__(
Expand All @@ -1343,7 +1365,7 @@ def __init__(
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
_BuilderMixinActorFactory_DiscreteOnly.__init__(self)
_BuilderMixinDualCriticFactory.__init__(self, self)
self._params: DiscreteSACParams = DiscreteSACParams()

Expand Down