TorchActorCritic¶
-
class
maze.core.agent.torch_actor_critic.
TorchActorCritic
(policy: maze.core.agent.torch_policy.TorchPolicy, critic: Union[maze.core.agent.torch_state_critic.TorchStateCritic, maze.core.agent.torch_state_action_critic.TorchStateActionCritic], device: str)¶ Encapsulates a structured torch policy and critic for training actor-critic algorithms in structured environments.
- Parameters
policy – A structured torch policy for training in structured environments.
critic – A structured torch critic for training in structured environments.
device – Device the model (networks) should be located on (cpu or cuda)
-
compute_actor_critic_output
(record: maze.core.trajectory_recording.records.structured_spaces_record.StructuredSpacesRecord, temperature: float = 1.0) → Tuple[maze.core.agent.torch_policy_output.PolicyOutput, maze.core.agent.state_critic_input_output.StateCriticOutput]¶ One method to compute the policy and critic output in one go, managing the sub-steps, individual critic types shared embeddings of networks.
- Parameters
record – The StructuredSpacesRecord holding the observation and actor ids.
temperature – (Optional) The temperature used for initializing the probability distribution of the action heads.
- Returns
A tuple of the policy and critic output.
-
property
device
¶ implementation of
TorchModel
-
eval
() → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
load_state_dict
(state_dict: Dict) → None¶ (overrides
TorchModel
)implementation of
TorchModel
-
parameters
() → List[torch.Tensor]¶ (overrides
TorchModel
)implementation of
TorchModel
-
state_dict
() → Dict¶ (overrides
TorchModel
)implementation of
TorchModel
-
to
(device: str)¶ (overrides
TorchModel
)implementation of
TorchModel
-
train
() → None¶ (overrides
TorchModel
)implementation of
TorchModel