Skip to content

viberl.networks.policy_network

Classes:

Name Description
PolicyNetwork

Policy network for policy gradient methods like REINFORCE.

PolicyNetwork

PolicyNetwork(
    state_size: int, action_size: int, hidden_size: int = 128, num_hidden_layers: int = 2
)

Bases: BaseNetwork

Policy network for policy gradient methods like REINFORCE.

Methods:

Name Description
forward

Forward pass to get action probabilities.

act

Select action based on current policy.

get_action_prob

Get probability of taking a specific action.

Attributes:

Name Type Description
action_size
policy_head
softmax
Source code in viberl/networks/policy_network.py
11
12
13
14
15
16
17
18
19
20
21
def __init__(
    self, state_size: int, action_size: int, hidden_size: int = 128, num_hidden_layers: int = 2
):
    super().__init__(state_size, hidden_size, num_hidden_layers)
    self.action_size = action_size

    # Policy head
    self.policy_head = nn.Linear(hidden_size, action_size)
    self.softmax = nn.Softmax(dim=-1)

    self.init_weights()

action_size instance-attribute

action_size = action_size

policy_head instance-attribute

policy_head = Linear(hidden_size, action_size)

softmax instance-attribute

softmax = Softmax(dim=-1)

forward

forward(x: Tensor) -> Tensor

Forward pass to get action probabilities.

Processes state features through the backbone network and policy head to produce normalized action probabilities.

Parameters:

Name Type Description Default
x Tensor

Input state tensor of shape (batch_size, state_size)

required

Returns:

Type Description
Tensor

Action probabilities tensor of shape (batch_size, action_size)

Tensor

with values summing to 1 along the action dimension

Source code in viberl/networks/policy_network.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to get action probabilities.

    Processes state features through the backbone network and policy head
    to produce normalized action probabilities.

    Args:
        x: Input state tensor of shape (batch_size, state_size)

    Returns:
        Action probabilities tensor of shape (batch_size, action_size)
        with values summing to 1 along the action dimension
    """
    features = self.forward_backbone(x)
    action_logits = self.policy_head(features)
    return self.softmax(action_logits)

act

act(state: list | tuple | Tensor, deterministic: bool = False) -> int

Select action based on current policy.

Parameters:

Name Type Description Default
state list | tuple | Tensor

Current state as list, tuple, or tensor

required
deterministic bool

If True, always returns the most probable action. If False, samples from the action distribution.

False

Returns:

Type Description
int

Selected action as integer

Source code in viberl/networks/policy_network.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def act(self, state: list | tuple | torch.Tensor, deterministic: bool = False) -> int:
    """Select action based on current policy.

    Args:
        state: Current state as list, tuple, or tensor
        deterministic: If True, always returns the most probable action.
                      If False, samples from the action distribution.

    Returns:
        Selected action as integer
    """
    if isinstance(state, list | tuple):
        state = torch.FloatTensor(state)
    else:
        state = torch.FloatTensor(state).unsqueeze(0)

    action_probs = self.forward(state)

    if deterministic:
        return action_probs.argmax().item()
    else:
        m = Categorical(action_probs)
        return m.sample().item()

get_action_prob

get_action_prob(state: list | tuple | Tensor, action: Tensor) -> Tensor

Get probability of taking a specific action.

Parameters:

Name Type Description Default
state list | tuple | Tensor

Current state as list, tuple, or tensor

required
action Tensor

Action tensor to get probability for

required

Returns:

Type Description
Tensor

Probability of taking the specified action

Source code in viberl/networks/policy_network.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_action_prob(
    self, state: list | tuple | torch.Tensor, action: torch.Tensor
) -> torch.Tensor:
    """Get probability of taking a specific action.

    Args:
        state: Current state as list, tuple, or tensor
        action: Action tensor to get probability for

    Returns:
        Probability of taking the specified action
    """
    action_probs = self.forward(state)
    return action_probs.gather(1, action.unsqueeze(1)).squeeze(1)