Source code for vertiport_autonomy.agents.heuristic

"""
Simplified Heuristic Baseline Agent

Provides a working baseline for comparison with DRL agents.
Uses simple FCFS strategy with basic conflict avoidance.
"""

from typing import Any, Dict

import numpy as np

from ..config.loader import load_scenario_config
from ..core.environment import VertiportEnv
from .base import BaseAgent


[docs] class SimpleHeuristicAgent(BaseAgent): """Simple heuristic agent that uses FCFS strategy."""
[docs] def __init__(self, name: str = "SimpleHeuristic"): """Initialize the simple heuristic agent.""" super().__init__(name)
[docs] def act(self, observation: Dict[str, Any]) -> np.ndarray: """Select actions based on simple heuristic strategy. Args: observation: Current environment observation Returns: Array of actions for each drone """ # Simple strategy: all drones move to waypoint num_drones = observation["drones_state"].shape[0] actions = np.ones(num_drones, dtype=np.int32) # All MOVE_TO_WAYPOINT return actions
[docs] def reset(self) -> None: """Reset agent state for a new episode.""" # Simple heuristic doesn't maintain state pass
[docs] def run_simple_heuristic(scenario_path: str, max_steps: int = 200): """Run simple heuristic agent on given scenario""" print(f"\n🤖 Testing Simple Heuristic: {scenario_path}") config = load_scenario_config(scenario_path) env = VertiportEnv(config) obs, info = env.reset() total_reward = 0 step_count = 0 while step_count < max_steps: # Simple strategy: all drones move to waypoint num_drones = obs["drones_state"].shape[0] actions = np.ones(num_drones, dtype=np.int32) # All MOVE_TO_WAYPOINT obs, reward, terminated, truncated, info = env.step(actions) total_reward += reward step_count += 1 if terminated or truncated: break env.close() return { "episode_length": step_count, "total_reward": total_reward, "average_reward": total_reward / step_count if step_count > 0 else 0, }
[docs] def main(): """Test simple heuristic on all scenarios""" scenarios = [ "scenarios/easy_world.yaml", "scenarios/intermediate_world.yaml", "scenarios/steady_flow.yaml", ] print("🎯 Simple Heuristic Baseline Results") print("=" * 50) for scenario in scenarios: try: metrics = run_simple_heuristic(scenario) print(f"📊 {scenario}:") print(f" Episode Length: {metrics['episode_length']}") print(f" Total Reward: {metrics['total_reward']:.2f}") print(f" Average Reward: {metrics['average_reward']:.2f}") except Exception as e: print(f"❌ Error: {e}")
if __name__ == "__main__": main()