Skip to content

viberl.networks.value_network

Classes:

Name Description
VNetwork

Value network for PPO and other policy gradient methods (returns single scalar value for state).

QNetwork

Q-network for value-based methods like DQN (returns Q-values for all actions).

VNetwork

VNetwork(state_size: int, hidden_size: int = 128, num_hidden_layers: int = 2)

Bases: BaseNetwork

Value network for PPO and other policy gradient methods (returns single scalar value for state).

Methods:

Name Description
forward

Forward pass to get state value.

Attributes:

Name Type Description
value_head
Source code in viberl/networks/value_network.py
10
11
12
13
14
15
def __init__(self, state_size: int, hidden_size: int = 128, num_hidden_layers: int = 2):
    super().__init__(state_size, hidden_size, num_hidden_layers)

    # Single output for state value
    self.value_head = nn.Linear(hidden_size, 1)
    self.init_weights()

value_head instance-attribute

value_head = Linear(hidden_size, 1)

forward

forward(x: Tensor) -> Tensor

Forward pass to get state value.

Parameters:

Name Type Description Default
x Tensor

Input state tensor of shape (batch_size, state_size)

required

Returns:

Type Description
Tensor

State value tensor of shape (batch_size,) representing the

Tensor

estimated value of the given state

Source code in viberl/networks/value_network.py
17
18
19
20
21
22
23
24
25
26
27
28
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to get state value.

    Args:
        x: Input state tensor of shape (batch_size, state_size)

    Returns:
        State value tensor of shape (batch_size,) representing the
        estimated value of the given state
    """
    features = self.forward_backbone(x)
    return self.value_head(features).squeeze(-1)  # Remove last dim

QNetwork

QNetwork(state_size: int, action_size: int, hidden_size: int = 128, num_hidden_layers: int = 2)

Bases: BaseNetwork

Q-network for value-based methods like DQN (returns Q-values for all actions).

Methods:

Name Description
forward

Forward pass to get Q-values for all actions.

get_q_values

Get Q-values for a given state.

get_action

Get action using epsilon-greedy policy.

Attributes:

Name Type Description
action_size
q_head
Source code in viberl/networks/value_network.py
34
35
36
37
38
39
40
41
42
43
def __init__(
    self, state_size: int, action_size: int, hidden_size: int = 128, num_hidden_layers: int = 2
):
    super().__init__(state_size, hidden_size, num_hidden_layers)
    self.action_size = action_size

    # Q-value head for all actions
    self.q_head = nn.Linear(hidden_size, action_size)

    self.init_weights()

action_size instance-attribute

action_size = action_size

q_head instance-attribute

q_head = Linear(hidden_size, action_size)

forward

forward(x: Tensor) -> Tensor

Forward pass to get Q-values for all actions.

Parameters:

Name Type Description Default
x Tensor

Input state tensor of shape (batch_size, state_size)

required

Returns:

Type Description
Tensor

Q-values tensor of shape (batch_size, action_size) containing

Tensor

Q-values for each action in the given state

Source code in viberl/networks/value_network.py
45
46
47
48
49
50
51
52
53
54
55
56
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass to get Q-values for all actions.

    Args:
        x: Input state tensor of shape (batch_size, state_size)

    Returns:
        Q-values tensor of shape (batch_size, action_size) containing
        Q-values for each action in the given state
    """
    features = self.forward_backbone(x)
    return self.q_head(features)

get_q_values

get_q_values(state: list | tuple | Tensor) -> Tensor

Get Q-values for a given state.

Convenience method that handles various input types and ensures proper tensor formatting before forward pass.

Parameters:

Name Type Description Default
state list | tuple | Tensor

Current state as list, tuple, or tensor

required

Returns:

Type Description
Tensor

Q-values tensor of shape (1, action_size) if single state,

Tensor

or (batch_size, action_size) if batch of states

Source code in viberl/networks/value_network.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def get_q_values(self, state: list | tuple | torch.Tensor) -> torch.Tensor:
    """Get Q-values for a given state.

    Convenience method that handles various input types and ensures
    proper tensor formatting before forward pass.

    Args:
        state: Current state as list, tuple, or tensor

    Returns:
        Q-values tensor of shape (1, action_size) if single state,
        or (batch_size, action_size) if batch of states
    """
    if isinstance(state, list | tuple):
        state = torch.FloatTensor(state)
    else:
        state = torch.FloatTensor(state)

    if len(state.shape) == 1:
        state = state.unsqueeze(0)

    return self.forward(state)

get_action

get_action(state: list | tuple | Tensor, epsilon: float = 0.0) -> int

Get action using epsilon-greedy policy.

Implements the epsilon-greedy action selection strategy where: - With probability epsilon: choose random action (exploration) - With probability 1-epsilon: choose best action (exploitation)

Parameters:

Name Type Description Default
state list | tuple | Tensor

Current state as list, tuple, or tensor

required
epsilon float

Probability of choosing random action (0.0 to 1.0)

0.0

Returns:

Type Description
int

Selected action as integer

Source code in viberl/networks/value_network.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def get_action(self, state: list | tuple | torch.Tensor, epsilon: float = 0.0) -> int:
    """Get action using epsilon-greedy policy.

    Implements the epsilon-greedy action selection strategy where:
    - With probability epsilon: choose random action (exploration)
    - With probability 1-epsilon: choose best action (exploitation)

    Args:
        state: Current state as list, tuple, or tensor
        epsilon: Probability of choosing random action (0.0 to 1.0)

    Returns:
        Selected action as integer
    """
    q_values = self.get_q_values(state)

    if torch.rand(1).item() < epsilon:
        return torch.randint(0, self.action_size, (1,)).item()
    else:
        return q_values.argmax(dim=1).item()