这是indexloc提供的服务,不要输入任何密码
Skip to content

SACPolicy Unsupported action space: Box(-1.0, 1.0, (1,), float32)") #1232

@Vladimir19052002

Description

@Vladimir19052002

Description

Encountering a ValueError when initializing SACPolicy with a single-dimensional Box action space.

Reproduction

Minimal example script:

import torch
import torch.nn as nn
from gym.spaces import Box
import numpy as np
from tianshou.policy import SACPolicy
import torch.optim as optim
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('TestLogger')

# Define Actor and Critic Networks
class SimpleActor(nn.Module):
    def __init__(self, input_dim):
        super(SimpleActor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.network(x)

class SimpleCritic(nn.Module):
    def __init__(self, input_dim, action_dim):
        super(SimpleCritic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim + action_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def forward(self, x, a):
        return self.network(torch.cat([x, a], dim=-1))

# Define action_space with shape=(1,)
action_space = Box(low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float32)
logger.info(f"Test action_space: {action_space}")

# Initialize networks
actor = SimpleActor(input_dim=10)
critic = SimpleCritic(input_dim=10, action_dim=1)

# Initialize optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=1e-3)
critic_optimizer = optim.Adam(critic.parameters(), lr=1e-3)

# Initialize SACPolicy
try:
    policy = SACPolicy(
        actor=actor,
        actor_optim=actor_optimizer,
        critic=critic,
        critic_optim=critic_optimizer,
        action_space=action_space,
        tau=0.005,
        gamma=0.99,
        exploration_noise=0.1,
        action_scaling=False  # Ensuring no redundant scaling
    )
    logger.info("SACPolicy initialized successfully in minimal example.")
except ValueError as ve:
    logger.error(f"SACPolicy initialization failed in minimal example: {ve}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions