viberl.networks.policy_network
Classes:
Name | Description |
---|---|
PolicyNetwork |
Policy network for policy gradient methods like REINFORCE. |
PolicyNetwork
PolicyNetwork(
state_size: int, action_size: int, hidden_size: int = 128, num_hidden_layers: int = 2
)
Bases: BaseNetwork
Policy network for policy gradient methods like REINFORCE.
Methods:
Name | Description |
---|---|
forward |
Forward pass to get action probabilities. |
act |
Select action based on current policy. |
get_action_prob |
Get probability of taking a specific action. |
Attributes:
Name | Type | Description |
---|---|---|
action_size |
|
|
policy_head |
|
|
softmax |
|
Source code in viberl/networks/policy_network.py
11 12 13 14 15 16 17 18 19 20 21 |
|
action_size
instance-attribute
action_size = action_size
policy_head
instance-attribute
policy_head = Linear(hidden_size, action_size)
softmax
instance-attribute
softmax = Softmax(dim=-1)
forward
forward(x: Tensor) -> Tensor
Forward pass to get action probabilities.
Processes state features through the backbone network and policy head to produce normalized action probabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input state tensor of shape (batch_size, state_size) |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Action probabilities tensor of shape (batch_size, action_size) |
Tensor
|
with values summing to 1 along the action dimension |
Source code in viberl/networks/policy_network.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
act
act(state: list | tuple | Tensor, deterministic: bool = False) -> int
Select action based on current policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
list | tuple | Tensor
|
Current state as list, tuple, or tensor |
required |
deterministic
|
bool
|
If True, always returns the most probable action. If False, samples from the action distribution. |
False
|
Returns:
Type | Description |
---|---|
int
|
Selected action as integer |
Source code in viberl/networks/policy_network.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
get_action_prob
get_action_prob(state: list | tuple | Tensor, action: Tensor) -> Tensor
Get probability of taking a specific action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
list | tuple | Tensor
|
Current state as list, tuple, or tensor |
required |
action
|
Tensor
|
Action tensor to get probability for |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Probability of taking the specified action |
Source code in viberl/networks/policy_network.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|