"""Basic training utilities for vertiport autonomy agents."""
import os
from typing import Any, Dict, Optional
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from ..config.loader import load_scenario_config
from ..core.environment import VertiportEnv
[docs]
class Trainer:
"""Basic trainer for PPO agents in vertiport environments."""
[docs]
def __init__(
self,
log_dir: str = "logs",
model_dir: str = "models",
n_envs: int = 50,
**ppo_kwargs,
):
"""Initialize the trainer.
Args:
log_dir: Directory for training logs
model_dir: Directory for saving models
n_envs: Number of parallel environments
**ppo_kwargs: Additional arguments for PPO
"""
self.log_dir = log_dir
self.model_dir = model_dir
self.n_envs = n_envs
self.ppo_kwargs = ppo_kwargs
# Create directories
os.makedirs(self.log_dir, exist_ok=True)
os.makedirs(self.model_dir, exist_ok=True)
# Default PPO parameters
self.default_ppo_params = {
"gamma": 0.99,
"n_steps": 1024,
"batch_size": 128,
"learning_rate": 1e-4,
"gae_lambda": 0.95,
"clip_range": 0.2,
"ent_coef": 0.01,
"vf_coef": 0.5,
"max_grad_norm": 0.5,
"policy_kwargs": {"net_arch": [64, 64], "activation_fn": torch.nn.Tanh},
}
[docs]
def create_environment(self, scenario_path: str) -> VecNormalize:
"""Create a vectorized and normalized environment.
Args:
scenario_path: Path to scenario configuration file
Returns:
Normalized vectorized environment
"""
config = load_scenario_config(scenario_path)
# Create vectorized environment
env = make_vec_env(
VertiportEnv, n_envs=self.n_envs, env_kwargs={"config": config}
)
# Normalize environment
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
return env
[docs]
def create_model(self, env: VecNormalize, **override_params) -> PPO:
"""Create a PPO model.
Args:
env: Environment for training
**override_params: Parameters to override defaults
Returns:
PPO model instance
"""
# Merge parameters
params = {**self.default_ppo_params, **self.ppo_kwargs, **override_params}
model = PPO(
"MultiInputPolicy", env, verbose=1, tensorboard_log=self.log_dir, **params
)
return model
[docs]
def create_callbacks(
self,
save_freq: int = 50000,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
name_prefix: str = "ppo_vertiport",
) -> list:
"""Create training callbacks.
Args:
save_freq: Frequency for saving checkpoints
eval_freq: Frequency for evaluation
n_eval_episodes: Number of episodes for evaluation
name_prefix: Prefix for saved model names
Returns:
List of callbacks
"""
checkpoint_callback = CheckpointCallback(
save_freq=save_freq, save_path=self.model_dir, name_prefix=name_prefix
)
# Note: EvalCallback needs a separate environment
# This is a simplified version - in practice you'd want a separate eval env
callbacks = [checkpoint_callback]
return callbacks
[docs]
def train(
self,
scenario_path: str,
total_timesteps: int,
tb_log_name: str = "PPO_Vertiport",
save_final: bool = True,
final_model_name: str = "ppo_vertiport_final",
**model_params,
) -> PPO:
"""Train a PPO agent.
Args:
scenario_path: Path to scenario configuration
total_timesteps: Total training timesteps
tb_log_name: TensorBoard log name
save_final: Whether to save final model
final_model_name: Name for final model
**model_params: Additional model parameters
Returns:
Trained PPO model
"""
print("--- Starting Training ---")
print(f"Scenario: {scenario_path}")
print(f"Total timesteps: {total_timesteps:,}")
print(f"Parallel environments: {self.n_envs}")
# Create environment and model
env = self.create_environment(scenario_path)
model = self.create_model(env, **model_params)
# Create callbacks
callbacks = self.create_callbacks()
# Train the model
model.learn(
total_timesteps=total_timesteps, callback=callbacks, tb_log_name=tb_log_name
)
print("--- Training Finished ---")
# Save final model
if save_final:
final_path = os.path.join(self.model_dir, final_model_name)
model.save(final_path)
print(f"Final model saved to {final_path}")
return model
[docs]
def main():
"""Basic training entry point."""
trainer = Trainer()
# Train with default parameters
model = trainer.train(
scenario_path="scenarios/steady_flow.yaml", total_timesteps=5000000
)
print("Training complete!")
if __name__ == "__main__":
main()