Source code for sequifier.hyperparameter_search

import ctypes
import json
import os
import signal
import subprocess
import sys
import time
import warnings
from typing import Union

import optuna
import torch._dynamo
import yaml
from beartype import beartype

torch._dynamo.config.suppress_errors = True
from sequifier.config.hyperparameter_search_config import (  # noqa: E402
    load_hyperparameter_search_config,
)
from sequifier.helpers import (  # noqa: E402
    get_best_model_path,
    get_last_training_batch_timedelta,
)
from sequifier.io.yaml import TrainModelDumper  # noqa: E402


[docs]def set_pdeathsig(): """Binds child process lifecycle to the parent orchestrator via Linux prctl.""" if sys.platform.startswith("linux"): libc = ctypes.CDLL("libc.so.6") libc.prctl(1, signal.SIGTERM) # PR_SET_PDEATHSIG = 1
[docs]def objective(trial: optuna.Trial, config) -> Union[float, tuple[float, ...]]: """The central objective engine bridging Optuna to pure CLI execution. This function handles generating the YAML configuration for the specific trial, dynamically allocating a port for distributed training, launching the training subprocess, asynchronously polling the validation metrics, and reporting them back to Optuna for potential pruning. Args: trial (optuna.Trial): The Optuna trial object managing the current hyperparameter combination. config (HyperparameterSearchConfig): The parsed hyperparameter search configuration. Returns: float: The best validation loss achieved during the trial. Raises: optuna.TrialPruned: If the trial is pruned by the Optuna orchestrator. RuntimeError: If the training subprocess fails or is externally preempted. """ run_config = config.sample_trial(trial, trial.number) run_name = run_config.model_name # 1. YAML Generation config_path = os.path.join( config.project_root, config.model_config_write_path, f"{run_name}.yaml" ) os.makedirs(os.path.dirname(config_path), exist_ok=True) with open(config_path, "w") as f: yaml.dump(run_config, f, Dumper=TrainModelDumper, sort_keys=False) os.environ["SEQUIFIER_HYPERPARAMETER_SEARCH_RUN"] = "1" env = os.environ.copy() cmd = ["sequifier", "train", f"--config-path={config_path}"] process = subprocess.Popen( cmd, env=env, preexec_fn=set_pdeathsig if sys.platform.startswith("linux") else None, ) metrics_path = os.path.join( config.project_root, "logs", f"sequifier-{run_name}-metrics.jsonl" ) prune_path = os.path.join( config.project_root, "logs", f"sequifier-{run_name}.prune" ) last_read_pos = 0 best_val_loss = float("inf") def consume_metrics(last_read_pos: int, best_val_loss: float) -> tuple[int, float]: """Helper closure to read written metrics and evaluate pruning.""" if os.path.exists(metrics_path): with open(metrics_path, "r") as f: f.seek(last_read_pos) while True: line = f.readline() if not line or (not line.endswith("\n")): break # Reached end of currently written data try: data = json.loads(line) val_loss = data.get("val_loss") global_step = data.get("global_step") if global_step is not None and val_loss is not None: # 5. Cooperative Pruning Evaluation is_multi_objective = ( config.evaluation_metrics is not None and len(config.evaluation_metrics) > 1 ) if not is_multi_objective: trial.report(val_loss, global_step) best_val_loss = min(best_val_loss, val_loss) if config.prune_trials and trial.should_prune(): open(prune_path, "w").close() try: try: timedelta = ( get_last_training_batch_timedelta( run_name, 0, config.project_root ) ) timeout_val = (timedelta * 2) + 30 except (ValueError, FileNotFoundError): timeout_val = 60.0 # Safe default fallback process.wait(timeout=timeout_val) except subprocess.TimeoutExpired: process.kill() # Escalation raise optuna.TrialPruned() last_read_pos = f.tell() except json.JSONDecodeError: break return last_read_pos, best_val_loss # 4. Asynchronous Polling & Caching Mitigation while process.poll() is None: last_read_pos, best_val_loss = consume_metrics(last_read_pos, best_val_loss) time.sleep(2) _, best_val_loss = consume_metrics(last_read_pos, best_val_loss) exit_code = process.returncode if exit_code == 143: if os.path.exists(prune_path): raise optuna.TrialPruned() else: raise RuntimeError( f"Trial pre-empted externally by cluster (SIGTERM). Exit code: {exit_code}" ) elif exit_code != 0: raise RuntimeError(f"Training failed with exit code {exit_code}") model_type = "onnx" if run_config.export_onnx else "pt" model_path, last_epoch = get_best_model_path( config.project_root, run_name, model_type ) if config.evaluation_inference_config: subprocess.run( [ "sequifier", "infer", f"--config-path={config.evaluation_inference_config}", f"--model-path={model_path}", ], check=True, ) if config.evaluation_script and config.evaluation_metrics: eval_script_path = config.evaluation_script cmd = [sys.executable, eval_script_path, f"{run_name}-best-{last_epoch}"] eval_process = subprocess.run( cmd, capture_output=True, text=True, cwd=config.project_root ) if eval_process.returncode != 0: raise RuntimeError( f"Evaluation script failed (exit code {eval_process.returncode}):\n{eval_process.stderr}" ) eval_json_path = os.path.join( config.project_root, "outputs", "evaluations", f"{run_name}-best-{last_epoch}.json", ) if not os.path.exists(eval_json_path): raise FileNotFoundError( f"Evaluation JSON not found at expected path: {eval_json_path}" ) with open(eval_json_path, "r") as f: eval_results = json.load(f) eval_results_keys = set(list(eval_results.keys())) evaluation_metrics = set(config.evaluation_metrics) missing_metrics = evaluation_metrics.difference(eval_results_keys) excess_metrics = eval_results_keys.difference(evaluation_metrics) if len(missing_metrics): raise ValueError( f"Some of the configured evaluation metrics are not in the script output: {missing_metrics}" ) if len(excess_metrics): warnings.warn( f"Some metrics output by the script are not used in hyperparameter optimization: {excess_metrics}" ) metrics = [] for metric in config.evaluation_metrics: if metric not in eval_results: raise KeyError( f"Metric '{metric}' missing in {eval_json_path}. Found keys: {list(eval_results.keys())}" ) metrics.append(float(eval_results[metric])) if len(metrics) == 1: return metrics[0] else: return tuple(metrics) return best_val_loss