TorchStateActionCritic¶
-
class
maze.core.agent.torch_state_action_critic.
TorchStateActionCritic
(networks: Mapping[Union[str, int], torch.nn.Module], num_policies: int, device: str, only_discrete_spaces: Dict[Union[str, int], bool], action_spaces_dict: Dict[Union[str, int], gym.spaces.Dict])¶ Encapsulates multiple torch state action critics for training in structured environments.
- Parameters
networks – Mapping of value functions (critic) to encapsulate.
num_policies – The number of corresponding policies.
device – Device the policy should be located on (cpu or cuda)
only_discrete_spaces – A dict specifying if the action spaces w.r.t. the step only hold discrete action spaces.
-
compute_state_action_value_step
(observation: Dict[str, torch.Tensor], action: Dict[str, torch.Tensor], critic_id: Union[str, int, tuple]) → List[torch.Tensor]¶ Predict the value with specified step_key, step_observation and action.
- Parameters
observation – The observation for the current step.
action – The action performed at the current step.
critic_id – The current step key of the multi-step env.
- Returns
A list of tensors holding the predicted q value for each critic.
-
compute_state_action_values_step
(observation: Dict[str, torch.Tensor], critic_id: Union[str, int, tuple]) → List[Dict[str, torch.Tensor]]¶ Predict the value with specified step_key, step_observation and action for discrete actions only.
- Parameters
observation – The observation for the current step.
critic_id – The current step key of the multi-step env.
- Returns
A list of dicts holding the predicted q value for each action w.r.t. to the critic.
-
property
device
¶ implementation of
TorchModel
-
eval
() → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
abstract property
num_critics
¶ Returns the number of critic networks. :return: Number of critic networks.
-
parameters
() → List[torch.Tensor]¶ (overrides
TorchModel
)implementation of
TorchModel
-
per_critic_parameters
() → List[List[torch.Tensor]]¶ Retrieve all trainable critic parameters (to be assigned to optimizers). :return: List of lists holding all parameters for the base critic corresponding to number of critic per step.
-
abstract
predict_next_q_values
(next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) → Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]¶ Predict the target q value for the next step. \(V (st) := E_{at∼π}[Q(st, at) − α log(π(at |st))]\).
- Parameters
next_observations – The next observations.
next_actions – The next actions sampled from the policy.
next_actions_logits – The logits of the next actions (only relevantt for the discrete case).
next_actions_log_probs – The log probabilities of the actions.
alpha – The alpha or entropy coefficient for each step.
- Returns
A dict w.r.t. the step holding tensors representing the predicted next q value
-
abstract
predict_q_values
(observations: Dict[Union[str, int], Dict[str, torch.Tensor]], actions: Dict[Union[str, int], Dict[str, torch.Tensor]], gather_output: bool) → Dict[Union[str, int], List[Union[torch.Tensor, Dict[str, torch.Tensor]]]]¶ (overrides
StateActionCritic
)implementation of
StateActionCritic
-
state_dict
() → Dict¶ (overrides
TorchModel
)implementation of
TorchModel
-
to
(device: str) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
train
() → None¶ (overrides
TorchModel
)implementation of
TorchModel