def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
if self.game_over:
return self._get_observation(), 0.0, True, False, self._get_info()
self.steps += 1
# Convert action to direction (ensure snake doesn't reverse into itself)
new_direction = Direction(action)
# Prevent moving directly opposite to current direction
if (
(new_direction == Direction.UP and self.direction == Direction.DOWN)
or (new_direction == Direction.DOWN and self.direction == Direction.UP)
or (new_direction == Direction.LEFT and self.direction == Direction.RIGHT)
or (new_direction == Direction.RIGHT and self.direction == Direction.LEFT)
):
new_direction = self.direction
self.direction = new_direction
# Move snake
head_x, head_y = self.snake[-1]
if self.direction == Direction.UP:
new_head = (head_x - 1, head_y)
elif self.direction == Direction.RIGHT:
new_head = (head_x, head_y + 1)
elif self.direction == Direction.DOWN:
new_head = (head_x + 1, head_y)
elif self.direction == Direction.LEFT:
new_head = (head_x, head_y - 1)
# Check collision with walls
if (
new_head[0] < 0
or new_head[0] >= self.grid_size
or new_head[1] < 0
or new_head[1] >= self.grid_size
):
self.game_over = True
return self._get_observation(), -10.0, True, False, self._get_info()
# Check collision with self
if new_head in self.snake[:-1]:
self.game_over = True
return self._get_observation(), -10.0, True, False, self._get_info()
# Move snake
self.snake.append(new_head)
reward = 0.0
# Check if food eaten
if new_head == self.food:
self.score += 1
reward = 10.0
self.food = self._place_food()
else:
# Remove tail if no food eaten
self.snake.pop(0)
# Give negative reward to encourage faster completion
reward += -0.1
# Check if maximum steps reached
if self.steps >= self.max_steps:
self.game_over = True
return self._get_observation(), reward, True, False, self._get_info()
return self._get_observation(), reward, False, False, self._get_info()