Spaces:
Runtime error
Runtime error
Update Feather training runtime image
Browse files- Dockerfile +7 -19
- entrypoint.py +9 -67
- mamba_ssm_init.py +3 -35
- overlay/configs/harness_config.py +47 -47
- overlay/harness/eval_agent.py +188 -188
- overlay/harness/orchestrator.py +16 -16
- overlay/htm_rust/src/gpu/fused.rs +73 -73
- overlay/hydra/eval.py +1 -8
- overlay/hydra/model.py +296 -296
- overlay/hydra/training.py +387 -387
- overlay/prepare.py +60 -60
- overlay/prepare_nemotron.py +159 -162
- overlay/scripts/__init__.py +1 -1
- overlay/scripts/audit_overlay_sync.py +100 -100
- overlay/scripts/benchmark_assets.py +62 -124
- overlay/scripts/benchmark_checkpoint.py +19 -118
- overlay/scripts/benchmark_checkpoint_report.py +50 -50
- overlay/scripts/benchmark_contract.py +67 -67
- overlay/scripts/benchmark_datasets.py +18 -190
- overlay/scripts/benchmark_hyena_stack.py +41 -66
- overlay/scripts/benchmark_preflight.py +31 -35
- overlay/scripts/benchmark_runner.py +248 -327
- overlay/scripts/benchmark_suite.py +84 -84
- overlay/scripts/bootstrap_benchmark_env.py +63 -63
- overlay/scripts/cycle1a_report.py +52 -52
- overlay/scripts/cycle_executor.py +312 -332
- overlay/scripts/export_hpo_priors.py +94 -94
- overlay/scripts/hf_routing.py +94 -94
- overlay/scripts/hpo_component_report.py +130 -130
- overlay/scripts/hpo_leaderboard.py +156 -156
- overlay/scripts/hpo_orchestrator.py +118 -118
- overlay/scripts/hpo_retest.py +151 -151
- overlay/scripts/hydra_generation.py +180 -183
- overlay/scripts/launch_benchmark_hf_job.py +157 -222
- overlay/scripts/launch_feather_hf_job.py +337 -343
- overlay/scripts/optuna_hpo.py +575 -575
- overlay/scripts/run_cycle1a.py +45 -46
- overlay/scripts/setup.sh +0 -1
- overlay/scripts/sweep_depth_aggregate.py +184 -184
- overlay/scripts/watch_benchmark_hf_job.py +33 -81
- overlay/subsystems/htm.py +128 -128
Dockerfile
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
| 2 |
-
|
| 3 |
-
ARG HTM_CUDA_ARCH=sm_86
|
| 4 |
|
| 5 |
ENV DEBIAN_FRONTEND=noninteractive \
|
| 6 |
PIP_NO_CACHE_DIR=1 \
|
|
@@ -107,22 +105,12 @@ COPY overlay /workspace/feather
|
|
| 107 |
COPY entrypoint.py /app/entrypoint.py
|
| 108 |
WORKDIR /workspace/feather
|
| 109 |
|
| 110 |
-
RUN python -
|
| 111 |
-
|
| 112 |
-
for sh in Path('/workspace/feather/scripts').glob('*.sh'):
|
| 113 |
-
raw = sh.read_bytes()
|
| 114 |
-
norm = raw.replace(b'\r\n', b'\n')
|
| 115 |
-
if norm != raw:
|
| 116 |
-
sh.write_bytes(norm)
|
| 117 |
-
PY
|
| 118 |
-
|
| 119 |
-
RUN python -m py_compile hydra/training.py prepare.py train.py && \
|
| 120 |
-
bash -n scripts/run_domain_expanded_pretrain.sh
|
| 121 |
|
| 122 |
-
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
|
| 123 |
-
export HTM_CUDA_ARCH=
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
pip install htm_rust/target/wheels/htm_rust-*.whl
|
| 127 |
|
| 128 |
CMD ["python", "/app/entrypoint.py"]
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
|
|
|
|
|
|
| 2 |
|
| 3 |
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
PIP_NO_CACHE_DIR=1 \
|
|
|
|
| 105 |
COPY entrypoint.py /app/entrypoint.py
|
| 106 |
WORKDIR /workspace/feather
|
| 107 |
|
| 108 |
+
RUN python -m py_compile hydra/training.py prepare.py train.py && \
|
| 109 |
+
bash -n scripts/run_domain_expanded_pretrain.sh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
|
| 112 |
+
export HTM_CUDA_ARCH=sm_90 && \
|
| 113 |
+
maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
|
| 114 |
+
pip install htm_rust/target/wheels/htm_rust-*.whl
|
|
|
|
| 115 |
|
| 116 |
CMD ["python", "/app/entrypoint.py"]
|
entrypoint.py
CHANGED
|
@@ -68,11 +68,7 @@ try:
|
|
| 68 |
except ImportError:
|
| 69 |
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
|
| 70 |
|
| 71 |
-
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
|
| 72 |
-
if '/workspace/feather' not in sys.path: # noqa: E402
|
| 73 |
-
sys.path.insert(0, '/workspace/feather')
|
| 74 |
-
from scripts.benchmark_assets import hydrate_benchmark_assets # noqa: E402
|
| 75 |
-
from subsystems.sdr_retina import build_retina # noqa: E402
|
| 76 |
|
| 77 |
REPO_ROOT = Path('/workspace/feather')
|
| 78 |
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
|
|
@@ -114,7 +110,7 @@ def _start_health_server() -> HTTPServer:
|
|
| 114 |
return server
|
| 115 |
|
| 116 |
|
| 117 |
-
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
|
| 118 |
if not path.exists():
|
| 119 |
print(f'[upload] skip missing {path}', flush=True)
|
| 120 |
return
|
|
@@ -124,20 +120,7 @@ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
|
|
| 124 |
repo_id=OUTPUT_REPO,
|
| 125 |
repo_type='model',
|
| 126 |
)
|
| 127 |
-
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def build_benchmark_mode_command() -> list[str]:
|
| 131 |
-
return [
|
| 132 |
-
'python',
|
| 133 |
-
str(REPO_ROOT / 'scripts' / 'benchmark_runner.py'),
|
| 134 |
-
'--benchmark', os.environ.get('HYDRA_BENCHMARK_NAME', 'GSM8K'),
|
| 135 |
-
'--generator-mode', 'hydra',
|
| 136 |
-
'--variant', os.environ.get('HYDRA_BENCHMARK_VARIANT', 'hydra_full'),
|
| 137 |
-
'--seed', os.environ.get('HYDRA_SEED', '42'),
|
| 138 |
-
'--out', str(REPO_ROOT / 'benchmark_result.json'),
|
| 139 |
-
'--ledger', str(REPO_ROOT / 'benchmark_ledger.json'),
|
| 140 |
-
] + sys.argv[1:]
|
| 141 |
|
| 142 |
|
| 143 |
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
|
|
@@ -175,7 +158,7 @@ def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
|
|
| 175 |
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
|
| 176 |
|
| 177 |
|
| 178 |
-
def run_job_mode() -> int:
|
| 179 |
os.chdir(REPO_ROOT)
|
| 180 |
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
|
| 181 |
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
|
|
@@ -220,46 +203,7 @@ def run_job_mode() -> int:
|
|
| 220 |
else:
|
| 221 |
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
|
| 222 |
|
| 223 |
-
return proc.returncode
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def run_benchmark_mode() -> int:
|
| 227 |
-
os.chdir(REPO_ROOT)
|
| 228 |
-
os.environ.setdefault('HYDRA_USE_NEMOTRON', '1')
|
| 229 |
-
if TOKEN:
|
| 230 |
-
try:
|
| 231 |
-
hydrate_benchmark_assets(
|
| 232 |
-
cache_dir=CACHE_ROOT,
|
| 233 |
-
output_repo=OUTPUT_REPO,
|
| 234 |
-
tokenizer_repo=os.environ.get('HYDRA_TOKENIZER_CACHE_REPO', OUTPUT_REPO),
|
| 235 |
-
token=TOKEN,
|
| 236 |
-
)
|
| 237 |
-
except Exception as e:
|
| 238 |
-
print(f'[benchmark] asset hydrate warning: {type(e).__name__}: {e}', flush=True)
|
| 239 |
-
try:
|
| 240 |
-
build_retina()
|
| 241 |
-
except Exception as e:
|
| 242 |
-
print(f'[benchmark] retina materialize warning: {type(e).__name__}: {e}', flush=True)
|
| 243 |
-
cmd = build_benchmark_mode_command()
|
| 244 |
-
print(f'[benchmark] command={cmd}', flush=True)
|
| 245 |
-
proc = subprocess.run(cmd, check=False)
|
| 246 |
-
|
| 247 |
-
if TOKEN:
|
| 248 |
-
api = HfApi(token=TOKEN)
|
| 249 |
-
try:
|
| 250 |
-
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
|
| 251 |
-
except Exception as e:
|
| 252 |
-
print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
|
| 253 |
-
prefix = f'jobs/{JOB_ID}'
|
| 254 |
-
try:
|
| 255 |
-
upload_artifact(api, REPO_ROOT / 'benchmark_result.json', f'{prefix}/benchmark_result.json')
|
| 256 |
-
upload_artifact(api, REPO_ROOT / 'benchmark_ledger.json', f'{prefix}/benchmark_ledger.json')
|
| 257 |
-
except Exception as e:
|
| 258 |
-
print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
|
| 259 |
-
else:
|
| 260 |
-
print('[upload] HF_TOKEN not set; skipping benchmark artifact upload', flush=True)
|
| 261 |
-
|
| 262 |
-
return proc.returncode
|
| 263 |
|
| 264 |
|
| 265 |
def run_space_mode() -> int:
|
|
@@ -273,12 +217,10 @@ def run_space_mode() -> int:
|
|
| 273 |
server.server_close()
|
| 274 |
|
| 275 |
|
| 276 |
-
def main() -> int:
|
| 277 |
-
if RUNTIME_MODE == 'job':
|
| 278 |
-
return run_job_mode()
|
| 279 |
-
|
| 280 |
-
return run_benchmark_mode()
|
| 281 |
-
return run_space_mode()
|
| 282 |
|
| 283 |
|
| 284 |
if __name__ == '__main__':
|
|
|
|
| 68 |
except ImportError:
|
| 69 |
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
|
| 70 |
|
| 71 |
+
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
REPO_ROOT = Path('/workspace/feather')
|
| 74 |
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
|
|
|
|
| 110 |
return server
|
| 111 |
|
| 112 |
|
| 113 |
+
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
|
| 114 |
if not path.exists():
|
| 115 |
print(f'[upload] skip missing {path}', flush=True)
|
| 116 |
return
|
|
|
|
| 120 |
repo_id=OUTPUT_REPO,
|
| 121 |
repo_type='model',
|
| 122 |
)
|
| 123 |
+
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
|
|
|
|
| 158 |
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
|
| 159 |
|
| 160 |
|
| 161 |
+
def run_job_mode() -> int:
|
| 162 |
os.chdir(REPO_ROOT)
|
| 163 |
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
|
| 164 |
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
|
|
|
|
| 203 |
else:
|
| 204 |
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
|
| 205 |
|
| 206 |
+
return proc.returncode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
def run_space_mode() -> int:
|
|
|
|
| 217 |
server.server_close()
|
| 218 |
|
| 219 |
|
| 220 |
+
def main() -> int:
|
| 221 |
+
if RUNTIME_MODE == 'job':
|
| 222 |
+
return run_job_mode()
|
| 223 |
+
return run_space_mode()
|
|
|
|
|
|
|
| 224 |
|
| 225 |
|
| 226 |
if __name__ == '__main__':
|
mamba_ssm_init.py
CHANGED
|
@@ -24,8 +24,8 @@ mamba_inner_fn = None
|
|
| 24 |
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
# exist at module load time.
|
| 27 |
-
import triton as _triton # noqa: E402
|
| 28 |
-
if not hasattr(_triton, "set_allocator"):
|
| 29 |
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
return None
|
| 31 |
_triton.set_allocator = _noop_set_allocator
|
|
@@ -53,39 +53,7 @@ if not hasattr(_tcc, "triton_key"):
|
|
| 53 |
def _triton_key_shim():
|
| 54 |
import triton as _t
|
| 55 |
return f"triton-{_t.__version__}-shim"
|
| 56 |
-
_tcc.triton_key = _triton_key_shim
|
| 57 |
-
|
| 58 |
-
# Triton 3.5 wheels can occasionally load with an empty backend registry in
|
| 59 |
-
# HF Jobs environments (driver.active -> "0 active drivers"), even though the
|
| 60 |
-
# NVIDIA backend module is present and CudaDriver.is_active() is True.
|
| 61 |
-
# Patch _create_driver to directly select CudaDriver when registry discovery
|
| 62 |
-
# returns empty.
|
| 63 |
-
import importlib as _importlib # noqa: E402
|
| 64 |
-
_triton_driver_mod = _importlib.import_module("triton.runtime.driver")
|
| 65 |
-
if getattr(_triton_driver_mod, "backends", None) == {}:
|
| 66 |
-
from triton.backends.nvidia import driver as _nvidia_driver # noqa: E402
|
| 67 |
-
|
| 68 |
-
def _create_driver_shim():
|
| 69 |
-
if hasattr(_nvidia_driver, "CudaDriver") and _nvidia_driver.CudaDriver.is_active():
|
| 70 |
-
return _nvidia_driver.CudaDriver()
|
| 71 |
-
raise RuntimeError(
|
| 72 |
-
"Triton backend registry is empty and NVIDIA CudaDriver is not active"
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
_triton_driver_mod._create_driver = _create_driver_shim
|
| 76 |
-
if hasattr(_triton_driver_mod, "driver") and hasattr(_triton_driver_mod.driver, "reset_active"):
|
| 77 |
-
_triton_driver_mod.driver.reset_active()
|
| 78 |
-
|
| 79 |
-
_triton_compiler_mod = _importlib.import_module("triton.compiler.compiler")
|
| 80 |
-
if getattr(_triton_compiler_mod, "backends", None) == {}:
|
| 81 |
-
from triton.backends import Backend as _Backend # noqa: E402
|
| 82 |
-
from triton.backends.nvidia.compiler import CUDABackend as _CUDABackend # noqa: E402
|
| 83 |
-
from triton.backends.nvidia.driver import CudaDriver as _CudaDriver # noqa: E402
|
| 84 |
-
|
| 85 |
-
_triton_compiler_mod.backends["nvidia"] = _Backend(
|
| 86 |
-
compiler=_CUDABackend,
|
| 87 |
-
driver=_CudaDriver,
|
| 88 |
-
)
|
| 89 |
|
| 90 |
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 91 |
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
|
|
|
| 24 |
# stub is never actually invoked at runtime because the codebase does not use
|
| 25 |
# torch.compile — but importing torch._inductor.* still requires the symbol to
|
| 26 |
# exist at module load time.
|
| 27 |
+
import triton as _triton # noqa: E402
|
| 28 |
+
if not hasattr(_triton, "set_allocator"):
|
| 29 |
def _noop_set_allocator(_fn): # pragma: no cover
|
| 30 |
return None
|
| 31 |
_triton.set_allocator = _noop_set_allocator
|
|
|
|
| 53 |
def _triton_key_shim():
|
| 54 |
import triton as _t
|
| 55 |
return f"triton-{_t.__version__}-shim"
|
| 56 |
+
_tcc.triton_key = _triton_key_shim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
|
| 59 |
# for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
|
overlay/configs/harness_config.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
"""Harness configuration for HYDRA's self-evolving outer loop."""
|
| 2 |
-
from typing import Literal
|
| 3 |
-
|
| 4 |
-
from pydantic import BaseModel, Field
|
| 5 |
-
|
| 6 |
-
type GateThresholds = dict[str, float]
|
| 7 |
-
type GateConfig = dict[str, GateThresholds]
|
| 8 |
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""Configuration for the HYDRA harness behavior."""
|
| 12 |
|
| 13 |
# Inner loop
|
|
@@ -50,19 +50,19 @@ class HarnessConfig(BaseModel):
|
|
| 50 |
default=5.0, description="Max % regression from best known val_bpb"
|
| 51 |
)
|
| 52 |
|
| 53 |
-
# Keep/discard criteria
|
| 54 |
-
primary_metric: str = "val_bpb"
|
| 55 |
-
secondary_metrics: GateConfig = Field(
|
| 56 |
-
default_factory=lambda: {
|
| 57 |
-
"mhc_spectral_norm": {"max": 2.0},
|
| 58 |
-
"engram_hit_rate": {"min": 0.1},
|
| 59 |
-
"factual_english_score": {"min": 0.5},
|
| 60 |
-
"instruction_following_score": {"min": 0.5},
|
| 61 |
-
"distinct_2": {"min": 0.1},
|
| 62 |
-
"repetition_rate": {"max": 0.2},
|
| 63 |
-
"hestia_quant_error": {"max": 0.05},
|
| 64 |
-
}
|
| 65 |
-
)
|
| 66 |
|
| 67 |
# Experiment execution
|
| 68 |
experiment_timeout: int = Field(
|
|
@@ -80,29 +80,29 @@ class HarnessConfig(BaseModel):
|
|
| 80 |
gate_mhc_spectral_norm: float | None = Field(
|
| 81 |
default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
|
| 82 |
)
|
| 83 |
-
gate_engram_hit_rate: float | None = Field(
|
| 84 |
-
default=None, description="Min engram_hit_rate for keep (None=disabled)"
|
| 85 |
-
)
|
| 86 |
-
gate_tps_median: float | None = Field(
|
| 87 |
-
default=None,
|
| 88 |
-
description="Min steady-state tps_median for keep (None=disabled)",
|
| 89 |
-
)
|
| 90 |
-
gate_tps_p10: float | None = Field(
|
| 91 |
-
default=None,
|
| 92 |
-
description="Min steady-state tps_p10 for keep (None=disabled)",
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
def to_secondary_gates(self) -> GateConfig:
|
| 96 |
-
"""Build active keep/discard gates from defaults plus gate_* overrides."""
|
| 97 |
-
gates = {metric: thresholds.copy() for metric, thresholds in self.secondary_metrics.items()}
|
| 98 |
-
|
| 99 |
-
if self.gate_mhc_spectral_norm is not None:
|
| 100 |
-
gates.setdefault("mhc_spectral_norm", {})["max"] = self.gate_mhc_spectral_norm
|
| 101 |
-
if self.gate_engram_hit_rate is not None:
|
| 102 |
-
gates.setdefault("engram_hit_rate", {})["min"] = self.gate_engram_hit_rate
|
| 103 |
-
if self.gate_tps_median is not None:
|
| 104 |
-
gates.setdefault("tps_median", {})["min"] = self.gate_tps_median
|
| 105 |
-
if self.gate_tps_p10 is not None:
|
| 106 |
-
gates.setdefault("tps_p10", {})["min"] = self.gate_tps_p10
|
| 107 |
-
|
| 108 |
-
return gates
|
|
|
|
| 1 |
+
"""Harness configuration for HYDRA's self-evolving outer loop."""
|
| 2 |
+
from typing import Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
|
| 6 |
+
type GateThresholds = dict[str, float]
|
| 7 |
+
type GateConfig = dict[str, GateThresholds]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HarnessConfig(BaseModel):
|
| 11 |
"""Configuration for the HYDRA harness behavior."""
|
| 12 |
|
| 13 |
# Inner loop
|
|
|
|
| 50 |
default=5.0, description="Max % regression from best known val_bpb"
|
| 51 |
)
|
| 52 |
|
| 53 |
+
# Keep/discard criteria
|
| 54 |
+
primary_metric: str = "val_bpb"
|
| 55 |
+
secondary_metrics: GateConfig = Field(
|
| 56 |
+
default_factory=lambda: {
|
| 57 |
+
"mhc_spectral_norm": {"max": 2.0},
|
| 58 |
+
"engram_hit_rate": {"min": 0.1},
|
| 59 |
+
"factual_english_score": {"min": 0.5},
|
| 60 |
+
"instruction_following_score": {"min": 0.5},
|
| 61 |
+
"distinct_2": {"min": 0.1},
|
| 62 |
+
"repetition_rate": {"max": 0.2},
|
| 63 |
+
"hestia_quant_error": {"max": 0.05},
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
|
| 67 |
# Experiment execution
|
| 68 |
experiment_timeout: int = Field(
|
|
|
|
| 80 |
gate_mhc_spectral_norm: float | None = Field(
|
| 81 |
default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
|
| 82 |
)
|
| 83 |
+
gate_engram_hit_rate: float | None = Field(
|
| 84 |
+
default=None, description="Min engram_hit_rate for keep (None=disabled)"
|
| 85 |
+
)
|
| 86 |
+
gate_tps_median: float | None = Field(
|
| 87 |
+
default=None,
|
| 88 |
+
description="Min steady-state tps_median for keep (None=disabled)",
|
| 89 |
+
)
|
| 90 |
+
gate_tps_p10: float | None = Field(
|
| 91 |
+
default=None,
|
| 92 |
+
description="Min steady-state tps_p10 for keep (None=disabled)",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def to_secondary_gates(self) -> GateConfig:
|
| 96 |
+
"""Build active keep/discard gates from defaults plus gate_* overrides."""
|
| 97 |
+
gates = {metric: thresholds.copy() for metric, thresholds in self.secondary_metrics.items()}
|
| 98 |
+
|
| 99 |
+
if self.gate_mhc_spectral_norm is not None:
|
| 100 |
+
gates.setdefault("mhc_spectral_norm", {})["max"] = self.gate_mhc_spectral_norm
|
| 101 |
+
if self.gate_engram_hit_rate is not None:
|
| 102 |
+
gates.setdefault("engram_hit_rate", {})["min"] = self.gate_engram_hit_rate
|
| 103 |
+
if self.gate_tps_median is not None:
|
| 104 |
+
gates.setdefault("tps_median", {})["min"] = self.gate_tps_median
|
| 105 |
+
if self.gate_tps_p10 is not None:
|
| 106 |
+
gates.setdefault("tps_p10", {})["min"] = self.gate_tps_p10
|
| 107 |
+
|
| 108 |
+
return gates
|
overlay/harness/eval_agent.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
"""Eval agent: parse run.log and extract metrics from training runs."""
|
| 2 |
-
import re
|
| 3 |
-
import statistics
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
type GateThresholds = dict[str, float]
|
| 8 |
-
type GateConfig = dict[str, GateThresholds]
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
-
class ExperimentResult:
|
| 13 |
"""Parsed result from a single experiment run.
|
| 14 |
|
| 15 |
All float fields default to 0.0; integer fields default to 0.
|
|
@@ -28,38 +28,38 @@ class ExperimentResult:
|
|
| 28 |
peak_vram_mb: float = 0.0
|
| 29 |
mfu_percent: float = 0.0
|
| 30 |
|
| 31 |
-
# Throughput
|
| 32 |
-
total_tokens_m: float = 0.0
|
| 33 |
-
num_steps: int = 0
|
| 34 |
-
tps_median: float = 0.0
|
| 35 |
-
tps_p10: float = 0.0
|
| 36 |
-
tps_min: float = 0.0
|
| 37 |
-
tps_max: float = 0.0
|
| 38 |
-
tps_samples: int = 0
|
| 39 |
|
| 40 |
# Model shape (echoed by train.py summary block)
|
| 41 |
num_params_m: float = 0.0
|
| 42 |
n_layer: int = 0
|
| 43 |
d_model: int = 0
|
| 44 |
|
| 45 |
-
# Secondary health metrics
|
| 46 |
-
mhc_spectral_norm: float = 0.0
|
| 47 |
-
engram_hit_rate: float = 0.0
|
| 48 |
-
sr_bypass_rate: float = 0.0
|
| 49 |
-
|
| 50 |
-
# Evaluation breadth metrics
|
| 51 |
-
factual_english_score: float = 0.0
|
| 52 |
-
instruction_following_score: float = 0.0
|
| 53 |
-
distinct_1: float = 0.0
|
| 54 |
-
distinct_2: float = 0.0
|
| 55 |
-
repetition_rate: float = 0.0
|
| 56 |
-
repetition_bigram_rate: float = 0.0
|
| 57 |
-
calibration_ece: float = 0.0
|
| 58 |
-
calibration_brier: float = 0.0
|
| 59 |
-
calibration_accuracy: float = 0.0
|
| 60 |
-
calibration_tokens: int = 0
|
| 61 |
-
eval_seed: int = 0
|
| 62 |
-
eval_seed_group: str = ""
|
| 63 |
|
| 64 |
# Status
|
| 65 |
crashed: bool = False
|
|
@@ -80,48 +80,48 @@ _PATTERNS: dict[str, str] = {
|
|
| 80 |
"n_layer": r"^n_layer:\s+(\d+)",
|
| 81 |
"d_model": r"^d_model:\s+(\d+)",
|
| 82 |
"mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
|
| 83 |
-
"engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
|
| 84 |
-
"sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
|
| 85 |
-
"factual_english_score": r"^factual_english_score:\s+([\d.]+)",
|
| 86 |
-
"instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
|
| 87 |
-
"distinct_1": r"^distinct_1:\s+([\d.]+)",
|
| 88 |
-
"distinct_2": r"^distinct_2:\s+([\d.]+)",
|
| 89 |
-
"repetition_rate": r"^repetition_rate:\s+([\d.]+)",
|
| 90 |
-
"repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
|
| 91 |
-
"calibration_ece": r"^calibration_ece:\s+([\d.]+)",
|
| 92 |
-
"calibration_brier": r"^calibration_brier:\s*([\d.]+)",
|
| 93 |
-
"calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
|
| 94 |
-
"calibration_tokens": r"^calibration_tokens:\s+(\d+)",
|
| 95 |
-
"eval_seed": r"^eval_seed:\s+(\d+)",
|
| 96 |
-
"eval_seed_group": r"^eval_seed_group:\s+(.+)",
|
| 97 |
-
}
|
| 98 |
|
| 99 |
# Attributes that should be parsed as int rather than float.
|
| 100 |
-
_INT_ATTRS: frozenset[str] = frozenset(
|
| 101 |
-
{
|
| 102 |
-
"num_steps",
|
| 103 |
-
"n_layer",
|
| 104 |
-
"d_model",
|
| 105 |
-
"calibration_tokens",
|
| 106 |
-
"eval_seed",
|
| 107 |
-
}
|
| 108 |
-
)
|
| 109 |
-
_STR_ATTRS: frozenset[str] = frozenset({"eval_seed_group"})
|
| 110 |
-
_STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
|
| 111 |
-
_TPS_PATTERN = re.compile(r"\btps=(\d+)\b")
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def _percentile_linear(sorted_values: list[float], pct: float) -> float:
|
| 115 |
-
"""Compute percentile via linear interpolation (0 <= pct <= 100)."""
|
| 116 |
-
if not sorted_values:
|
| 117 |
-
return 0.0
|
| 118 |
-
if len(sorted_values) == 1:
|
| 119 |
-
return sorted_values[0]
|
| 120 |
-
rank = (len(sorted_values) - 1) * (pct / 100.0)
|
| 121 |
-
lo = int(rank)
|
| 122 |
-
hi = min(lo + 1, len(sorted_values) - 1)
|
| 123 |
-
frac = rank - lo
|
| 124 |
-
return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
|
| 125 |
|
| 126 |
|
| 127 |
def parse_run_log(log_path: str) -> ExperimentResult:
|
|
@@ -144,60 +144,60 @@ def parse_run_log(log_path: str) -> ExperimentResult:
|
|
| 144 |
result.error_message = f"Log file not found: {log_path}"
|
| 145 |
return result
|
| 146 |
|
| 147 |
-
# Detect crash signals in output. Keep this strict to avoid false positives
|
| 148 |
-
# from benign log lines that include "error" in a non-fatal context.
|
| 149 |
-
if (
|
| 150 |
-
"Traceback" in content
|
| 151 |
-
or "\nFAIL\n" in content
|
| 152 |
-
or "[TPS_GUARD] FAIL" in content
|
| 153 |
-
or "raise SystemExit(1)" in content
|
| 154 |
-
):
|
| 155 |
-
result.crashed = True
|
| 156 |
-
lines = content.strip().splitlines()
|
| 157 |
-
result.error_message = "\n".join(lines[-20:])
|
| 158 |
-
|
| 159 |
-
for attr, pattern in _PATTERNS.items():
|
| 160 |
-
match = re.search(pattern, content, re.MULTILINE)
|
| 161 |
-
if match:
|
| 162 |
-
raw = match.group(1)
|
| 163 |
-
if attr in _INT_ATTRS:
|
| 164 |
-
setattr(result, attr, int(raw))
|
| 165 |
-
elif attr in _STR_ATTRS:
|
| 166 |
-
setattr(result, attr, raw.strip())
|
| 167 |
-
else:
|
| 168 |
-
setattr(result, attr, float(raw))
|
| 169 |
-
|
| 170 |
-
warmup_steps = 10
|
| 171 |
-
warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
|
| 172 |
-
if warmup_match:
|
| 173 |
-
warmup_steps = int(warmup_match.group(1))
|
| 174 |
-
|
| 175 |
-
step_tps_samples: list[tuple[int, int]] = []
|
| 176 |
-
for m in _STEP_TPS_PATTERN.finditer(content):
|
| 177 |
-
step_tps_samples.append((int(m.group(1)), int(m.group(2))))
|
| 178 |
-
|
| 179 |
-
tps_values: list[float] = []
|
| 180 |
-
if step_tps_samples:
|
| 181 |
-
for step, tps in step_tps_samples:
|
| 182 |
-
if step >= warmup_steps:
|
| 183 |
-
tps_values.append(float(tps))
|
| 184 |
-
if not tps_values:
|
| 185 |
-
tps_values = [float(tps) for _, tps in step_tps_samples]
|
| 186 |
-
else:
|
| 187 |
-
tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
|
| 188 |
-
|
| 189 |
-
if tps_values:
|
| 190 |
-
sorted_tps = sorted(tps_values)
|
| 191 |
-
result.tps_samples = len(tps_values)
|
| 192 |
-
result.tps_median = float(statistics.median(tps_values))
|
| 193 |
-
result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
|
| 194 |
-
result.tps_min = float(sorted_tps[0])
|
| 195 |
-
result.tps_max = float(sorted_tps[-1])
|
| 196 |
-
|
| 197 |
-
return result
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def check_secondary_alarms(result: ExperimentResult) -> list[str]:
|
| 201 |
"""Check secondary metrics against fixed alarm thresholds.
|
| 202 |
|
| 203 |
Args:
|
|
@@ -216,44 +216,44 @@ def check_secondary_alarms(result: ExperimentResult) -> list[str]:
|
|
| 216 |
alarms.append(
|
| 217 |
f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
|
| 218 |
)
|
| 219 |
-
if 0 < result.mfu_percent < 10:
|
| 220 |
-
alarms.append(
|
| 221 |
-
f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
|
| 222 |
-
)
|
| 223 |
-
if result.calibration_ece > 0.35:
|
| 224 |
-
alarms.append(
|
| 225 |
-
f"calibration_ece={result.calibration_ece:.4f} > 0.35 (poor calibration)"
|
| 226 |
-
)
|
| 227 |
-
if result.tps_median > 0 and result.tps_median < 50000:
|
| 228 |
-
alarms.append(
|
| 229 |
-
f"tps_median={result.tps_median:.0f} < 50000 (throughput below A10 objective)"
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
return alarms
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def _check_gate(
|
| 236 |
-
result: ExperimentResult,
|
| 237 |
-
gates: GateConfig,
|
| 238 |
-
metric: str,
|
| 239 |
-
) -> tuple[bool, str] | None:
|
| 240 |
-
"""Evaluate a single min/max gate against an ExperimentResult metric."""
|
| 241 |
-
gate = gates.get(metric, {})
|
| 242 |
-
value = getattr(result, metric)
|
| 243 |
-
max_value = gate.get("max")
|
| 244 |
-
if max_value is not None and value > max_value:
|
| 245 |
-
return False, f"{metric} {value:.4f} > gate {max_value}"
|
| 246 |
-
min_value = gate.get("min")
|
| 247 |
-
if min_value is not None and value < min_value:
|
| 248 |
-
return False, f"{metric} {value:.4f} < gate {min_value}"
|
| 249 |
-
return None
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
def should_keep(
|
| 253 |
-
result: ExperimentResult,
|
| 254 |
-
best_bpb: float,
|
| 255 |
-
gates: GateConfig | None = None,
|
| 256 |
-
) -> tuple[bool, str]:
|
| 257 |
"""Decide whether to keep or discard an experiment.
|
| 258 |
|
| 259 |
The primary criterion is strictly lower val_bpb than the current best.
|
|
@@ -277,24 +277,24 @@ def should_keep(
|
|
| 277 |
if result.val_bpb >= best_bpb:
|
| 278 |
return False, "discard"
|
| 279 |
|
| 280 |
-
# Secondary gate checks.
|
| 281 |
-
if gates:
|
| 282 |
-
gate_metrics = (
|
| 283 |
-
"mhc_spectral_norm",
|
| 284 |
-
"engram_hit_rate",
|
| 285 |
-
"factual_english_score",
|
| 286 |
-
"instruction_following_score",
|
| 287 |
-
"distinct_1",
|
| 288 |
-
"distinct_2",
|
| 289 |
-
"repetition_rate",
|
| 290 |
-
"repetition_bigram_rate",
|
| 291 |
-
"calibration_ece",
|
| 292 |
-
"tps_median",
|
| 293 |
-
"tps_p10",
|
| 294 |
-
)
|
| 295 |
-
for metric in gate_metrics:
|
| 296 |
-
gate_result = _check_gate(result, gates, metric)
|
| 297 |
-
if gate_result is not None:
|
| 298 |
-
return gate_result
|
| 299 |
-
|
| 300 |
-
return True, "keep"
|
|
|
|
| 1 |
+
"""Eval agent: parse run.log and extract metrics from training runs."""
|
| 2 |
+
import re
|
| 3 |
+
import statistics
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
type GateThresholds = dict[str, float]
|
| 8 |
+
type GateConfig = dict[str, GateThresholds]
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
+
class ExperimentResult:
|
| 13 |
"""Parsed result from a single experiment run.
|
| 14 |
|
| 15 |
All float fields default to 0.0; integer fields default to 0.
|
|
|
|
| 28 |
peak_vram_mb: float = 0.0
|
| 29 |
mfu_percent: float = 0.0
|
| 30 |
|
| 31 |
+
# Throughput
|
| 32 |
+
total_tokens_m: float = 0.0
|
| 33 |
+
num_steps: int = 0
|
| 34 |
+
tps_median: float = 0.0
|
| 35 |
+
tps_p10: float = 0.0
|
| 36 |
+
tps_min: float = 0.0
|
| 37 |
+
tps_max: float = 0.0
|
| 38 |
+
tps_samples: int = 0
|
| 39 |
|
| 40 |
# Model shape (echoed by train.py summary block)
|
| 41 |
num_params_m: float = 0.0
|
| 42 |
n_layer: int = 0
|
| 43 |
d_model: int = 0
|
| 44 |
|
| 45 |
+
# Secondary health metrics
|
| 46 |
+
mhc_spectral_norm: float = 0.0
|
| 47 |
+
engram_hit_rate: float = 0.0
|
| 48 |
+
sr_bypass_rate: float = 0.0
|
| 49 |
+
|
| 50 |
+
# Evaluation breadth metrics
|
| 51 |
+
factual_english_score: float = 0.0
|
| 52 |
+
instruction_following_score: float = 0.0
|
| 53 |
+
distinct_1: float = 0.0
|
| 54 |
+
distinct_2: float = 0.0
|
| 55 |
+
repetition_rate: float = 0.0
|
| 56 |
+
repetition_bigram_rate: float = 0.0
|
| 57 |
+
calibration_ece: float = 0.0
|
| 58 |
+
calibration_brier: float = 0.0
|
| 59 |
+
calibration_accuracy: float = 0.0
|
| 60 |
+
calibration_tokens: int = 0
|
| 61 |
+
eval_seed: int = 0
|
| 62 |
+
eval_seed_group: str = ""
|
| 63 |
|
| 64 |
# Status
|
| 65 |
crashed: bool = False
|
|
|
|
| 80 |
"n_layer": r"^n_layer:\s+(\d+)",
|
| 81 |
"d_model": r"^d_model:\s+(\d+)",
|
| 82 |
"mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
|
| 83 |
+
"engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
|
| 84 |
+
"sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
|
| 85 |
+
"factual_english_score": r"^factual_english_score:\s+([\d.]+)",
|
| 86 |
+
"instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
|
| 87 |
+
"distinct_1": r"^distinct_1:\s+([\d.]+)",
|
| 88 |
+
"distinct_2": r"^distinct_2:\s+([\d.]+)",
|
| 89 |
+
"repetition_rate": r"^repetition_rate:\s+([\d.]+)",
|
| 90 |
+
"repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
|
| 91 |
+
"calibration_ece": r"^calibration_ece:\s+([\d.]+)",
|
| 92 |
+
"calibration_brier": r"^calibration_brier:\s*([\d.]+)",
|
| 93 |
+
"calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
|
| 94 |
+
"calibration_tokens": r"^calibration_tokens:\s+(\d+)",
|
| 95 |
+
"eval_seed": r"^eval_seed:\s+(\d+)",
|
| 96 |
+
"eval_seed_group": r"^eval_seed_group:\s+(.+)",
|
| 97 |
+
}
|
| 98 |
|
| 99 |
# Attributes that should be parsed as int rather than float.
|
| 100 |
+
_INT_ATTRS: frozenset[str] = frozenset(
|
| 101 |
+
{
|
| 102 |
+
"num_steps",
|
| 103 |
+
"n_layer",
|
| 104 |
+
"d_model",
|
| 105 |
+
"calibration_tokens",
|
| 106 |
+
"eval_seed",
|
| 107 |
+
}
|
| 108 |
+
)
|
| 109 |
+
_STR_ATTRS: frozenset[str] = frozenset({"eval_seed_group"})
|
| 110 |
+
_STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
|
| 111 |
+
_TPS_PATTERN = re.compile(r"\btps=(\d+)\b")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _percentile_linear(sorted_values: list[float], pct: float) -> float:
|
| 115 |
+
"""Compute percentile via linear interpolation (0 <= pct <= 100)."""
|
| 116 |
+
if not sorted_values:
|
| 117 |
+
return 0.0
|
| 118 |
+
if len(sorted_values) == 1:
|
| 119 |
+
return sorted_values[0]
|
| 120 |
+
rank = (len(sorted_values) - 1) * (pct / 100.0)
|
| 121 |
+
lo = int(rank)
|
| 122 |
+
hi = min(lo + 1, len(sorted_values) - 1)
|
| 123 |
+
frac = rank - lo
|
| 124 |
+
return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
|
| 125 |
|
| 126 |
|
| 127 |
def parse_run_log(log_path: str) -> ExperimentResult:
|
|
|
|
| 144 |
result.error_message = f"Log file not found: {log_path}"
|
| 145 |
return result
|
| 146 |
|
| 147 |
+
# Detect crash signals in output. Keep this strict to avoid false positives
|
| 148 |
+
# from benign log lines that include "error" in a non-fatal context.
|
| 149 |
+
if (
|
| 150 |
+
"Traceback" in content
|
| 151 |
+
or "\nFAIL\n" in content
|
| 152 |
+
or "[TPS_GUARD] FAIL" in content
|
| 153 |
+
or "raise SystemExit(1)" in content
|
| 154 |
+
):
|
| 155 |
+
result.crashed = True
|
| 156 |
+
lines = content.strip().splitlines()
|
| 157 |
+
result.error_message = "\n".join(lines[-20:])
|
| 158 |
+
|
| 159 |
+
for attr, pattern in _PATTERNS.items():
|
| 160 |
+
match = re.search(pattern, content, re.MULTILINE)
|
| 161 |
+
if match:
|
| 162 |
+
raw = match.group(1)
|
| 163 |
+
if attr in _INT_ATTRS:
|
| 164 |
+
setattr(result, attr, int(raw))
|
| 165 |
+
elif attr in _STR_ATTRS:
|
| 166 |
+
setattr(result, attr, raw.strip())
|
| 167 |
+
else:
|
| 168 |
+
setattr(result, attr, float(raw))
|
| 169 |
+
|
| 170 |
+
warmup_steps = 10
|
| 171 |
+
warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
|
| 172 |
+
if warmup_match:
|
| 173 |
+
warmup_steps = int(warmup_match.group(1))
|
| 174 |
+
|
| 175 |
+
step_tps_samples: list[tuple[int, int]] = []
|
| 176 |
+
for m in _STEP_TPS_PATTERN.finditer(content):
|
| 177 |
+
step_tps_samples.append((int(m.group(1)), int(m.group(2))))
|
| 178 |
+
|
| 179 |
+
tps_values: list[float] = []
|
| 180 |
+
if step_tps_samples:
|
| 181 |
+
for step, tps in step_tps_samples:
|
| 182 |
+
if step >= warmup_steps:
|
| 183 |
+
tps_values.append(float(tps))
|
| 184 |
+
if not tps_values:
|
| 185 |
+
tps_values = [float(tps) for _, tps in step_tps_samples]
|
| 186 |
+
else:
|
| 187 |
+
tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
|
| 188 |
+
|
| 189 |
+
if tps_values:
|
| 190 |
+
sorted_tps = sorted(tps_values)
|
| 191 |
+
result.tps_samples = len(tps_values)
|
| 192 |
+
result.tps_median = float(statistics.median(tps_values))
|
| 193 |
+
result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
|
| 194 |
+
result.tps_min = float(sorted_tps[0])
|
| 195 |
+
result.tps_max = float(sorted_tps[-1])
|
| 196 |
+
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def check_secondary_alarms(result: ExperimentResult) -> list[str]:
|
| 201 |
"""Check secondary metrics against fixed alarm thresholds.
|
| 202 |
|
| 203 |
Args:
|
|
|
|
| 216 |
alarms.append(
|
| 217 |
f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
|
| 218 |
)
|
| 219 |
+
if 0 < result.mfu_percent < 10:
|
| 220 |
+
alarms.append(
|
| 221 |
+
f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
|
| 222 |
+
)
|
| 223 |
+
if result.calibration_ece > 0.35:
|
| 224 |
+
alarms.append(
|
| 225 |
+
f"calibration_ece={result.calibration_ece:.4f} > 0.35 (poor calibration)"
|
| 226 |
+
)
|
| 227 |
+
if result.tps_median > 0 and result.tps_median < 50000:
|
| 228 |
+
alarms.append(
|
| 229 |
+
f"tps_median={result.tps_median:.0f} < 50000 (throughput below A10 objective)"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
return alarms
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _check_gate(
|
| 236 |
+
result: ExperimentResult,
|
| 237 |
+
gates: GateConfig,
|
| 238 |
+
metric: str,
|
| 239 |
+
) -> tuple[bool, str] | None:
|
| 240 |
+
"""Evaluate a single min/max gate against an ExperimentResult metric."""
|
| 241 |
+
gate = gates.get(metric, {})
|
| 242 |
+
value = getattr(result, metric)
|
| 243 |
+
max_value = gate.get("max")
|
| 244 |
+
if max_value is not None and value > max_value:
|
| 245 |
+
return False, f"{metric} {value:.4f} > gate {max_value}"
|
| 246 |
+
min_value = gate.get("min")
|
| 247 |
+
if min_value is not None and value < min_value:
|
| 248 |
+
return False, f"{metric} {value:.4f} < gate {min_value}"
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def should_keep(
|
| 253 |
+
result: ExperimentResult,
|
| 254 |
+
best_bpb: float,
|
| 255 |
+
gates: GateConfig | None = None,
|
| 256 |
+
) -> tuple[bool, str]:
|
| 257 |
"""Decide whether to keep or discard an experiment.
|
| 258 |
|
| 259 |
The primary criterion is strictly lower val_bpb than the current best.
|
|
|
|
| 277 |
if result.val_bpb >= best_bpb:
|
| 278 |
return False, "discard"
|
| 279 |
|
| 280 |
+
# Secondary gate checks.
|
| 281 |
+
if gates:
|
| 282 |
+
gate_metrics = (
|
| 283 |
+
"mhc_spectral_norm",
|
| 284 |
+
"engram_hit_rate",
|
| 285 |
+
"factual_english_score",
|
| 286 |
+
"instruction_following_score",
|
| 287 |
+
"distinct_1",
|
| 288 |
+
"distinct_2",
|
| 289 |
+
"repetition_rate",
|
| 290 |
+
"repetition_bigram_rate",
|
| 291 |
+
"calibration_ece",
|
| 292 |
+
"tps_median",
|
| 293 |
+
"tps_p10",
|
| 294 |
+
)
|
| 295 |
+
for metric in gate_metrics:
|
| 296 |
+
gate_result = _check_gate(result, gates, metric)
|
| 297 |
+
if gate_result is not None:
|
| 298 |
+
return gate_result
|
| 299 |
+
|
| 300 |
+
return True, "keep"
|
overlay/harness/orchestrator.py
CHANGED
|
@@ -20,12 +20,12 @@ provides the infrastructure ("rails") that the autoresearch loop runs on.
|
|
| 20 |
"""
|
| 21 |
import argparse
|
| 22 |
import csv
|
| 23 |
-
import os
|
| 24 |
-
import subprocess
|
| 25 |
-
import time
|
| 26 |
-
|
| 27 |
-
from configs.harness_config import HarnessConfig
|
| 28 |
-
from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
|
| 29 |
from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
|
| 30 |
from harness.health_monitor import check_health, reset_peak_stats
|
| 31 |
from harness.meta_agent import run_meta_iteration
|
|
@@ -145,12 +145,12 @@ def run_experiment(timeout: int = 600) -> str:
|
|
| 145 |
# ---------------------------------------------------------------------------
|
| 146 |
|
| 147 |
|
| 148 |
-
def run_loop(
|
| 149 |
-
meta_interval: int = 20,
|
| 150 |
-
max_experiments: int | None = None,
|
| 151 |
-
experiment_timeout: int = 600,
|
| 152 |
-
secondary_gates: dict[str, dict[str, float]] | None = None,
|
| 153 |
-
) -> None:
|
| 154 |
"""Run the HYDRA autoresearch loop.
|
| 155 |
|
| 156 |
This function runs indefinitely (or until ``max_experiments`` is reached
|
|
@@ -163,10 +163,10 @@ def run_loop(
|
|
| 163 |
secondary_gates: Optional gate thresholds forwarded to
|
| 164 |
:func:`~harness.eval_agent.should_keep`.
|
| 165 |
"""
|
| 166 |
-
init_results_tsv()
|
| 167 |
-
if secondary_gates is None:
|
| 168 |
-
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 169 |
-
best_bpb = _load_best_bpb()
|
| 170 |
experiment_num = count_experiments()
|
| 171 |
|
| 172 |
print(
|
|
|
|
| 20 |
"""
|
| 21 |
import argparse
|
| 22 |
import csv
|
| 23 |
+
import os
|
| 24 |
+
import subprocess
|
| 25 |
+
import time
|
| 26 |
+
|
| 27 |
+
from configs.harness_config import HarnessConfig
|
| 28 |
+
from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
|
| 29 |
from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
|
| 30 |
from harness.health_monitor import check_health, reset_peak_stats
|
| 31 |
from harness.meta_agent import run_meta_iteration
|
|
|
|
| 145 |
# ---------------------------------------------------------------------------
|
| 146 |
|
| 147 |
|
| 148 |
+
def run_loop(
|
| 149 |
+
meta_interval: int = 20,
|
| 150 |
+
max_experiments: int | None = None,
|
| 151 |
+
experiment_timeout: int = 600,
|
| 152 |
+
secondary_gates: dict[str, dict[str, float]] | None = None,
|
| 153 |
+
) -> None:
|
| 154 |
"""Run the HYDRA autoresearch loop.
|
| 155 |
|
| 156 |
This function runs indefinitely (or until ``max_experiments`` is reached
|
|
|
|
| 163 |
secondary_gates: Optional gate thresholds forwarded to
|
| 164 |
:func:`~harness.eval_agent.should_keep`.
|
| 165 |
"""
|
| 166 |
+
init_results_tsv()
|
| 167 |
+
if secondary_gates is None:
|
| 168 |
+
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 169 |
+
best_bpb = _load_best_bpb()
|
| 170 |
experiment_num = count_experiments()
|
| 171 |
|
| 172 |
print(
|
overlay/htm_rust/src/gpu/fused.rs
CHANGED
|
@@ -513,41 +513,41 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 513 |
assert_eq!(anom_per_region.len(), b);
|
| 514 |
assert!(b >= 1, "need at least one region");
|
| 515 |
|
| 516 |
-
// Reset per-region step_scratch before each launch.
|
| 517 |
-
for &rp in region_ptrs.iter() {
|
| 518 |
-
let r = unsafe { &mut *rp };
|
| 519 |
-
let dev = r.sp_gpu.dev_ref().clone();
|
| 520 |
-
let fused = r
|
| 521 |
-
.fused_state
|
| 522 |
-
.as_mut()
|
| 523 |
-
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 524 |
-
dev.memset_zeros(&mut fused.step_scratch)?;
|
| 525 |
-
fused.iter_counter = fused.iter_counter.wrapping_add(1);
|
| 526 |
-
}
|
| 527 |
|
| 528 |
// Shared config — all regions use identical sp/tm parameters.
|
| 529 |
-
let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
|
| 530 |
-
let r0 = unsafe { &*region_ptrs[0] };
|
| 531 |
-
let fused = r0
|
| 532 |
-
.fused_state
|
| 533 |
-
.as_ref()
|
| 534 |
-
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 535 |
-
(
|
| 536 |
-
fused.grid_dim_x,
|
| 537 |
-
fused.block_dim_x,
|
| 538 |
-
fused.raw_kernel.function_batched,
|
| 539 |
-
*r0.sp_gpu.dev_ref().cu_stream(),
|
| 540 |
-
*r0.sp_gpu.dev_ref().cu_primary_ctx(),
|
| 541 |
-
)
|
| 542 |
};
|
| 543 |
|
| 544 |
-
let cfg = {
|
| 545 |
-
let r = unsafe { &*region_ptrs[0] };
|
| 546 |
-
let fused = r
|
| 547 |
-
.fused_state
|
| 548 |
-
.as_ref()
|
| 549 |
-
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 550 |
-
FusedConfig {
|
| 551 |
input_bits: input_bits as u32,
|
| 552 |
n_columns: r.sp_gpu.n_columns_accessor() as u32,
|
| 553 |
synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
|
|
@@ -572,42 +572,42 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 572 |
initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
|
| 573 |
t: t as u32,
|
| 574 |
learn: if learn { 1 } else { 0 },
|
| 575 |
-
iter_seed: fused.iter_counter,
|
| 576 |
-
cooperative_grid_sync: 1,
|
| 577 |
-
}
|
| 578 |
-
};
|
| 579 |
|
| 580 |
// Build B FusedPtrs per-region.
|
| 581 |
-
let ptrs_vec: Vec<FusedPtrs> = (0..b)
|
| 582 |
-
.map(|i| {
|
| 583 |
-
let r = unsafe { &*region_ptrs[i] };
|
| 584 |
-
let fused = r
|
| 585 |
-
.fused_state
|
| 586 |
-
.as_ref()
|
| 587 |
-
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 588 |
-
Ok(FusedPtrs {
|
| 589 |
-
syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
|
| 590 |
-
syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
|
| 591 |
-
boost: *r.sp_gpu.boost_accessor().device_ptr(),
|
| 592 |
-
active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
|
| 593 |
-
inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
|
| 594 |
-
seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
|
| 595 |
-
seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
|
| 596 |
-
syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
|
| 597 |
-
tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
|
| 598 |
-
cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
|
| 599 |
-
cell_active_a: *fused.cell_active_bits_a.device_ptr(),
|
| 600 |
-
cell_active_b: *fused.cell_active_bits_b.device_ptr(),
|
| 601 |
-
cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
|
| 602 |
-
cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
|
| 603 |
-
inputs: inputs_per_region[i],
|
| 604 |
-
cols_out: cols_per_region[i],
|
| 605 |
-
anom_out: anom_per_region[i],
|
| 606 |
-
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 607 |
-
step_scratch: *fused.step_scratch.device_ptr(),
|
| 608 |
-
})
|
| 609 |
-
})
|
| 610 |
-
.collect::<Result<Vec<_>, DriverError>>()?;
|
| 611 |
|
| 612 |
// Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
|
| 613 |
// FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
|
|
@@ -619,14 +619,14 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 619 |
// Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
|
| 620 |
// occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
|
| 621 |
// on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
|
| 622 |
-
let use_cluster = {
|
| 623 |
-
let r0 = unsafe { &*region_ptrs[0] };
|
| 624 |
-
let fused = r0
|
| 625 |
-
.fused_state
|
| 626 |
-
.as_ref()
|
| 627 |
-
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 628 |
-
fused.cluster_info.max_cluster_size > 0
|
| 629 |
-
};
|
| 630 |
|
| 631 |
unsafe {
|
| 632 |
result::ctx::set_current(cu_ctx)?;
|
|
|
|
| 513 |
assert_eq!(anom_per_region.len(), b);
|
| 514 |
assert!(b >= 1, "need at least one region");
|
| 515 |
|
| 516 |
+
// Reset per-region step_scratch before each launch.
|
| 517 |
+
for &rp in region_ptrs.iter() {
|
| 518 |
+
let r = unsafe { &mut *rp };
|
| 519 |
+
let dev = r.sp_gpu.dev_ref().clone();
|
| 520 |
+
let fused = r
|
| 521 |
+
.fused_state
|
| 522 |
+
.as_mut()
|
| 523 |
+
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 524 |
+
dev.memset_zeros(&mut fused.step_scratch)?;
|
| 525 |
+
fused.iter_counter = fused.iter_counter.wrapping_add(1);
|
| 526 |
+
}
|
| 527 |
|
| 528 |
// Shared config — all regions use identical sp/tm parameters.
|
| 529 |
+
let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
|
| 530 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 531 |
+
let fused = r0
|
| 532 |
+
.fused_state
|
| 533 |
+
.as_ref()
|
| 534 |
+
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 535 |
+
(
|
| 536 |
+
fused.grid_dim_x,
|
| 537 |
+
fused.block_dim_x,
|
| 538 |
+
fused.raw_kernel.function_batched,
|
| 539 |
+
*r0.sp_gpu.dev_ref().cu_stream(),
|
| 540 |
+
*r0.sp_gpu.dev_ref().cu_primary_ctx(),
|
| 541 |
+
)
|
| 542 |
};
|
| 543 |
|
| 544 |
+
let cfg = {
|
| 545 |
+
let r = unsafe { &*region_ptrs[0] };
|
| 546 |
+
let fused = r
|
| 547 |
+
.fused_state
|
| 548 |
+
.as_ref()
|
| 549 |
+
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 550 |
+
FusedConfig {
|
| 551 |
input_bits: input_bits as u32,
|
| 552 |
n_columns: r.sp_gpu.n_columns_accessor() as u32,
|
| 553 |
synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
|
|
|
|
| 572 |
initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
|
| 573 |
t: t as u32,
|
| 574 |
learn: if learn { 1 } else { 0 },
|
| 575 |
+
iter_seed: fused.iter_counter,
|
| 576 |
+
cooperative_grid_sync: 1,
|
| 577 |
+
}
|
| 578 |
+
};
|
| 579 |
|
| 580 |
// Build B FusedPtrs per-region.
|
| 581 |
+
let ptrs_vec: Vec<FusedPtrs> = (0..b)
|
| 582 |
+
.map(|i| {
|
| 583 |
+
let r = unsafe { &*region_ptrs[i] };
|
| 584 |
+
let fused = r
|
| 585 |
+
.fused_state
|
| 586 |
+
.as_ref()
|
| 587 |
+
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 588 |
+
Ok(FusedPtrs {
|
| 589 |
+
syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
|
| 590 |
+
syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
|
| 591 |
+
boost: *r.sp_gpu.boost_accessor().device_ptr(),
|
| 592 |
+
active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
|
| 593 |
+
inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
|
| 594 |
+
seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
|
| 595 |
+
seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
|
| 596 |
+
syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
|
| 597 |
+
tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
|
| 598 |
+
cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
|
| 599 |
+
cell_active_a: *fused.cell_active_bits_a.device_ptr(),
|
| 600 |
+
cell_active_b: *fused.cell_active_bits_b.device_ptr(),
|
| 601 |
+
cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
|
| 602 |
+
cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
|
| 603 |
+
inputs: inputs_per_region[i],
|
| 604 |
+
cols_out: cols_per_region[i],
|
| 605 |
+
anom_out: anom_per_region[i],
|
| 606 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 607 |
+
step_scratch: *fused.step_scratch.device_ptr(),
|
| 608 |
+
})
|
| 609 |
+
})
|
| 610 |
+
.collect::<Result<Vec<_>, DriverError>>()?;
|
| 611 |
|
| 612 |
// Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
|
| 613 |
// FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
|
|
|
|
| 619 |
// Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
|
| 620 |
// occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
|
| 621 |
// on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
|
| 622 |
+
let use_cluster = {
|
| 623 |
+
let r0 = unsafe { &*region_ptrs[0] };
|
| 624 |
+
let fused = r0
|
| 625 |
+
.fused_state
|
| 626 |
+
.as_ref()
|
| 627 |
+
.ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
|
| 628 |
+
fused.cluster_info.max_cluster_size > 0
|
| 629 |
+
};
|
| 630 |
|
| 631 |
unsafe {
|
| 632 |
result::ctx::set_current(cu_ctx)?;
|
overlay/hydra/eval.py
CHANGED
|
@@ -138,9 +138,6 @@ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
|
|
| 138 |
num_samples = FACTUAL_SAMPLES
|
| 139 |
batch = FACTUAL_BATCH
|
| 140 |
gen_tokens = FACTUAL_GEN_TOKENS
|
| 141 |
-
# Optional fast incremental decode path for recurrence-capable backbones.
|
| 142 |
-
# If disabled, we preserve the original full-context re-forward behavior.
|
| 143 |
-
incremental_decode = os.environ.get("HYDRA_FACTUAL_GEN_INCREMENTAL", "1") == "1"
|
| 144 |
temps = [0.7, 0.9, 1.1]
|
| 145 |
hits = 0
|
| 146 |
|
|
@@ -157,18 +154,14 @@ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
|
|
| 157 |
temp = temps[batch_idx % len(temps)]
|
| 158 |
batch_idx += 1
|
| 159 |
ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
|
| 160 |
-
logits = model(ctx, targets=None)
|
| 161 |
for _ in range(gen_tokens):
|
|
|
|
| 162 |
next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
|
| 163 |
probs = torch.softmax(next_logits.float() / temp, dim=-1)
|
| 164 |
next_id = torch.multinomial(probs, num_samples=1)
|
| 165 |
ctx = torch.cat([ctx, next_id], dim=1)
|
| 166 |
if ctx.size(1) >= max_seq_len:
|
| 167 |
break
|
| 168 |
-
if incremental_decode:
|
| 169 |
-
logits = model(ctx[:, -1:], targets=None)
|
| 170 |
-
else:
|
| 171 |
-
logits = model(ctx, targets=None)
|
| 172 |
# Transfer to CPU in one shot, no per-row sync
|
| 173 |
all_rows.extend(ctx.cpu().tolist())
|
| 174 |
samples_done += b
|
|
|
|
| 138 |
num_samples = FACTUAL_SAMPLES
|
| 139 |
batch = FACTUAL_BATCH
|
| 140 |
gen_tokens = FACTUAL_GEN_TOKENS
|
|
|
|
|
|
|
|
|
|
| 141 |
temps = [0.7, 0.9, 1.1]
|
| 142 |
hits = 0
|
| 143 |
|
|
|
|
| 154 |
temp = temps[batch_idx % len(temps)]
|
| 155 |
batch_idx += 1
|
| 156 |
ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
|
|
|
|
| 157 |
for _ in range(gen_tokens):
|
| 158 |
+
logits = model(ctx, targets=None)
|
| 159 |
next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
|
| 160 |
probs = torch.softmax(next_logits.float() / temp, dim=-1)
|
| 161 |
next_id = torch.multinomial(probs, num_samples=1)
|
| 162 |
ctx = torch.cat([ctx, next_id], dim=1)
|
| 163 |
if ctx.size(1) >= max_seq_len:
|
| 164 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
# Transfer to CPU in one shot, no per-row sync
|
| 166 |
all_rows.extend(ctx.cpu().tolist())
|
| 167 |
samples_done += b
|
overlay/hydra/model.py
CHANGED
|
@@ -32,58 +32,58 @@ from __future__ import annotations
|
|
| 32 |
|
| 33 |
import os
|
| 34 |
|
| 35 |
-
import torch
|
| 36 |
-
import torch.nn as nn
|
| 37 |
-
import torch.nn.functional as F
|
| 38 |
-
|
| 39 |
-
try:
|
| 40 |
-
from mamba_ssm import Mamba3
|
| 41 |
-
except Exception: # pragma: no cover - depends on optional runtime install
|
| 42 |
-
Mamba3 = None # type: ignore[assignment]
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _get_mamba3_cls():
|
| 46 |
-
global Mamba3
|
| 47 |
-
if Mamba3 is None:
|
| 48 |
-
try:
|
| 49 |
-
from mamba_ssm import Mamba3 as _Mamba3 # type: ignore
|
| 50 |
-
Mamba3 = _Mamba3 # type: ignore[assignment]
|
| 51 |
-
except Exception as exc: # pragma: no cover - environment dependent
|
| 52 |
-
raise ImportError(
|
| 53 |
-
"mamba_ssm is required for Mamba-based HYDRA blocks. "
|
| 54 |
-
"Install mamba-ssm or use HYDRA_BASELINE_ARCH=transformer."
|
| 55 |
-
) from exc
|
| 56 |
-
return Mamba3
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _ensure_triton_cuda_backend_registered() -> None:
|
| 60 |
-
"""Ensure Triton sees exactly one CUDA backend in HF Jobs A10 runtime.
|
| 61 |
-
|
| 62 |
-
In some Triton 3.5.1 environments, `triton.compiler.compiler.backends`
|
| 63 |
-
and `triton.runtime.driver.backends` are empty even though
|
| 64 |
-
`triton.backends.nvidia` is available and CUDA is active. When that
|
| 65 |
-
happens, Mamba3 layernorm path crashes at first forward with
|
| 66 |
-
"0 compatible backends for target (cuda)".
|
| 67 |
-
"""
|
| 68 |
-
try:
|
| 69 |
-
import triton.compiler.compiler as cc
|
| 70 |
-
import triton.runtime.driver as rd
|
| 71 |
-
from triton.backends import Backend
|
| 72 |
-
from triton.backends.nvidia.compiler import CUDABackend
|
| 73 |
-
from triton.backends.nvidia.driver import CudaDriver
|
| 74 |
-
|
| 75 |
-
if hasattr(rd, "backends") and isinstance(rd.backends, dict) and not rd.backends:
|
| 76 |
-
rd.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
|
| 77 |
-
|
| 78 |
-
if hasattr(cc, "backends") and isinstance(cc.backends, dict) and not cc.backends:
|
| 79 |
-
cc.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
|
| 80 |
-
except Exception:
|
| 81 |
-
# Keep model construction resilient; runtime will raise explicit Triton
|
| 82 |
-
# errors later if backend setup is still invalid.
|
| 83 |
-
pass
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
_ensure_triton_cuda_backend_registered()
|
| 87 |
|
| 88 |
from subsystems.hestia_mini import HestiaQAT
|
| 89 |
from subsystems.htm import HTMLayer
|
|
@@ -98,30 +98,30 @@ from hydra.hyena_block import HyenaBlock
|
|
| 98 |
from hydra.optimizer import MuonAdamW
|
| 99 |
|
| 100 |
|
| 101 |
-
def norm(x: torch.Tensor) -> torch.Tensor:
|
| 102 |
-
"""RMSNorm over the last dim — stateless, autocast-friendly."""
|
| 103 |
-
return F.rms_norm(x, (x.size(-1),))
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
class TransformerBaselineBlock(nn.Module):
|
| 107 |
-
"""Transformer-style delta block for matched baseline experiments.
|
| 108 |
-
|
| 109 |
-
This block returns a transformed delta tensor rather than owning the outer
|
| 110 |
-
residual connection, because ManifoldHyperConnection already handles stream
|
| 111 |
-
mixing and residual injection around the block function.
|
| 112 |
-
"""
|
| 113 |
-
|
| 114 |
-
def __init__(self, d_model: int, n_heads: int, expand: int, dropout: float) -> None:
|
| 115 |
-
super().__init__()
|
| 116 |
-
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
|
| 117 |
-
self.ff_in = nn.Linear(d_model, expand * d_model, bias=False)
|
| 118 |
-
self.ff_out = nn.Linear(expand * d_model, d_model, bias=False)
|
| 119 |
-
self.dropout = nn.Dropout(dropout)
|
| 120 |
-
|
| 121 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
-
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
|
| 123 |
-
ff = self.ff_out(F.gelu(self.ff_in(attn_out)))
|
| 124 |
-
return self.dropout(attn_out + ff)
|
| 125 |
|
| 126 |
|
| 127 |
class PostSemClawModel(nn.Module):
|
|
@@ -136,12 +136,12 @@ class PostSemClawModel(nn.Module):
|
|
| 136 |
model(x, y, reduction='mean') -> scalar loss
|
| 137 |
"""
|
| 138 |
|
| 139 |
-
def __init__(self, config):
|
| 140 |
-
super().__init__()
|
| 141 |
-
_ensure_triton_cuda_backend_registered()
|
| 142 |
-
self.config = config
|
| 143 |
-
self._throughput_mode = os.environ.get("HYDRA_THROUGHPUT_MODE", "0") == "1"
|
| 144 |
-
self._baseline_arch = os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower()
|
| 145 |
|
| 146 |
# Token embedding
|
| 147 |
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
|
@@ -163,31 +163,31 @@ class PostSemClawModel(nn.Module):
|
|
| 163 |
print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True)
|
| 164 |
_gdn_layer_set -= _hyena_layer_set
|
| 165 |
|
| 166 |
-
if _gdn_layer_set:
|
| 167 |
-
from hydra.gdn_block import GDNBlock # requires `fla` package
|
| 168 |
-
|
| 169 |
-
def _build_block(i: int) -> nn.Module:
|
| 170 |
-
if self._baseline_arch == "transformer":
|
| 171 |
-
return TransformerBaselineBlock(
|
| 172 |
-
d_model=config.d_model,
|
| 173 |
-
n_heads=config.n_heads,
|
| 174 |
-
expand=config.expand,
|
| 175 |
-
dropout=float(os.environ.get("HYDRA_DROPOUT", "0.2")),
|
| 176 |
-
)
|
| 177 |
-
if i in _hyena_layer_set:
|
| 178 |
-
return HyenaBlock(
|
| 179 |
d_model=config.d_model,
|
| 180 |
seq_len=config.sequence_len,
|
| 181 |
order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")),
|
| 182 |
filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")),
|
| 183 |
)
|
| 184 |
-
if i in _gdn_layer_set:
|
| 185 |
-
return GDNBlock(
|
| 186 |
-
d_model=config.d_model,
|
| 187 |
-
n_heads=config.n_heads,
|
| 188 |
-
)
|
| 189 |
-
mamba3_cls = _get_mamba3_cls()
|
| 190 |
-
return mamba3_cls(
|
| 191 |
d_model=config.d_model,
|
| 192 |
d_state=config.d_state,
|
| 193 |
expand=config.expand,
|
|
@@ -201,43 +201,43 @@ class PostSemClawModel(nn.Module):
|
|
| 201 |
self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)])
|
| 202 |
|
| 203 |
# Full-architecture SDR: offline semantic retina + STE (no-bypass).
|
| 204 |
-
if self._throughput_mode:
|
| 205 |
-
self.sdr_semantic = None
|
| 206 |
-
self.htm = None
|
| 207 |
-
self.htm_proj = None
|
| 208 |
-
self.htm_anom_proj = None
|
| 209 |
-
self.engram = None
|
| 210 |
-
self.engram_layer_idx = -1
|
| 211 |
-
else:
|
| 212 |
-
self.sdr_semantic = SemanticFoldingSDR(
|
| 213 |
-
vocab_size=config.vocab_size,
|
| 214 |
-
n_bits=config.sdr_n_bits,
|
| 215 |
-
target_active=config.sdr_target_active,
|
| 216 |
-
delta_rank=config.sdr_delta_rank,
|
| 217 |
-
som_warmup_steps=config.sdr_som_warmup,
|
| 218 |
-
som_update_interval=config.sdr_som_interval,
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
# HTM spatial pooler + temporal memory (Rust, Hebbian).
|
| 222 |
-
self.htm = HTMLayer(
|
| 223 |
-
input_bits=config.sdr_n_bits,
|
| 224 |
-
n_columns=config.htm_n_columns,
|
| 225 |
-
cells_per_column=config.htm_cells_per_column,
|
| 226 |
-
batch_size=1,
|
| 227 |
-
seed=42,
|
| 228 |
-
learn=True,
|
| 229 |
-
reset_each_forward=True,
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
self.htm_proj = nn.Linear(config.htm_n_columns, config.d_model, bias=False)
|
| 233 |
-
self.htm_anom_proj = nn.Linear(1, config.d_model, bias=False)
|
| 234 |
-
|
| 235 |
-
self.engram = GPUEngram(
|
| 236 |
-
d_model=config.d_model,
|
| 237 |
-
n_columns=config.engram_n_columns,
|
| 238 |
-
max_ngram=3,
|
| 239 |
-
)
|
| 240 |
-
self.engram_layer_idx = config.engram_layer_idx
|
| 241 |
|
| 242 |
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
|
| 243 |
self.mhc = nn.ModuleList([
|
|
@@ -258,18 +258,18 @@ class PostSemClawModel(nn.Module):
|
|
| 258 |
# additional CE losses; no new params. Activated via HYDRA_MTP_K.
|
| 259 |
self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1")))
|
| 260 |
|
| 261 |
-
# Learnability knob 3: gradient checkpointing on Mamba3 blocks.
|
| 262 |
-
self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
|
| 263 |
-
|
| 264 |
-
# Full-arch throughput knob: Engram remains in the architecture, but
|
| 265 |
-
# can run every N forwards and reuse its most recent residual delta on
|
| 266 |
-
# skipped forwards. This amortizes top-k Hopfield retrieval while still
|
| 267 |
-
# injecting Engram signal every microbatch. N=1 preserves exact legacy
|
| 268 |
-
# behavior.
|
| 269 |
-
self._engram_subsample = max(1, int(os.environ.get("HYDRA_ENGRAM_SUBSAMPLE", "1")))
|
| 270 |
-
self._engram_call_idx = 0
|
| 271 |
-
self._engram_delta_cache = None
|
| 272 |
-
self._engram_hit_rate_cache = None
|
| 273 |
|
| 274 |
# Learnability knob 4: doc-separator BOS masking in packed sequences.
|
| 275 |
self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
|
|
@@ -373,8 +373,8 @@ class PostSemClawModel(nn.Module):
|
|
| 373 |
# Required because to_empty() only moves params/buffers, and _retina_indices
|
| 374 |
# is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
|
| 375 |
device = self.wte.weight.device
|
| 376 |
-
if self.sdr_semantic is not None and hasattr(self.sdr_semantic, '_retina_indices'):
|
| 377 |
-
self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
|
| 378 |
|
| 379 |
# Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
|
| 380 |
# vocab=8192; at larger vocabs, smaller std prevents logit blowup.
|
|
@@ -413,20 +413,20 @@ class PostSemClawModel(nn.Module):
|
|
| 413 |
))
|
| 414 |
nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
|
| 415 |
|
| 416 |
-
if self.htm_proj is not None:
|
| 417 |
-
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
|
| 418 |
-
if self.htm_anom_proj is not None:
|
| 419 |
-
nn.init.normal_(self.htm_anom_proj.weight, mean=0.0, std=s)
|
| 420 |
|
| 421 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 422 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 423 |
self.wte.to(dtype=torch.bfloat16)
|
| 424 |
-
if self.htm_proj is not None:
|
| 425 |
-
self.htm_proj.to(dtype=torch.bfloat16)
|
| 426 |
-
if self.htm_anom_proj is not None:
|
| 427 |
-
self.htm_anom_proj.to(dtype=torch.bfloat16)
|
| 428 |
-
if self.engram is not None:
|
| 429 |
-
self.engram.to(dtype=torch.bfloat16)
|
| 430 |
|
| 431 |
def set_bos_token_id(self, bos_id: int) -> None:
|
| 432 |
"""Inform the model of the tokenizer's BOS id so doc-separator
|
|
@@ -472,10 +472,10 @@ class PostSemClawModel(nn.Module):
|
|
| 472 |
wte = sum(p.numel() for p in self.wte.parameters())
|
| 473 |
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 474 |
blocks = sum(p.numel() for p in self.blocks.parameters())
|
| 475 |
-
sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else 0
|
| 476 |
-
htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) if self.htm_proj is not None else 0
|
| 477 |
-
htm_anom_proj = sum(p.numel() for p in self.htm_anom_proj.parameters()) if self.htm_anom_proj is not None else 0
|
| 478 |
-
engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
|
| 479 |
total = sum(p.numel() for p in self.parameters())
|
| 480 |
return {
|
| 481 |
'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
|
|
@@ -544,19 +544,19 @@ class PostSemClawModel(nn.Module):
|
|
| 544 |
# for p in self.sdr_semantic.parameters():
|
| 545 |
# if p.dim() == 2:
|
| 546 |
# matrix_params.append(p)
|
| 547 |
-
if self.htm_proj is not None:
|
| 548 |
-
for name, p in self.htm_proj.named_parameters():
|
| 549 |
-
if _muon_eligible(name, p):
|
| 550 |
-
matrix_params.append(p)
|
| 551 |
-
if self.engram is not None:
|
| 552 |
-
for name, p in self.engram.named_parameters():
|
| 553 |
-
if _muon_eligible(name, p):
|
| 554 |
-
matrix_params.append(p)
|
| 555 |
|
| 556 |
# SDR params are intentionally not in any optimizer group — they
|
| 557 |
# receive no gradient in the current forward, so any update would be
|
| 558 |
# pure noise (weight_decay × lr on a zero-grad param).
|
| 559 |
-
sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else set()
|
| 560 |
assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
|
| 561 |
scalar_params = [
|
| 562 |
p for p in self.parameters()
|
|
@@ -565,7 +565,7 @@ class PostSemClawModel(nn.Module):
|
|
| 565 |
|
| 566 |
total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
|
| 567 |
total_params = len(list(self.parameters()))
|
| 568 |
-
sdr_excluded = len(list(self.sdr_semantic.parameters())) if self.sdr_semantic is not None else 0
|
| 569 |
assert total_assigned + sdr_excluded == total_params, (
|
| 570 |
f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
|
| 571 |
f"{sdr_excluded} vs total {total_params}"
|
|
@@ -633,59 +633,59 @@ class PostSemClawModel(nn.Module):
|
|
| 633 |
else:
|
| 634 |
_t0 = None
|
| 635 |
|
| 636 |
-
dense_emb = self.wte(idx) # (B, T, d_model) bf16
|
| 637 |
-
if self._throughput_mode:
|
| 638 |
-
self._last_sdr = None
|
| 639 |
-
sdr_active_bits = 0.0
|
| 640 |
-
htm_anomaly = dense_emb.new_tensor(0.0)
|
| 641 |
-
x = norm(dense_emb)
|
| 642 |
-
if _profile:
|
| 643 |
-
_t_htm_async = _ev()
|
| 644 |
-
_t_wte = _ev()
|
| 645 |
-
_t_htm_await = _ev()
|
| 646 |
-
_t_htm_proj = _ev()
|
| 647 |
-
else:
|
| 648 |
-
sdr_binary = self.sdr_semantic.binary_only(idx)
|
| 649 |
-
self._last_sdr = sdr_binary
|
| 650 |
-
_htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
|
| 651 |
-
if not hasattr(self, '_htm_call_idx'):
|
| 652 |
-
self._htm_call_idx = 0
|
| 653 |
-
|
| 654 |
-
_run_htm = (self._htm_call_idx % _htm_sub == 0)
|
| 655 |
-
self._htm_call_idx += 1
|
| 656 |
-
if _run_htm:
|
| 657 |
-
htm_handle = self.htm.forward_async(sdr_binary)
|
| 658 |
-
else:
|
| 659 |
-
htm_handle = None
|
| 660 |
-
|
| 661 |
-
if _profile: _t_htm_async = _ev()
|
| 662 |
-
if _profile: _t_wte = _ev()
|
| 663 |
-
|
| 664 |
-
if _run_htm:
|
| 665 |
-
htm_out = self.htm.forward_await(htm_handle)
|
| 666 |
-
self._htm_cache = htm_out.detach()
|
| 667 |
-
elif hasattr(self, '_htm_cache') and self._htm_cache is not None and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
|
| 668 |
-
htm_out = self._htm_cache
|
| 669 |
-
else:
|
| 670 |
-
htm_handle = self.htm.forward_async(sdr_binary)
|
| 671 |
-
htm_out = self.htm.forward_await(htm_handle)
|
| 672 |
-
self._htm_cache = htm_out.detach()
|
| 673 |
-
|
| 674 |
-
if _profile: _t_htm_await = _ev()
|
| 675 |
-
with torch.no_grad():
|
| 676 |
-
sdr_active_bits = float(self.sdr_semantic.target_active)
|
| 677 |
-
htm_anomaly = htm_out[..., -1].mean()
|
| 678 |
-
if self._htm_stop_grad:
|
| 679 |
-
htm_out = htm_out.detach()
|
| 680 |
-
htm_cols = htm_out[..., :-1].to(dense_emb.dtype)
|
| 681 |
-
htm_anom = htm_out[..., -1:].to(dense_emb.dtype)
|
| 682 |
-
htm_proj_out = self.htm_proj(htm_cols) + self.htm_anom_proj(htm_anom)
|
| 683 |
-
x = norm(dense_emb + htm_proj_out)
|
| 684 |
-
if _profile: _t_htm_proj = _ev()
|
| 685 |
-
|
| 686 |
-
# mHC-routed Mamba-3 stack with Engram injection at configured layer.
|
| 687 |
-
streams = self.mhc[0].init_streams(x)
|
| 688 |
-
_profile_layer_events = []
|
| 689 |
|
| 690 |
# Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
|
| 691 |
# measure residual contribution of each layer: delta_N = h_post - h_pre.
|
|
@@ -702,20 +702,20 @@ class PostSemClawModel(nn.Module):
|
|
| 702 |
h_pre = self.mhc[0].merge_streams(streams).detach().float()
|
| 703 |
_run_svd = (self._diag_step % self._diag_svd_every) == 0
|
| 704 |
|
| 705 |
-
for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
|
| 706 |
-
if _profile:
|
| 707 |
-
_t_layer_start = _ev()
|
| 708 |
-
_t_layer_mhc = None
|
| 709 |
-
_t_engram_start = None
|
| 710 |
-
_t_engram_end = None
|
| 711 |
-
else:
|
| 712 |
-
_t_layer_start = None
|
| 713 |
-
_t_layer_mhc = None
|
| 714 |
-
_t_engram_start = None
|
| 715 |
-
_t_engram_end = None
|
| 716 |
-
|
| 717 |
-
def _block_fn(h, _block=block):
|
| 718 |
-
return self.drop(_block(norm(h)))
|
| 719 |
|
| 720 |
# Learnability #3: gradient checkpointing. Wrap the block-fn so
|
| 721 |
# the mhc layer's internal uses of it re-run the block in backward
|
|
@@ -726,31 +726,31 @@ class PostSemClawModel(nn.Module):
|
|
| 726 |
_raw_fn = _block_fn
|
| 727 |
def _block_fn(h, _raw=_raw_fn): # noqa: E731
|
| 728 |
return _ckpt.checkpoint(_raw, h, use_reentrant=False)
|
| 729 |
-
|
| 730 |
-
streams = mhc_layer(streams, _block_fn)
|
| 731 |
-
if _profile:
|
| 732 |
-
_t_layer_mhc = _ev()
|
| 733 |
-
|
| 734 |
-
if self.engram is not None and i == self.engram_layer_idx:
|
| 735 |
-
if _profile: _t_engram_start = _ev()
|
| 736 |
-
x_mid = mhc_layer.merge_streams(streams)
|
| 737 |
-
_run_engram = (self._engram_call_idx % self._engram_subsample == 0)
|
| 738 |
-
self._engram_call_idx += 1
|
| 739 |
-
if _run_engram or self._engram_delta_cache is None or self._engram_delta_cache.shape != x_mid.shape:
|
| 740 |
-
x_engram, hit_rate = self.engram(x_mid, idx)
|
| 741 |
-
self._engram_delta_cache = (x_engram - x_mid).detach()
|
| 742 |
-
self._engram_hit_rate_cache = hit_rate.detach() if torch.is_tensor(hit_rate) else hit_rate
|
| 743 |
-
x_mid = x_engram
|
| 744 |
-
else:
|
| 745 |
-
# Preserve gradient flow through the identity path while
|
| 746 |
-
# reusing a detached Engram residual. The Engram module
|
| 747 |
-
# still contributes to every forward; its expensive top-k
|
| 748 |
-
# retrieval and parameter gradients run on the cadence.
|
| 749 |
-
x_mid = x_mid + self._engram_delta_cache.to(dtype=x_mid.dtype, device=x_mid.device)
|
| 750 |
-
hit_rate = self._engram_hit_rate_cache
|
| 751 |
-
streams = mhc_layer.init_streams(x_mid)
|
| 752 |
-
self._metrics['engram_hit_rate'] = hit_rate
|
| 753 |
-
if _profile: _t_engram_end = _ev()
|
| 754 |
|
| 755 |
if _diag:
|
| 756 |
with torch.no_grad():
|
|
@@ -773,20 +773,20 @@ class PostSemClawModel(nn.Module):
|
|
| 773 |
self._metrics[f'layer_{i}_eff_rank'] = eff_rank
|
| 774 |
except Exception:
|
| 775 |
pass
|
| 776 |
-
h_pre = h_post
|
| 777 |
-
|
| 778 |
-
if _profile:
|
| 779 |
-
_profile_layer_events.append(
|
| 780 |
-
(i, _t_layer_start, _t_layer_mhc, _t_engram_start, _t_engram_end, _ev())
|
| 781 |
-
)
|
| 782 |
|
| 783 |
if _diag:
|
| 784 |
self._diag_step += 1
|
| 785 |
|
| 786 |
if _profile: _t_blocks = _ev()
|
| 787 |
|
| 788 |
-
self._metrics['sdr_active_bits'] = sdr_active_bits
|
| 789 |
-
self._metrics['htm_anomaly'] = htm_anomaly
|
| 790 |
|
| 791 |
x = self.mhc[-1].merge_streams(streams)
|
| 792 |
x = norm(x)
|
|
@@ -1000,8 +1000,8 @@ class PostSemClawModel(nn.Module):
|
|
| 1000 |
_t_end = _ev()
|
| 1001 |
torch.cuda.synchronize()
|
| 1002 |
def _ms(a, b): return a.elapsed_time(b)
|
| 1003 |
-
print(
|
| 1004 |
-
f"[PROFILE B={B} T={T}] "
|
| 1005 |
f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
|
| 1006 |
f"wte={_ms(_t_htm_async, _t_wte):.2f} "
|
| 1007 |
f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
|
|
@@ -1010,23 +1010,23 @@ class PostSemClawModel(nn.Module):
|
|
| 1010 |
f"merge={_ms(_t_blocks, _t_merge):.2f} "
|
| 1011 |
f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
|
| 1012 |
f"total={_ms(_t0, _t_end):.2f} ms",
|
| 1013 |
-
flush=True,
|
| 1014 |
-
)
|
| 1015 |
-
for _li, _start, _after_mhc, _engram_start, _engram_end, _end in _profile_layer_events:
|
| 1016 |
-
print(
|
| 1017 |
-
f"[PROFILE_LAYER B={B} T={T} layer={_li}] "
|
| 1018 |
-
f"mhc_block={_ms(_start, _after_mhc):.2f} "
|
| 1019 |
-
f"layer_total={_ms(_start, _end):.2f} ms",
|
| 1020 |
-
flush=True,
|
| 1021 |
-
)
|
| 1022 |
-
if _engram_start is not None and _engram_end is not None:
|
| 1023 |
-
print(
|
| 1024 |
-
f"[PROFILE_ENGRAM B={B} T={T} layer={_li}] "
|
| 1025 |
-
f"engram={_ms(_engram_start, _engram_end):.2f} "
|
| 1026 |
-
f"post_layer_total={_ms(_after_mhc, _end):.2f} ms",
|
| 1027 |
-
flush=True,
|
| 1028 |
-
)
|
| 1029 |
-
return out
|
| 1030 |
|
| 1031 |
logits = self.lm_head(x).float()
|
| 1032 |
if _softcap_clamp:
|
|
|
|
| 32 |
|
| 33 |
import os
|
| 34 |
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from mamba_ssm import Mamba3
|
| 41 |
+
except Exception: # pragma: no cover - depends on optional runtime install
|
| 42 |
+
Mamba3 = None # type: ignore[assignment]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_mamba3_cls():
|
| 46 |
+
global Mamba3
|
| 47 |
+
if Mamba3 is None:
|
| 48 |
+
try:
|
| 49 |
+
from mamba_ssm import Mamba3 as _Mamba3 # type: ignore
|
| 50 |
+
Mamba3 = _Mamba3 # type: ignore[assignment]
|
| 51 |
+
except Exception as exc: # pragma: no cover - environment dependent
|
| 52 |
+
raise ImportError(
|
| 53 |
+
"mamba_ssm is required for Mamba-based HYDRA blocks. "
|
| 54 |
+
"Install mamba-ssm or use HYDRA_BASELINE_ARCH=transformer."
|
| 55 |
+
) from exc
|
| 56 |
+
return Mamba3
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _ensure_triton_cuda_backend_registered() -> None:
|
| 60 |
+
"""Ensure Triton sees exactly one CUDA backend in HF Jobs A10 runtime.
|
| 61 |
+
|
| 62 |
+
In some Triton 3.5.1 environments, `triton.compiler.compiler.backends`
|
| 63 |
+
and `triton.runtime.driver.backends` are empty even though
|
| 64 |
+
`triton.backends.nvidia` is available and CUDA is active. When that
|
| 65 |
+
happens, Mamba3 layernorm path crashes at first forward with
|
| 66 |
+
"0 compatible backends for target (cuda)".
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
import triton.compiler.compiler as cc
|
| 70 |
+
import triton.runtime.driver as rd
|
| 71 |
+
from triton.backends import Backend
|
| 72 |
+
from triton.backends.nvidia.compiler import CUDABackend
|
| 73 |
+
from triton.backends.nvidia.driver import CudaDriver
|
| 74 |
+
|
| 75 |
+
if hasattr(rd, "backends") and isinstance(rd.backends, dict) and not rd.backends:
|
| 76 |
+
rd.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
|
| 77 |
+
|
| 78 |
+
if hasattr(cc, "backends") and isinstance(cc.backends, dict) and not cc.backends:
|
| 79 |
+
cc.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
|
| 80 |
+
except Exception:
|
| 81 |
+
# Keep model construction resilient; runtime will raise explicit Triton
|
| 82 |
+
# errors later if backend setup is still invalid.
|
| 83 |
+
pass
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
_ensure_triton_cuda_backend_registered()
|
| 87 |
|
| 88 |
from subsystems.hestia_mini import HestiaQAT
|
| 89 |
from subsystems.htm import HTMLayer
|
|
|
|
| 98 |
from hydra.optimizer import MuonAdamW
|
| 99 |
|
| 100 |
|
| 101 |
+
def norm(x: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
"""RMSNorm over the last dim — stateless, autocast-friendly."""
|
| 103 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TransformerBaselineBlock(nn.Module):
|
| 107 |
+
"""Transformer-style delta block for matched baseline experiments.
|
| 108 |
+
|
| 109 |
+
This block returns a transformed delta tensor rather than owning the outer
|
| 110 |
+
residual connection, because ManifoldHyperConnection already handles stream
|
| 111 |
+
mixing and residual injection around the block function.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, d_model: int, n_heads: int, expand: int, dropout: float) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
|
| 117 |
+
self.ff_in = nn.Linear(d_model, expand * d_model, bias=False)
|
| 118 |
+
self.ff_out = nn.Linear(expand * d_model, d_model, bias=False)
|
| 119 |
+
self.dropout = nn.Dropout(dropout)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
attn_out, _ = self.self_attn(x, x, x, need_weights=False)
|
| 123 |
+
ff = self.ff_out(F.gelu(self.ff_in(attn_out)))
|
| 124 |
+
return self.dropout(attn_out + ff)
|
| 125 |
|
| 126 |
|
| 127 |
class PostSemClawModel(nn.Module):
|
|
|
|
| 136 |
model(x, y, reduction='mean') -> scalar loss
|
| 137 |
"""
|
| 138 |
|
| 139 |
+
def __init__(self, config):
|
| 140 |
+
super().__init__()
|
| 141 |
+
_ensure_triton_cuda_backend_registered()
|
| 142 |
+
self.config = config
|
| 143 |
+
self._throughput_mode = os.environ.get("HYDRA_THROUGHPUT_MODE", "0") == "1"
|
| 144 |
+
self._baseline_arch = os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower()
|
| 145 |
|
| 146 |
# Token embedding
|
| 147 |
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
|
|
|
| 163 |
print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True)
|
| 164 |
_gdn_layer_set -= _hyena_layer_set
|
| 165 |
|
| 166 |
+
if _gdn_layer_set:
|
| 167 |
+
from hydra.gdn_block import GDNBlock # requires `fla` package
|
| 168 |
+
|
| 169 |
+
def _build_block(i: int) -> nn.Module:
|
| 170 |
+
if self._baseline_arch == "transformer":
|
| 171 |
+
return TransformerBaselineBlock(
|
| 172 |
+
d_model=config.d_model,
|
| 173 |
+
n_heads=config.n_heads,
|
| 174 |
+
expand=config.expand,
|
| 175 |
+
dropout=float(os.environ.get("HYDRA_DROPOUT", "0.2")),
|
| 176 |
+
)
|
| 177 |
+
if i in _hyena_layer_set:
|
| 178 |
+
return HyenaBlock(
|
| 179 |
d_model=config.d_model,
|
| 180 |
seq_len=config.sequence_len,
|
| 181 |
order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")),
|
| 182 |
filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")),
|
| 183 |
)
|
| 184 |
+
if i in _gdn_layer_set:
|
| 185 |
+
return GDNBlock(
|
| 186 |
+
d_model=config.d_model,
|
| 187 |
+
n_heads=config.n_heads,
|
| 188 |
+
)
|
| 189 |
+
mamba3_cls = _get_mamba3_cls()
|
| 190 |
+
return mamba3_cls(
|
| 191 |
d_model=config.d_model,
|
| 192 |
d_state=config.d_state,
|
| 193 |
expand=config.expand,
|
|
|
|
| 201 |
self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)])
|
| 202 |
|
| 203 |
# Full-architecture SDR: offline semantic retina + STE (no-bypass).
|
| 204 |
+
if self._throughput_mode:
|
| 205 |
+
self.sdr_semantic = None
|
| 206 |
+
self.htm = None
|
| 207 |
+
self.htm_proj = None
|
| 208 |
+
self.htm_anom_proj = None
|
| 209 |
+
self.engram = None
|
| 210 |
+
self.engram_layer_idx = -1
|
| 211 |
+
else:
|
| 212 |
+
self.sdr_semantic = SemanticFoldingSDR(
|
| 213 |
+
vocab_size=config.vocab_size,
|
| 214 |
+
n_bits=config.sdr_n_bits,
|
| 215 |
+
target_active=config.sdr_target_active,
|
| 216 |
+
delta_rank=config.sdr_delta_rank,
|
| 217 |
+
som_warmup_steps=config.sdr_som_warmup,
|
| 218 |
+
som_update_interval=config.sdr_som_interval,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# HTM spatial pooler + temporal memory (Rust, Hebbian).
|
| 222 |
+
self.htm = HTMLayer(
|
| 223 |
+
input_bits=config.sdr_n_bits,
|
| 224 |
+
n_columns=config.htm_n_columns,
|
| 225 |
+
cells_per_column=config.htm_cells_per_column,
|
| 226 |
+
batch_size=1,
|
| 227 |
+
seed=42,
|
| 228 |
+
learn=True,
|
| 229 |
+
reset_each_forward=True,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.htm_proj = nn.Linear(config.htm_n_columns, config.d_model, bias=False)
|
| 233 |
+
self.htm_anom_proj = nn.Linear(1, config.d_model, bias=False)
|
| 234 |
+
|
| 235 |
+
self.engram = GPUEngram(
|
| 236 |
+
d_model=config.d_model,
|
| 237 |
+
n_columns=config.engram_n_columns,
|
| 238 |
+
max_ngram=3,
|
| 239 |
+
)
|
| 240 |
+
self.engram_layer_idx = config.engram_layer_idx
|
| 241 |
|
| 242 |
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
|
| 243 |
self.mhc = nn.ModuleList([
|
|
|
|
| 258 |
# additional CE losses; no new params. Activated via HYDRA_MTP_K.
|
| 259 |
self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1")))
|
| 260 |
|
| 261 |
+
# Learnability knob 3: gradient checkpointing on Mamba3 blocks.
|
| 262 |
+
self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
|
| 263 |
+
|
| 264 |
+
# Full-arch throughput knob: Engram remains in the architecture, but
|
| 265 |
+
# can run every N forwards and reuse its most recent residual delta on
|
| 266 |
+
# skipped forwards. This amortizes top-k Hopfield retrieval while still
|
| 267 |
+
# injecting Engram signal every microbatch. N=1 preserves exact legacy
|
| 268 |
+
# behavior.
|
| 269 |
+
self._engram_subsample = max(1, int(os.environ.get("HYDRA_ENGRAM_SUBSAMPLE", "1")))
|
| 270 |
+
self._engram_call_idx = 0
|
| 271 |
+
self._engram_delta_cache = None
|
| 272 |
+
self._engram_hit_rate_cache = None
|
| 273 |
|
| 274 |
# Learnability knob 4: doc-separator BOS masking in packed sequences.
|
| 275 |
self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
|
|
|
|
| 373 |
# Required because to_empty() only moves params/buffers, and _retina_indices
|
| 374 |
# is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
|
| 375 |
device = self.wte.weight.device
|
| 376 |
+
if self.sdr_semantic is not None and hasattr(self.sdr_semantic, '_retina_indices'):
|
| 377 |
+
self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
|
| 378 |
|
| 379 |
# Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
|
| 380 |
# vocab=8192; at larger vocabs, smaller std prevents logit blowup.
|
|
|
|
| 413 |
))
|
| 414 |
nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
|
| 415 |
|
| 416 |
+
if self.htm_proj is not None:
|
| 417 |
+
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
|
| 418 |
+
if self.htm_anom_proj is not None:
|
| 419 |
+
nn.init.normal_(self.htm_anom_proj.weight, mean=0.0, std=s)
|
| 420 |
|
| 421 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 422 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 423 |
self.wte.to(dtype=torch.bfloat16)
|
| 424 |
+
if self.htm_proj is not None:
|
| 425 |
+
self.htm_proj.to(dtype=torch.bfloat16)
|
| 426 |
+
if self.htm_anom_proj is not None:
|
| 427 |
+
self.htm_anom_proj.to(dtype=torch.bfloat16)
|
| 428 |
+
if self.engram is not None:
|
| 429 |
+
self.engram.to(dtype=torch.bfloat16)
|
| 430 |
|
| 431 |
def set_bos_token_id(self, bos_id: int) -> None:
|
| 432 |
"""Inform the model of the tokenizer's BOS id so doc-separator
|
|
|
|
| 472 |
wte = sum(p.numel() for p in self.wte.parameters())
|
| 473 |
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 474 |
blocks = sum(p.numel() for p in self.blocks.parameters())
|
| 475 |
+
sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else 0
|
| 476 |
+
htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) if self.htm_proj is not None else 0
|
| 477 |
+
htm_anom_proj = sum(p.numel() for p in self.htm_anom_proj.parameters()) if self.htm_anom_proj is not None else 0
|
| 478 |
+
engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
|
| 479 |
total = sum(p.numel() for p in self.parameters())
|
| 480 |
return {
|
| 481 |
'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
|
|
|
|
| 544 |
# for p in self.sdr_semantic.parameters():
|
| 545 |
# if p.dim() == 2:
|
| 546 |
# matrix_params.append(p)
|
| 547 |
+
if self.htm_proj is not None:
|
| 548 |
+
for name, p in self.htm_proj.named_parameters():
|
| 549 |
+
if _muon_eligible(name, p):
|
| 550 |
+
matrix_params.append(p)
|
| 551 |
+
if self.engram is not None:
|
| 552 |
+
for name, p in self.engram.named_parameters():
|
| 553 |
+
if _muon_eligible(name, p):
|
| 554 |
+
matrix_params.append(p)
|
| 555 |
|
| 556 |
# SDR params are intentionally not in any optimizer group — they
|
| 557 |
# receive no gradient in the current forward, so any update would be
|
| 558 |
# pure noise (weight_decay × lr on a zero-grad param).
|
| 559 |
+
sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else set()
|
| 560 |
assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
|
| 561 |
scalar_params = [
|
| 562 |
p for p in self.parameters()
|
|
|
|
| 565 |
|
| 566 |
total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
|
| 567 |
total_params = len(list(self.parameters()))
|
| 568 |
+
sdr_excluded = len(list(self.sdr_semantic.parameters())) if self.sdr_semantic is not None else 0
|
| 569 |
assert total_assigned + sdr_excluded == total_params, (
|
| 570 |
f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
|
| 571 |
f"{sdr_excluded} vs total {total_params}"
|
|
|
|
| 633 |
else:
|
| 634 |
_t0 = None
|
| 635 |
|
| 636 |
+
dense_emb = self.wte(idx) # (B, T, d_model) bf16
|
| 637 |
+
if self._throughput_mode:
|
| 638 |
+
self._last_sdr = None
|
| 639 |
+
sdr_active_bits = 0.0
|
| 640 |
+
htm_anomaly = dense_emb.new_tensor(0.0)
|
| 641 |
+
x = norm(dense_emb)
|
| 642 |
+
if _profile:
|
| 643 |
+
_t_htm_async = _ev()
|
| 644 |
+
_t_wte = _ev()
|
| 645 |
+
_t_htm_await = _ev()
|
| 646 |
+
_t_htm_proj = _ev()
|
| 647 |
+
else:
|
| 648 |
+
sdr_binary = self.sdr_semantic.binary_only(idx)
|
| 649 |
+
self._last_sdr = sdr_binary
|
| 650 |
+
_htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
|
| 651 |
+
if not hasattr(self, '_htm_call_idx'):
|
| 652 |
+
self._htm_call_idx = 0
|
| 653 |
+
|
| 654 |
+
_run_htm = (self._htm_call_idx % _htm_sub == 0)
|
| 655 |
+
self._htm_call_idx += 1
|
| 656 |
+
if _run_htm:
|
| 657 |
+
htm_handle = self.htm.forward_async(sdr_binary)
|
| 658 |
+
else:
|
| 659 |
+
htm_handle = None
|
| 660 |
+
|
| 661 |
+
if _profile: _t_htm_async = _ev()
|
| 662 |
+
if _profile: _t_wte = _ev()
|
| 663 |
+
|
| 664 |
+
if _run_htm:
|
| 665 |
+
htm_out = self.htm.forward_await(htm_handle)
|
| 666 |
+
self._htm_cache = htm_out.detach()
|
| 667 |
+
elif hasattr(self, '_htm_cache') and self._htm_cache is not None and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
|
| 668 |
+
htm_out = self._htm_cache
|
| 669 |
+
else:
|
| 670 |
+
htm_handle = self.htm.forward_async(sdr_binary)
|
| 671 |
+
htm_out = self.htm.forward_await(htm_handle)
|
| 672 |
+
self._htm_cache = htm_out.detach()
|
| 673 |
+
|
| 674 |
+
if _profile: _t_htm_await = _ev()
|
| 675 |
+
with torch.no_grad():
|
| 676 |
+
sdr_active_bits = float(self.sdr_semantic.target_active)
|
| 677 |
+
htm_anomaly = htm_out[..., -1].mean()
|
| 678 |
+
if self._htm_stop_grad:
|
| 679 |
+
htm_out = htm_out.detach()
|
| 680 |
+
htm_cols = htm_out[..., :-1].to(dense_emb.dtype)
|
| 681 |
+
htm_anom = htm_out[..., -1:].to(dense_emb.dtype)
|
| 682 |
+
htm_proj_out = self.htm_proj(htm_cols) + self.htm_anom_proj(htm_anom)
|
| 683 |
+
x = norm(dense_emb + htm_proj_out)
|
| 684 |
+
if _profile: _t_htm_proj = _ev()
|
| 685 |
+
|
| 686 |
+
# mHC-routed Mamba-3 stack with Engram injection at configured layer.
|
| 687 |
+
streams = self.mhc[0].init_streams(x)
|
| 688 |
+
_profile_layer_events = []
|
| 689 |
|
| 690 |
# Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
|
| 691 |
# measure residual contribution of each layer: delta_N = h_post - h_pre.
|
|
|
|
| 702 |
h_pre = self.mhc[0].merge_streams(streams).detach().float()
|
| 703 |
_run_svd = (self._diag_step % self._diag_svd_every) == 0
|
| 704 |
|
| 705 |
+
for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
|
| 706 |
+
if _profile:
|
| 707 |
+
_t_layer_start = _ev()
|
| 708 |
+
_t_layer_mhc = None
|
| 709 |
+
_t_engram_start = None
|
| 710 |
+
_t_engram_end = None
|
| 711 |
+
else:
|
| 712 |
+
_t_layer_start = None
|
| 713 |
+
_t_layer_mhc = None
|
| 714 |
+
_t_engram_start = None
|
| 715 |
+
_t_engram_end = None
|
| 716 |
+
|
| 717 |
+
def _block_fn(h, _block=block):
|
| 718 |
+
return self.drop(_block(norm(h)))
|
| 719 |
|
| 720 |
# Learnability #3: gradient checkpointing. Wrap the block-fn so
|
| 721 |
# the mhc layer's internal uses of it re-run the block in backward
|
|
|
|
| 726 |
_raw_fn = _block_fn
|
| 727 |
def _block_fn(h, _raw=_raw_fn): # noqa: E731
|
| 728 |
return _ckpt.checkpoint(_raw, h, use_reentrant=False)
|
| 729 |
+
|
| 730 |
+
streams = mhc_layer(streams, _block_fn)
|
| 731 |
+
if _profile:
|
| 732 |
+
_t_layer_mhc = _ev()
|
| 733 |
+
|
| 734 |
+
if self.engram is not None and i == self.engram_layer_idx:
|
| 735 |
+
if _profile: _t_engram_start = _ev()
|
| 736 |
+
x_mid = mhc_layer.merge_streams(streams)
|
| 737 |
+
_run_engram = (self._engram_call_idx % self._engram_subsample == 0)
|
| 738 |
+
self._engram_call_idx += 1
|
| 739 |
+
if _run_engram or self._engram_delta_cache is None or self._engram_delta_cache.shape != x_mid.shape:
|
| 740 |
+
x_engram, hit_rate = self.engram(x_mid, idx)
|
| 741 |
+
self._engram_delta_cache = (x_engram - x_mid).detach()
|
| 742 |
+
self._engram_hit_rate_cache = hit_rate.detach() if torch.is_tensor(hit_rate) else hit_rate
|
| 743 |
+
x_mid = x_engram
|
| 744 |
+
else:
|
| 745 |
+
# Preserve gradient flow through the identity path while
|
| 746 |
+
# reusing a detached Engram residual. The Engram module
|
| 747 |
+
# still contributes to every forward; its expensive top-k
|
| 748 |
+
# retrieval and parameter gradients run on the cadence.
|
| 749 |
+
x_mid = x_mid + self._engram_delta_cache.to(dtype=x_mid.dtype, device=x_mid.device)
|
| 750 |
+
hit_rate = self._engram_hit_rate_cache
|
| 751 |
+
streams = mhc_layer.init_streams(x_mid)
|
| 752 |
+
self._metrics['engram_hit_rate'] = hit_rate
|
| 753 |
+
if _profile: _t_engram_end = _ev()
|
| 754 |
|
| 755 |
if _diag:
|
| 756 |
with torch.no_grad():
|
|
|
|
| 773 |
self._metrics[f'layer_{i}_eff_rank'] = eff_rank
|
| 774 |
except Exception:
|
| 775 |
pass
|
| 776 |
+
h_pre = h_post
|
| 777 |
+
|
| 778 |
+
if _profile:
|
| 779 |
+
_profile_layer_events.append(
|
| 780 |
+
(i, _t_layer_start, _t_layer_mhc, _t_engram_start, _t_engram_end, _ev())
|
| 781 |
+
)
|
| 782 |
|
| 783 |
if _diag:
|
| 784 |
self._diag_step += 1
|
| 785 |
|
| 786 |
if _profile: _t_blocks = _ev()
|
| 787 |
|
| 788 |
+
self._metrics['sdr_active_bits'] = sdr_active_bits
|
| 789 |
+
self._metrics['htm_anomaly'] = htm_anomaly
|
| 790 |
|
| 791 |
x = self.mhc[-1].merge_streams(streams)
|
| 792 |
x = norm(x)
|
|
|
|
| 1000 |
_t_end = _ev()
|
| 1001 |
torch.cuda.synchronize()
|
| 1002 |
def _ms(a, b): return a.elapsed_time(b)
|
| 1003 |
+
print(
|
| 1004 |
+
f"[PROFILE B={B} T={T}] "
|
| 1005 |
f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
|
| 1006 |
f"wte={_ms(_t_htm_async, _t_wte):.2f} "
|
| 1007 |
f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
|
|
|
|
| 1010 |
f"merge={_ms(_t_blocks, _t_merge):.2f} "
|
| 1011 |
f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
|
| 1012 |
f"total={_ms(_t0, _t_end):.2f} ms",
|
| 1013 |
+
flush=True,
|
| 1014 |
+
)
|
| 1015 |
+
for _li, _start, _after_mhc, _engram_start, _engram_end, _end in _profile_layer_events:
|
| 1016 |
+
print(
|
| 1017 |
+
f"[PROFILE_LAYER B={B} T={T} layer={_li}] "
|
| 1018 |
+
f"mhc_block={_ms(_start, _after_mhc):.2f} "
|
| 1019 |
+
f"layer_total={_ms(_start, _end):.2f} ms",
|
| 1020 |
+
flush=True,
|
| 1021 |
+
)
|
| 1022 |
+
if _engram_start is not None and _engram_end is not None:
|
| 1023 |
+
print(
|
| 1024 |
+
f"[PROFILE_ENGRAM B={B} T={T} layer={_li}] "
|
| 1025 |
+
f"engram={_ms(_engram_start, _engram_end):.2f} "
|
| 1026 |
+
f"post_layer_total={_ms(_after_mhc, _end):.2f} ms",
|
| 1027 |
+
flush=True,
|
| 1028 |
+
)
|
| 1029 |
+
return out
|
| 1030 |
|
| 1031 |
logits = self.lm_head(x).float()
|
| 1032 |
if _softcap_clamp:
|
overlay/hydra/training.py
CHANGED
|
@@ -4,20 +4,20 @@ Extracted from the monolithic train.py (W1 modularization). Semantics
|
|
| 4 |
preserved. Public entrypoint: `main()`.
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
import gc
|
| 10 |
-
import hashlib
|
| 11 |
-
import json
|
| 12 |
-
import math
|
| 13 |
-
import os
|
| 14 |
-
import sys
|
| 15 |
-
import threading
|
| 16 |
-
import time
|
| 17 |
-
from collections.abc import Mapping
|
| 18 |
-
from dataclasses import asdict
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from typing import Any
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
@@ -133,7 +133,7 @@ def _ckpt_snapshot_state_dicts(
|
|
| 133 |
return msd, osd
|
| 134 |
|
| 135 |
|
| 136 |
-
def save_ckpt(
|
| 137 |
model: PostSemClawModel,
|
| 138 |
optimizer: torch.optim.Optimizer,
|
| 139 |
config: PostSemClawConfig,
|
|
@@ -214,233 +214,233 @@ def save_ckpt(
|
|
| 214 |
target=_write, daemon=True, name=f"ckpt-save-{step}"
|
| 215 |
)
|
| 216 |
_CKPT_WORKER_THREAD.start()
|
| 217 |
-
except Exception as e:
|
| 218 |
-
print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def _env_flag_enabled(env: Mapping[str, str], key: str) -> bool:
|
| 222 |
-
value = str(env.get(key, "0") or "0").strip().lower()
|
| 223 |
-
return value not in {"", "0", "false", "no", "off"}
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def _env_int(env: Mapping[str, str], key: str, default: int) -> int:
|
| 227 |
-
try:
|
| 228 |
-
return int(str(env.get(key, str(default)) or str(default)))
|
| 229 |
-
except ValueError:
|
| 230 |
-
return default
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def architecture_compliance_payload(env: Mapping[str, str]) -> dict[str, bool | int | str]:
|
| 234 |
-
throughput_mode = _env_flag_enabled(env, "HYDRA_THROUGHPUT_MODE")
|
| 235 |
-
fastpath = _env_flag_enabled(env, "HYDRA_FASTPATH")
|
| 236 |
-
force_htm_cpu = _env_flag_enabled(env, "HYDRA_FORCE_HTM_CPU")
|
| 237 |
-
inert_mamba = _env_flag_enabled(env, "HYDRA_INERT_MAMBA")
|
| 238 |
-
synthetic_retina = _env_flag_enabled(env, "HYDRA_ALLOW_SYNTHETIC_RETINA")
|
| 239 |
-
hyena_layers = str(env.get("HYDRA_HYENA_LAYERS", "") or "")
|
| 240 |
-
engram_subsample = _env_int(env, "HYDRA_ENGRAM_SUBSAMPLE", 1)
|
| 241 |
-
htm_subsample = _env_int(env, "HYDRA_HTM_SUBSAMPLE", 1)
|
| 242 |
-
full_arch_compliant = not any((
|
| 243 |
-
throughput_mode,
|
| 244 |
-
fastpath,
|
| 245 |
-
force_htm_cpu,
|
| 246 |
-
inert_mamba,
|
| 247 |
-
synthetic_retina,
|
| 248 |
-
bool(hyena_layers.strip()),
|
| 249 |
-
))
|
| 250 |
-
return {
|
| 251 |
-
'full_arch_compliant': full_arch_compliant,
|
| 252 |
-
'throughput_mode': throughput_mode,
|
| 253 |
-
'fastpath': fastpath,
|
| 254 |
-
'force_htm_cpu': force_htm_cpu,
|
| 255 |
-
'inert_mamba': inert_mamba,
|
| 256 |
-
'synthetic_retina': synthetic_retina,
|
| 257 |
-
'hyena_layers': hyena_layers,
|
| 258 |
-
'engram_subsample': engram_subsample,
|
| 259 |
-
'htm_subsample': htm_subsample,
|
| 260 |
-
}
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def eval_attempt_batches(*, requested_batch: int, min_batch: int) -> list[int]:
|
| 264 |
-
requested = max(1, int(requested_batch))
|
| 265 |
-
minimum = max(1, int(min_batch))
|
| 266 |
-
batches: list[int] = []
|
| 267 |
-
current = requested
|
| 268 |
-
while current >= minimum:
|
| 269 |
-
if current not in batches:
|
| 270 |
-
batches.append(current)
|
| 271 |
-
if current == minimum:
|
| 272 |
-
break
|
| 273 |
-
next_batch = max(minimum, current // 2)
|
| 274 |
-
if next_batch == current:
|
| 275 |
-
break
|
| 276 |
-
current = next_batch
|
| 277 |
-
if minimum not in batches:
|
| 278 |
-
batches.append(minimum)
|
| 279 |
-
return batches
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
def build_eval_plan(*, eval_tokens: int, requested_batch: int, max_seq_len: int, chunk_tokens: int, min_batch: int) -> dict[str, Any]:
|
| 283 |
-
effective_chunk_tokens = max(int(chunk_tokens), int(requested_batch) * int(max_seq_len))
|
| 284 |
-
chunk_count = max(1, math.ceil(int(eval_tokens) / effective_chunk_tokens))
|
| 285 |
-
return {
|
| 286 |
-
'eval_tokens': int(eval_tokens),
|
| 287 |
-
'eval_requested_batch': int(requested_batch),
|
| 288 |
-
'eval_chunk_tokens': int(effective_chunk_tokens),
|
| 289 |
-
'eval_chunk_count': int(chunk_count),
|
| 290 |
-
'eval_attempt_batches': eval_attempt_batches(requested_batch=requested_batch, min_batch=min_batch),
|
| 291 |
-
'eval_min_batch': int(max(1, min_batch)),
|
| 292 |
-
}
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def _fingerprint_descriptor(descriptor: Mapping[str, Any]) -> str:
|
| 296 |
-
payload = json.dumps(dict(descriptor), sort_keys=True, separators=(",", ":"))
|
| 297 |
-
return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:12]
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def dataset_domain_payload(*, env: Mapping[str, str], prepare_module: Any, nemotron_module: Any | None) -> dict[str, Any]:
|
| 301 |
-
use_nemotron = _env_flag_enabled(env, "HYDRA_USE_NEMOTRON")
|
| 302 |
-
vocab_size = int(getattr(prepare_module, "VOCAB_SIZE", 0))
|
| 303 |
-
|
| 304 |
-
if use_nemotron and nemotron_module is not None:
|
| 305 |
-
use_full_blend = _env_flag_enabled(env, "HYDRA_USE_FULL_BLEND")
|
| 306 |
-
phase = str(env.get("HYDRA_NEMOTRON_PHASE", "phase1") or "phase1").strip().lower()
|
| 307 |
-
if use_full_blend:
|
| 308 |
-
train_weights = dict(getattr(nemotron_module, "FULL_BLEND_WEIGHTS", {}))
|
| 309 |
-
val_weights = dict(train_weights)
|
| 310 |
-
else:
|
| 311 |
-
train_weights = dict(
|
| 312 |
-
getattr(nemotron_module, "PHASE2_WEIGHTS", {}) if phase == "phase2" else getattr(nemotron_module, "PHASE1_WEIGHTS", {})
|
| 313 |
-
)
|
| 314 |
-
val_weights = {"Nemotron-Pretraining-Multiple-Choice": 1.0}
|
| 315 |
-
train_descriptor = {
|
| 316 |
-
"backend": "nemotron_stream",
|
| 317 |
-
"phase": "full_blend" if use_full_blend else phase,
|
| 318 |
-
"weights": train_weights,
|
| 319 |
-
"factual_inject_rate": _env_int(env, "HYDRA_FACTUAL_INJECT_RATE", 50),
|
| 320 |
-
"vocab_size": vocab_size,
|
| 321 |
-
}
|
| 322 |
-
val_descriptor = {
|
| 323 |
-
"backend": "nemotron_stream",
|
| 324 |
-
"phase": "full_blend" if use_full_blend else "val_multiple_choice",
|
| 325 |
-
"weights": val_weights,
|
| 326 |
-
"vocab_size": vocab_size,
|
| 327 |
-
}
|
| 328 |
-
data_backend = "nemotron_stream"
|
| 329 |
-
else:
|
| 330 |
-
all_files = list(getattr(prepare_module, "list_parquet_files", lambda: [])())
|
| 331 |
-
val_filename = str(getattr(prepare_module, "VAL_FILENAME", ""))
|
| 332 |
-
train_files = [str(path) for path in all_files if not str(path).endswith(val_filename)]
|
| 333 |
-
val_files = [str(path) for path in all_files if str(path).endswith(val_filename)]
|
| 334 |
-
train_descriptor = {
|
| 335 |
-
"backend": "climbmix_parquet",
|
| 336 |
-
"train_shard_count": len(train_files),
|
| 337 |
-
"train_shard_examples": sorted(Path(path).name for path in train_files[:3]),
|
| 338 |
-
"vocab_size": vocab_size,
|
| 339 |
-
}
|
| 340 |
-
val_descriptor = {
|
| 341 |
-
"backend": "climbmix_parquet",
|
| 342 |
-
"val_filename": val_filename,
|
| 343 |
-
"val_shard_count": len(val_files),
|
| 344 |
-
"vocab_size": vocab_size,
|
| 345 |
-
}
|
| 346 |
-
data_backend = "climbmix_parquet"
|
| 347 |
-
|
| 348 |
-
train_fingerprint = _fingerprint_descriptor(train_descriptor)
|
| 349 |
-
val_fingerprint = _fingerprint_descriptor(val_descriptor)
|
| 350 |
-
return {
|
| 351 |
-
"data_backend": data_backend,
|
| 352 |
-
"train_domain_descriptor": train_descriptor,
|
| 353 |
-
"val_domain_descriptor": val_descriptor,
|
| 354 |
-
"train_domain_fingerprint": train_fingerprint,
|
| 355 |
-
"val_domain_fingerprint": val_fingerprint,
|
| 356 |
-
"train_val_domain_match": train_fingerprint == val_fingerprint,
|
| 357 |
-
}
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
def build_lineage_payload(
|
| 361 |
-
*,
|
| 362 |
-
env: Mapping[str, str],
|
| 363 |
-
seed: int,
|
| 364 |
-
resume_requested: bool,
|
| 365 |
-
resume_requested_path: str | None,
|
| 366 |
-
resume_loaded_path: str | None,
|
| 367 |
-
resume_step: int,
|
| 368 |
-
resume_epoch: int,
|
| 369 |
-
) -> dict[str, Any]:
|
| 370 |
-
warmstart = _env_flag_enabled(env, "HYDRA_WARMSTART")
|
| 371 |
-
resume_applied = resume_loaded_path is not None and int(resume_step) > 0
|
| 372 |
-
if resume_applied and warmstart:
|
| 373 |
-
lineage_mode = "warmstart_resume"
|
| 374 |
-
elif resume_applied:
|
| 375 |
-
lineage_mode = "resume"
|
| 376 |
-
else:
|
| 377 |
-
lineage_mode = "fresh"
|
| 378 |
-
return {
|
| 379 |
-
"seed": int(seed),
|
| 380 |
-
"warmstart": warmstart,
|
| 381 |
-
"resume_requested": bool(resume_requested),
|
| 382 |
-
"resume_applied": resume_applied,
|
| 383 |
-
"resume_requested_path": resume_requested_path,
|
| 384 |
-
"resume_loaded_path": resume_loaded_path,
|
| 385 |
-
"resume_step": int(resume_step),
|
| 386 |
-
"resume_epoch": int(resume_epoch),
|
| 387 |
-
"lineage_mode": lineage_mode,
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
def build_final_metrics_payload(
|
| 392 |
-
*,
|
| 393 |
-
secondary_metrics: dict[str, Any],
|
| 394 |
-
val_bpb: float | None,
|
| 395 |
-
val_ppl: float | None,
|
| 396 |
-
eval_status: str,
|
| 397 |
-
eval_error: str | None,
|
| 398 |
-
n_layer: int,
|
| 399 |
-
d_model: int,
|
| 400 |
-
num_params: int,
|
| 401 |
-
step: int,
|
| 402 |
-
total_tokens: int,
|
| 403 |
-
peak_vram_mb: float,
|
| 404 |
-
total_training_time: float,
|
| 405 |
-
sdr_target_active: int,
|
| 406 |
-
architecture_env: Mapping[str, str] | None = None,
|
| 407 |
-
eval_diagnostics: Mapping[str, Any] | None = None,
|
| 408 |
-
domain_fingerprints: Mapping[str, Any] | None = None,
|
| 409 |
-
lineage_payload: Mapping[str, Any] | None = None,
|
| 410 |
-
) -> dict[str, Any]:
|
| 411 |
-
"""Build final run metrics without conflating skipped eval and validation.
|
| 412 |
-
|
| 413 |
-
This helper deliberately preserves ``val_bpb=None`` when final eval did not
|
| 414 |
-
complete. HPO can then prune or explicitly label a fallback instead of
|
| 415 |
-
accidentally treating live training BPB as validation BPB.
|
| 416 |
-
"""
|
| 417 |
-
payload = dict(secondary_metrics)
|
| 418 |
-
payload.update({
|
| 419 |
-
'eval_status': eval_status,
|
| 420 |
-
'eval_error': eval_error,
|
| 421 |
-
'objective_source': 'final_val' if val_bpb is not None else 'missing_final_val',
|
| 422 |
-
'val_bpb': float(val_bpb) if val_bpb is not None else None,
|
| 423 |
-
'val_ppl': float(val_ppl) if val_ppl is not None else None,
|
| 424 |
-
'n_layer': int(n_layer),
|
| 425 |
-
'd_model': int(d_model),
|
| 426 |
-
'num_params_M': float(num_params / 1e6),
|
| 427 |
-
'num_steps': int(step),
|
| 428 |
-
'total_tokens_M': float(total_tokens / 1e6),
|
| 429 |
-
'peak_vram_mb': float(peak_vram_mb),
|
| 430 |
-
'training_seconds': float(total_training_time),
|
| 431 |
-
'sdr_target_active': int(sdr_target_active),
|
| 432 |
-
})
|
| 433 |
-
payload.update(architecture_compliance_payload(architecture_env or dict(os.environ)))
|
| 434 |
-
if eval_diagnostics:
|
| 435 |
-
payload.update(dict(eval_diagnostics))
|
| 436 |
-
if domain_fingerprints:
|
| 437 |
-
payload.update(dict(domain_fingerprints))
|
| 438 |
-
if lineage_payload:
|
| 439 |
-
payload.update(dict(lineage_payload))
|
| 440 |
-
return payload
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
|
| 444 |
"""Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
|
| 445 |
|
| 446 |
Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
|
|
@@ -500,14 +500,14 @@ def _try_load_ckpt(path: Path, model, optimizer, device):
|
|
| 500 |
return step, total_training_time, smooth_train_loss, bpt_ema, epoch
|
| 501 |
|
| 502 |
|
| 503 |
-
def maybe_resume_ckpt(
|
| 504 |
-
model: PostSemClawModel,
|
| 505 |
-
optimizer: torch.optim.Optimizer,
|
| 506 |
-
device: torch.device,
|
| 507 |
-
) -> tuple[int, float, float, float, int, str | None]:
|
| 508 |
-
if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
|
| 509 |
-
print("[ckpt] resume disabled; starting fresh", flush=True)
|
| 510 |
-
return 0, 0.0, 0.0, 0.0, 0, None
|
| 511 |
|
| 512 |
resume_path = Path(os.path.expanduser(RESUME_CKPT))
|
| 513 |
# Try the primary path, then rotated backups. This is crucial because a
|
|
@@ -521,18 +521,18 @@ def maybe_resume_ckpt(
|
|
| 521 |
if not cand.exists():
|
| 522 |
continue
|
| 523 |
try:
|
| 524 |
-
result = _try_load_ckpt(cand, model, optimizer, device)
|
| 525 |
-
if result is not None:
|
| 526 |
-
if cand != resume_path:
|
| 527 |
-
print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
|
| 528 |
-
step, total_training_time, smooth_train_loss, bpt_ema, epoch = result
|
| 529 |
-
return step, total_training_time, smooth_train_loss, bpt_ema, epoch, str(cand)
|
| 530 |
except Exception as e:
|
| 531 |
print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
|
| 532 |
continue
|
| 533 |
|
| 534 |
-
print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
|
| 535 |
-
return 0, 0.0, 0.0, 0.0, 0, None
|
| 536 |
|
| 537 |
|
| 538 |
# ---------------------------------------------------------------------------
|
|
@@ -561,14 +561,14 @@ def main() -> None:
|
|
| 561 |
|
| 562 |
# Streaming path skips prepare.py (which normally trains the tokenizer
|
| 563 |
# and builds the retina), so we must materialize both before model init.
|
| 564 |
-
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 565 |
-
_p_nemo.ensure_tokenizer()
|
| 566 |
-
if os.environ.get("HYDRA_THROUGHPUT_MODE", "0") != "1":
|
| 567 |
-
# Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
|
| 568 |
-
# returns in seconds; otherwise build_retina streams Nemotron docs to
|
| 569 |
-
# compute cooccurrence + train SOM, then uploads back to the cache.
|
| 570 |
-
import subsystems.sdr_retina as _sdr_retina
|
| 571 |
-
_sdr_retina.build_retina()
|
| 572 |
tokenizer = Tokenizer.from_directory()
|
| 573 |
vocab_size = tokenizer.get_vocab_size()
|
| 574 |
print(f"Vocab size: {vocab_size:,}")
|
|
@@ -614,18 +614,18 @@ def main() -> None:
|
|
| 614 |
weight_decay=WEIGHT_DECAY,
|
| 615 |
)
|
| 616 |
|
| 617 |
-
step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch, resume_loaded_path = maybe_resume_ckpt(
|
| 618 |
-
model, optimizer, device,
|
| 619 |
-
)
|
| 620 |
-
lineage_payload = build_lineage_payload(
|
| 621 |
-
env=dict(os.environ),
|
| 622 |
-
seed=SEED,
|
| 623 |
-
resume_requested=bool(RESUME_CKPT and RESUME_CKPT.lower() != "none"),
|
| 624 |
-
resume_requested_path=RESUME_CKPT if RESUME_CKPT and RESUME_CKPT.lower() != "none" else None,
|
| 625 |
-
resume_loaded_path=resume_loaded_path,
|
| 626 |
-
resume_step=step,
|
| 627 |
-
resume_epoch=resume_epoch,
|
| 628 |
-
)
|
| 629 |
|
| 630 |
# Learnability #4: inform the model of the BOS token id so it can mask
|
| 631 |
# doc-separator positions in packed sequences. Always set (the mask only
|
|
@@ -1020,22 +1020,22 @@ def main() -> None:
|
|
| 1020 |
# does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
|
| 1021 |
# how many val tokens to sweep (default 2 M, short enough for autoresearch
|
| 1022 |
# 5-min budgets).
|
| 1023 |
-
val_bpb: float | None = None
|
| 1024 |
-
val_ppl: float | None = None
|
| 1025 |
-
eval_status = "not_started"
|
| 1026 |
-
eval_error: str | None = None
|
| 1027 |
-
_eval_B = int(os.environ.get("HYDRA_EVAL_BATCH", str(max(1, DEVICE_BATCH_SIZE // 2))))
|
| 1028 |
-
_eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(2 * 524288)))
|
| 1029 |
-
_eval_chunk_tokens = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(_eval_tokens)))
|
| 1030 |
-
_eval_min_batch = int(os.environ.get("HYDRA_EVAL_MIN_BATCH", "1"))
|
| 1031 |
-
eval_diagnostics = build_eval_plan(
|
| 1032 |
-
eval_tokens=_eval_tokens,
|
| 1033 |
-
requested_batch=_eval_B,
|
| 1034 |
-
max_seq_len=MAX_SEQ_LEN,
|
| 1035 |
-
chunk_tokens=_eval_chunk_tokens,
|
| 1036 |
-
min_batch=_eval_min_batch,
|
| 1037 |
-
)
|
| 1038 |
-
try:
|
| 1039 |
# Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
|
| 1040 |
# which leaves < 1GB for the eval forward — the driver can't satisfy
|
| 1041 |
# the allocation. Free EVERY tensor we don't strictly need:
|
|
@@ -1057,70 +1057,70 @@ def main() -> None:
|
|
| 1057 |
model._last_sdr = None
|
| 1058 |
import gc as _gc
|
| 1059 |
_gc.collect()
|
| 1060 |
-
torch.cuda.empty_cache()
|
| 1061 |
-
torch.cuda.synchronize()
|
| 1062 |
-
try:
|
| 1063 |
-
_free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 1064 |
-
eval_diagnostics["eval_free_vram_before_mb"] = float(_free_mb)
|
| 1065 |
-
print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
|
| 1066 |
-
except Exception:
|
| 1067 |
-
pass
|
| 1068 |
-
print(
|
| 1069 |
-
f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B} "
|
| 1070 |
-
f"chunk_tokens={eval_diagnostics['eval_chunk_tokens']} attempts={eval_diagnostics['eval_attempt_batches']}...",
|
| 1071 |
-
flush=True,
|
| 1072 |
-
)
|
| 1073 |
-
model.eval()
|
| 1074 |
-
_orig = _prepare_mod.EVAL_TOKENS
|
| 1075 |
-
_orig_chunk = getattr(_prepare_mod, "EVAL_CHUNK_TOKENS", _eval_tokens)
|
| 1076 |
-
_prepare_mod.EVAL_TOKENS = _eval_tokens
|
| 1077 |
-
_prepare_mod.EVAL_CHUNK_TOKENS = int(eval_diagnostics["eval_chunk_tokens"])
|
| 1078 |
-
_successful_batch: int | None = None
|
| 1079 |
-
_attempts: list[int] = []
|
| 1080 |
-
try:
|
| 1081 |
-
for _attempt_batch in eval_diagnostics["eval_attempt_batches"]:
|
| 1082 |
-
_attempts.append(int(_attempt_batch))
|
| 1083 |
-
eval_diagnostics["eval_attempted_batch"] = int(_attempt_batch)
|
| 1084 |
-
try:
|
| 1085 |
-
with autocast_ctx:
|
| 1086 |
-
val_bpb = evaluate_bpb(model, tokenizer, int(_attempt_batch))
|
| 1087 |
-
_successful_batch = int(_attempt_batch)
|
| 1088 |
-
break
|
| 1089 |
-
except torch.cuda.OutOfMemoryError as _attempt_oom:
|
| 1090 |
-
eval_error = str(_attempt_oom)
|
| 1091 |
-
eval_status = "oom"
|
| 1092 |
-
torch.cuda.empty_cache()
|
| 1093 |
-
if int(_attempt_batch) == eval_diagnostics["eval_attempt_batches"][-1]:
|
| 1094 |
-
raise
|
| 1095 |
-
finally:
|
| 1096 |
-
_prepare_mod.EVAL_TOKENS = _orig
|
| 1097 |
-
_prepare_mod.EVAL_CHUNK_TOKENS = _orig_chunk
|
| 1098 |
-
eval_diagnostics["eval_attempt_batches"] = _attempts
|
| 1099 |
-
eval_diagnostics["eval_effective_batch"] = _successful_batch
|
| 1100 |
-
val_ppl = 2 ** val_bpb
|
| 1101 |
-
eval_status = "completed"
|
| 1102 |
-
print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
|
| 1103 |
-
except torch.cuda.OutOfMemoryError as e:
|
| 1104 |
-
eval_status = "oom"
|
| 1105 |
-
eval_error = str(e)
|
| 1106 |
-
print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
|
| 1107 |
-
torch.cuda.empty_cache()
|
| 1108 |
-
try:
|
| 1109 |
-
eval_diagnostics["eval_free_vram_after_mb"] = float(torch.cuda.mem_get_info()[0] / 1024 / 1024)
|
| 1110 |
-
except Exception:
|
| 1111 |
-
pass
|
| 1112 |
-
except Exception as e:
|
| 1113 |
-
import traceback as _tb
|
| 1114 |
-
eval_status = type(e).__name__
|
| 1115 |
-
eval_error = str(e)
|
| 1116 |
-
print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
|
| 1117 |
-
_tb.print_exc()
|
| 1118 |
-
try:
|
| 1119 |
-
_free = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 1120 |
-
eval_diagnostics["eval_free_vram_after_mb"] = float(_free)
|
| 1121 |
-
print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
|
| 1122 |
-
except Exception:
|
| 1123 |
-
pass
|
| 1124 |
|
| 1125 |
# Final ckpts with val_bpb filled in (if eval succeeded).
|
| 1126 |
save_ckpt(
|
|
@@ -1164,13 +1164,13 @@ def main() -> None:
|
|
| 1164 |
/ total_training_time / GPU_BF16_PEAK_FLOPS
|
| 1165 |
if total_training_time > 0 else 0
|
| 1166 |
)
|
| 1167 |
-
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 1168 |
-
metrics = model.get_secondary_metrics()
|
| 1169 |
-
domain_fingerprints = dataset_domain_payload(
|
| 1170 |
-
env=dict(os.environ),
|
| 1171 |
-
prepare_module=_prepare_mod,
|
| 1172 |
-
nemotron_module=globals().get("_p_nemo"),
|
| 1173 |
-
)
|
| 1174 |
|
| 1175 |
print("---")
|
| 1176 |
print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
|
|
@@ -1206,28 +1206,28 @@ def main() -> None:
|
|
| 1206 |
# Emit full metrics dictionary as JSON for sweep aggregation. Path from
|
| 1207 |
# HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
|
| 1208 |
# written (even without diagnostics) so the aggregator can compare runs.
|
| 1209 |
-
_metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
|
| 1210 |
-
try:
|
| 1211 |
-
_dump = build_final_metrics_payload(
|
| 1212 |
-
secondary_metrics=metrics,
|
| 1213 |
-
val_bpb=val_bpb,
|
| 1214 |
-
val_ppl=val_ppl,
|
| 1215 |
-
eval_status=eval_status,
|
| 1216 |
-
eval_error=eval_error,
|
| 1217 |
-
n_layer=N_LAYER,
|
| 1218 |
-
d_model=D_MODEL,
|
| 1219 |
-
num_params=num_params,
|
| 1220 |
-
step=step,
|
| 1221 |
-
total_tokens=total_tokens,
|
| 1222 |
-
peak_vram_mb=peak_vram_mb,
|
| 1223 |
-
total_training_time=total_training_time,
|
| 1224 |
-
sdr_target_active=int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
|
| 1225 |
-
architecture_env=dict(os.environ),
|
| 1226 |
-
eval_diagnostics=eval_diagnostics,
|
| 1227 |
-
domain_fingerprints=domain_fingerprints,
|
| 1228 |
-
lineage_payload=lineage_payload,
|
| 1229 |
-
)
|
| 1230 |
-
Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
|
| 1231 |
with open(_metrics_out, 'w') as _f:
|
| 1232 |
json.dump(_dump, _f, indent=2, sort_keys=True)
|
| 1233 |
print(f"[METRICS] wrote {_metrics_out}", flush=True)
|
|
|
|
| 4 |
preserved. Public entrypoint: `main()`.
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import hashlib
|
| 11 |
+
import json
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import threading
|
| 16 |
+
import time
|
| 17 |
+
from collections.abc import Mapping
|
| 18 |
+
from dataclasses import asdict
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
|
| 22 |
import torch
|
| 23 |
|
|
|
|
| 133 |
return msd, osd
|
| 134 |
|
| 135 |
|
| 136 |
+
def save_ckpt(
|
| 137 |
model: PostSemClawModel,
|
| 138 |
optimizer: torch.optim.Optimizer,
|
| 139 |
config: PostSemClawConfig,
|
|
|
|
| 214 |
target=_write, daemon=True, name=f"ckpt-save-{step}"
|
| 215 |
)
|
| 216 |
_CKPT_WORKER_THREAD.start()
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _env_flag_enabled(env: Mapping[str, str], key: str) -> bool:
|
| 222 |
+
value = str(env.get(key, "0") or "0").strip().lower()
|
| 223 |
+
return value not in {"", "0", "false", "no", "off"}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _env_int(env: Mapping[str, str], key: str, default: int) -> int:
|
| 227 |
+
try:
|
| 228 |
+
return int(str(env.get(key, str(default)) or str(default)))
|
| 229 |
+
except ValueError:
|
| 230 |
+
return default
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def architecture_compliance_payload(env: Mapping[str, str]) -> dict[str, bool | int | str]:
|
| 234 |
+
throughput_mode = _env_flag_enabled(env, "HYDRA_THROUGHPUT_MODE")
|
| 235 |
+
fastpath = _env_flag_enabled(env, "HYDRA_FASTPATH")
|
| 236 |
+
force_htm_cpu = _env_flag_enabled(env, "HYDRA_FORCE_HTM_CPU")
|
| 237 |
+
inert_mamba = _env_flag_enabled(env, "HYDRA_INERT_MAMBA")
|
| 238 |
+
synthetic_retina = _env_flag_enabled(env, "HYDRA_ALLOW_SYNTHETIC_RETINA")
|
| 239 |
+
hyena_layers = str(env.get("HYDRA_HYENA_LAYERS", "") or "")
|
| 240 |
+
engram_subsample = _env_int(env, "HYDRA_ENGRAM_SUBSAMPLE", 1)
|
| 241 |
+
htm_subsample = _env_int(env, "HYDRA_HTM_SUBSAMPLE", 1)
|
| 242 |
+
full_arch_compliant = not any((
|
| 243 |
+
throughput_mode,
|
| 244 |
+
fastpath,
|
| 245 |
+
force_htm_cpu,
|
| 246 |
+
inert_mamba,
|
| 247 |
+
synthetic_retina,
|
| 248 |
+
bool(hyena_layers.strip()),
|
| 249 |
+
))
|
| 250 |
+
return {
|
| 251 |
+
'full_arch_compliant': full_arch_compliant,
|
| 252 |
+
'throughput_mode': throughput_mode,
|
| 253 |
+
'fastpath': fastpath,
|
| 254 |
+
'force_htm_cpu': force_htm_cpu,
|
| 255 |
+
'inert_mamba': inert_mamba,
|
| 256 |
+
'synthetic_retina': synthetic_retina,
|
| 257 |
+
'hyena_layers': hyena_layers,
|
| 258 |
+
'engram_subsample': engram_subsample,
|
| 259 |
+
'htm_subsample': htm_subsample,
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def eval_attempt_batches(*, requested_batch: int, min_batch: int) -> list[int]:
|
| 264 |
+
requested = max(1, int(requested_batch))
|
| 265 |
+
minimum = max(1, int(min_batch))
|
| 266 |
+
batches: list[int] = []
|
| 267 |
+
current = requested
|
| 268 |
+
while current >= minimum:
|
| 269 |
+
if current not in batches:
|
| 270 |
+
batches.append(current)
|
| 271 |
+
if current == minimum:
|
| 272 |
+
break
|
| 273 |
+
next_batch = max(minimum, current // 2)
|
| 274 |
+
if next_batch == current:
|
| 275 |
+
break
|
| 276 |
+
current = next_batch
|
| 277 |
+
if minimum not in batches:
|
| 278 |
+
batches.append(minimum)
|
| 279 |
+
return batches
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def build_eval_plan(*, eval_tokens: int, requested_batch: int, max_seq_len: int, chunk_tokens: int, min_batch: int) -> dict[str, Any]:
|
| 283 |
+
effective_chunk_tokens = max(int(chunk_tokens), int(requested_batch) * int(max_seq_len))
|
| 284 |
+
chunk_count = max(1, math.ceil(int(eval_tokens) / effective_chunk_tokens))
|
| 285 |
+
return {
|
| 286 |
+
'eval_tokens': int(eval_tokens),
|
| 287 |
+
'eval_requested_batch': int(requested_batch),
|
| 288 |
+
'eval_chunk_tokens': int(effective_chunk_tokens),
|
| 289 |
+
'eval_chunk_count': int(chunk_count),
|
| 290 |
+
'eval_attempt_batches': eval_attempt_batches(requested_batch=requested_batch, min_batch=min_batch),
|
| 291 |
+
'eval_min_batch': int(max(1, min_batch)),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _fingerprint_descriptor(descriptor: Mapping[str, Any]) -> str:
|
| 296 |
+
payload = json.dumps(dict(descriptor), sort_keys=True, separators=(",", ":"))
|
| 297 |
+
return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:12]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def dataset_domain_payload(*, env: Mapping[str, str], prepare_module: Any, nemotron_module: Any | None) -> dict[str, Any]:
|
| 301 |
+
use_nemotron = _env_flag_enabled(env, "HYDRA_USE_NEMOTRON")
|
| 302 |
+
vocab_size = int(getattr(prepare_module, "VOCAB_SIZE", 0))
|
| 303 |
+
|
| 304 |
+
if use_nemotron and nemotron_module is not None:
|
| 305 |
+
use_full_blend = _env_flag_enabled(env, "HYDRA_USE_FULL_BLEND")
|
| 306 |
+
phase = str(env.get("HYDRA_NEMOTRON_PHASE", "phase1") or "phase1").strip().lower()
|
| 307 |
+
if use_full_blend:
|
| 308 |
+
train_weights = dict(getattr(nemotron_module, "FULL_BLEND_WEIGHTS", {}))
|
| 309 |
+
val_weights = dict(train_weights)
|
| 310 |
+
else:
|
| 311 |
+
train_weights = dict(
|
| 312 |
+
getattr(nemotron_module, "PHASE2_WEIGHTS", {}) if phase == "phase2" else getattr(nemotron_module, "PHASE1_WEIGHTS", {})
|
| 313 |
+
)
|
| 314 |
+
val_weights = {"Nemotron-Pretraining-Multiple-Choice": 1.0}
|
| 315 |
+
train_descriptor = {
|
| 316 |
+
"backend": "nemotron_stream",
|
| 317 |
+
"phase": "full_blend" if use_full_blend else phase,
|
| 318 |
+
"weights": train_weights,
|
| 319 |
+
"factual_inject_rate": _env_int(env, "HYDRA_FACTUAL_INJECT_RATE", 50),
|
| 320 |
+
"vocab_size": vocab_size,
|
| 321 |
+
}
|
| 322 |
+
val_descriptor = {
|
| 323 |
+
"backend": "nemotron_stream",
|
| 324 |
+
"phase": "full_blend" if use_full_blend else "val_multiple_choice",
|
| 325 |
+
"weights": val_weights,
|
| 326 |
+
"vocab_size": vocab_size,
|
| 327 |
+
}
|
| 328 |
+
data_backend = "nemotron_stream"
|
| 329 |
+
else:
|
| 330 |
+
all_files = list(getattr(prepare_module, "list_parquet_files", lambda: [])())
|
| 331 |
+
val_filename = str(getattr(prepare_module, "VAL_FILENAME", ""))
|
| 332 |
+
train_files = [str(path) for path in all_files if not str(path).endswith(val_filename)]
|
| 333 |
+
val_files = [str(path) for path in all_files if str(path).endswith(val_filename)]
|
| 334 |
+
train_descriptor = {
|
| 335 |
+
"backend": "climbmix_parquet",
|
| 336 |
+
"train_shard_count": len(train_files),
|
| 337 |
+
"train_shard_examples": sorted(Path(path).name for path in train_files[:3]),
|
| 338 |
+
"vocab_size": vocab_size,
|
| 339 |
+
}
|
| 340 |
+
val_descriptor = {
|
| 341 |
+
"backend": "climbmix_parquet",
|
| 342 |
+
"val_filename": val_filename,
|
| 343 |
+
"val_shard_count": len(val_files),
|
| 344 |
+
"vocab_size": vocab_size,
|
| 345 |
+
}
|
| 346 |
+
data_backend = "climbmix_parquet"
|
| 347 |
+
|
| 348 |
+
train_fingerprint = _fingerprint_descriptor(train_descriptor)
|
| 349 |
+
val_fingerprint = _fingerprint_descriptor(val_descriptor)
|
| 350 |
+
return {
|
| 351 |
+
"data_backend": data_backend,
|
| 352 |
+
"train_domain_descriptor": train_descriptor,
|
| 353 |
+
"val_domain_descriptor": val_descriptor,
|
| 354 |
+
"train_domain_fingerprint": train_fingerprint,
|
| 355 |
+
"val_domain_fingerprint": val_fingerprint,
|
| 356 |
+
"train_val_domain_match": train_fingerprint == val_fingerprint,
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def build_lineage_payload(
|
| 361 |
+
*,
|
| 362 |
+
env: Mapping[str, str],
|
| 363 |
+
seed: int,
|
| 364 |
+
resume_requested: bool,
|
| 365 |
+
resume_requested_path: str | None,
|
| 366 |
+
resume_loaded_path: str | None,
|
| 367 |
+
resume_step: int,
|
| 368 |
+
resume_epoch: int,
|
| 369 |
+
) -> dict[str, Any]:
|
| 370 |
+
warmstart = _env_flag_enabled(env, "HYDRA_WARMSTART")
|
| 371 |
+
resume_applied = resume_loaded_path is not None and int(resume_step) > 0
|
| 372 |
+
if resume_applied and warmstart:
|
| 373 |
+
lineage_mode = "warmstart_resume"
|
| 374 |
+
elif resume_applied:
|
| 375 |
+
lineage_mode = "resume"
|
| 376 |
+
else:
|
| 377 |
+
lineage_mode = "fresh"
|
| 378 |
+
return {
|
| 379 |
+
"seed": int(seed),
|
| 380 |
+
"warmstart": warmstart,
|
| 381 |
+
"resume_requested": bool(resume_requested),
|
| 382 |
+
"resume_applied": resume_applied,
|
| 383 |
+
"resume_requested_path": resume_requested_path,
|
| 384 |
+
"resume_loaded_path": resume_loaded_path,
|
| 385 |
+
"resume_step": int(resume_step),
|
| 386 |
+
"resume_epoch": int(resume_epoch),
|
| 387 |
+
"lineage_mode": lineage_mode,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def build_final_metrics_payload(
|
| 392 |
+
*,
|
| 393 |
+
secondary_metrics: dict[str, Any],
|
| 394 |
+
val_bpb: float | None,
|
| 395 |
+
val_ppl: float | None,
|
| 396 |
+
eval_status: str,
|
| 397 |
+
eval_error: str | None,
|
| 398 |
+
n_layer: int,
|
| 399 |
+
d_model: int,
|
| 400 |
+
num_params: int,
|
| 401 |
+
step: int,
|
| 402 |
+
total_tokens: int,
|
| 403 |
+
peak_vram_mb: float,
|
| 404 |
+
total_training_time: float,
|
| 405 |
+
sdr_target_active: int,
|
| 406 |
+
architecture_env: Mapping[str, str] | None = None,
|
| 407 |
+
eval_diagnostics: Mapping[str, Any] | None = None,
|
| 408 |
+
domain_fingerprints: Mapping[str, Any] | None = None,
|
| 409 |
+
lineage_payload: Mapping[str, Any] | None = None,
|
| 410 |
+
) -> dict[str, Any]:
|
| 411 |
+
"""Build final run metrics without conflating skipped eval and validation.
|
| 412 |
+
|
| 413 |
+
This helper deliberately preserves ``val_bpb=None`` when final eval did not
|
| 414 |
+
complete. HPO can then prune or explicitly label a fallback instead of
|
| 415 |
+
accidentally treating live training BPB as validation BPB.
|
| 416 |
+
"""
|
| 417 |
+
payload = dict(secondary_metrics)
|
| 418 |
+
payload.update({
|
| 419 |
+
'eval_status': eval_status,
|
| 420 |
+
'eval_error': eval_error,
|
| 421 |
+
'objective_source': 'final_val' if val_bpb is not None else 'missing_final_val',
|
| 422 |
+
'val_bpb': float(val_bpb) if val_bpb is not None else None,
|
| 423 |
+
'val_ppl': float(val_ppl) if val_ppl is not None else None,
|
| 424 |
+
'n_layer': int(n_layer),
|
| 425 |
+
'd_model': int(d_model),
|
| 426 |
+
'num_params_M': float(num_params / 1e6),
|
| 427 |
+
'num_steps': int(step),
|
| 428 |
+
'total_tokens_M': float(total_tokens / 1e6),
|
| 429 |
+
'peak_vram_mb': float(peak_vram_mb),
|
| 430 |
+
'training_seconds': float(total_training_time),
|
| 431 |
+
'sdr_target_active': int(sdr_target_active),
|
| 432 |
+
})
|
| 433 |
+
payload.update(architecture_compliance_payload(architecture_env or dict(os.environ)))
|
| 434 |
+
if eval_diagnostics:
|
| 435 |
+
payload.update(dict(eval_diagnostics))
|
| 436 |
+
if domain_fingerprints:
|
| 437 |
+
payload.update(dict(domain_fingerprints))
|
| 438 |
+
if lineage_payload:
|
| 439 |
+
payload.update(dict(lineage_payload))
|
| 440 |
+
return payload
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
|
| 444 |
"""Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
|
| 445 |
|
| 446 |
Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
|
|
|
|
| 500 |
return step, total_training_time, smooth_train_loss, bpt_ema, epoch
|
| 501 |
|
| 502 |
|
| 503 |
+
def maybe_resume_ckpt(
|
| 504 |
+
model: PostSemClawModel,
|
| 505 |
+
optimizer: torch.optim.Optimizer,
|
| 506 |
+
device: torch.device,
|
| 507 |
+
) -> tuple[int, float, float, float, int, str | None]:
|
| 508 |
+
if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
|
| 509 |
+
print("[ckpt] resume disabled; starting fresh", flush=True)
|
| 510 |
+
return 0, 0.0, 0.0, 0.0, 0, None
|
| 511 |
|
| 512 |
resume_path = Path(os.path.expanduser(RESUME_CKPT))
|
| 513 |
# Try the primary path, then rotated backups. This is crucial because a
|
|
|
|
| 521 |
if not cand.exists():
|
| 522 |
continue
|
| 523 |
try:
|
| 524 |
+
result = _try_load_ckpt(cand, model, optimizer, device)
|
| 525 |
+
if result is not None:
|
| 526 |
+
if cand != resume_path:
|
| 527 |
+
print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
|
| 528 |
+
step, total_training_time, smooth_train_loss, bpt_ema, epoch = result
|
| 529 |
+
return step, total_training_time, smooth_train_loss, bpt_ema, epoch, str(cand)
|
| 530 |
except Exception as e:
|
| 531 |
print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
|
| 532 |
continue
|
| 533 |
|
| 534 |
+
print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
|
| 535 |
+
return 0, 0.0, 0.0, 0.0, 0, None
|
| 536 |
|
| 537 |
|
| 538 |
# ---------------------------------------------------------------------------
|
|
|
|
| 561 |
|
| 562 |
# Streaming path skips prepare.py (which normally trains the tokenizer
|
| 563 |
# and builds the retina), so we must materialize both before model init.
|
| 564 |
+
if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
|
| 565 |
+
_p_nemo.ensure_tokenizer()
|
| 566 |
+
if os.environ.get("HYDRA_THROUGHPUT_MODE", "0") != "1":
|
| 567 |
+
# Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
|
| 568 |
+
# returns in seconds; otherwise build_retina streams Nemotron docs to
|
| 569 |
+
# compute cooccurrence + train SOM, then uploads back to the cache.
|
| 570 |
+
import subsystems.sdr_retina as _sdr_retina
|
| 571 |
+
_sdr_retina.build_retina()
|
| 572 |
tokenizer = Tokenizer.from_directory()
|
| 573 |
vocab_size = tokenizer.get_vocab_size()
|
| 574 |
print(f"Vocab size: {vocab_size:,}")
|
|
|
|
| 614 |
weight_decay=WEIGHT_DECAY,
|
| 615 |
)
|
| 616 |
|
| 617 |
+
step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch, resume_loaded_path = maybe_resume_ckpt(
|
| 618 |
+
model, optimizer, device,
|
| 619 |
+
)
|
| 620 |
+
lineage_payload = build_lineage_payload(
|
| 621 |
+
env=dict(os.environ),
|
| 622 |
+
seed=SEED,
|
| 623 |
+
resume_requested=bool(RESUME_CKPT and RESUME_CKPT.lower() != "none"),
|
| 624 |
+
resume_requested_path=RESUME_CKPT if RESUME_CKPT and RESUME_CKPT.lower() != "none" else None,
|
| 625 |
+
resume_loaded_path=resume_loaded_path,
|
| 626 |
+
resume_step=step,
|
| 627 |
+
resume_epoch=resume_epoch,
|
| 628 |
+
)
|
| 629 |
|
| 630 |
# Learnability #4: inform the model of the BOS token id so it can mask
|
| 631 |
# doc-separator positions in packed sequences. Always set (the mask only
|
|
|
|
| 1020 |
# does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
|
| 1021 |
# how many val tokens to sweep (default 2 M, short enough for autoresearch
|
| 1022 |
# 5-min budgets).
|
| 1023 |
+
val_bpb: float | None = None
|
| 1024 |
+
val_ppl: float | None = None
|
| 1025 |
+
eval_status = "not_started"
|
| 1026 |
+
eval_error: str | None = None
|
| 1027 |
+
_eval_B = int(os.environ.get("HYDRA_EVAL_BATCH", str(max(1, DEVICE_BATCH_SIZE // 2))))
|
| 1028 |
+
_eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(2 * 524288)))
|
| 1029 |
+
_eval_chunk_tokens = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(_eval_tokens)))
|
| 1030 |
+
_eval_min_batch = int(os.environ.get("HYDRA_EVAL_MIN_BATCH", "1"))
|
| 1031 |
+
eval_diagnostics = build_eval_plan(
|
| 1032 |
+
eval_tokens=_eval_tokens,
|
| 1033 |
+
requested_batch=_eval_B,
|
| 1034 |
+
max_seq_len=MAX_SEQ_LEN,
|
| 1035 |
+
chunk_tokens=_eval_chunk_tokens,
|
| 1036 |
+
min_batch=_eval_min_batch,
|
| 1037 |
+
)
|
| 1038 |
+
try:
|
| 1039 |
# Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
|
| 1040 |
# which leaves < 1GB for the eval forward — the driver can't satisfy
|
| 1041 |
# the allocation. Free EVERY tensor we don't strictly need:
|
|
|
|
| 1057 |
model._last_sdr = None
|
| 1058 |
import gc as _gc
|
| 1059 |
_gc.collect()
|
| 1060 |
+
torch.cuda.empty_cache()
|
| 1061 |
+
torch.cuda.synchronize()
|
| 1062 |
+
try:
|
| 1063 |
+
_free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 1064 |
+
eval_diagnostics["eval_free_vram_before_mb"] = float(_free_mb)
|
| 1065 |
+
print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
|
| 1066 |
+
except Exception:
|
| 1067 |
+
pass
|
| 1068 |
+
print(
|
| 1069 |
+
f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B} "
|
| 1070 |
+
f"chunk_tokens={eval_diagnostics['eval_chunk_tokens']} attempts={eval_diagnostics['eval_attempt_batches']}...",
|
| 1071 |
+
flush=True,
|
| 1072 |
+
)
|
| 1073 |
+
model.eval()
|
| 1074 |
+
_orig = _prepare_mod.EVAL_TOKENS
|
| 1075 |
+
_orig_chunk = getattr(_prepare_mod, "EVAL_CHUNK_TOKENS", _eval_tokens)
|
| 1076 |
+
_prepare_mod.EVAL_TOKENS = _eval_tokens
|
| 1077 |
+
_prepare_mod.EVAL_CHUNK_TOKENS = int(eval_diagnostics["eval_chunk_tokens"])
|
| 1078 |
+
_successful_batch: int | None = None
|
| 1079 |
+
_attempts: list[int] = []
|
| 1080 |
+
try:
|
| 1081 |
+
for _attempt_batch in eval_diagnostics["eval_attempt_batches"]:
|
| 1082 |
+
_attempts.append(int(_attempt_batch))
|
| 1083 |
+
eval_diagnostics["eval_attempted_batch"] = int(_attempt_batch)
|
| 1084 |
+
try:
|
| 1085 |
+
with autocast_ctx:
|
| 1086 |
+
val_bpb = evaluate_bpb(model, tokenizer, int(_attempt_batch))
|
| 1087 |
+
_successful_batch = int(_attempt_batch)
|
| 1088 |
+
break
|
| 1089 |
+
except torch.cuda.OutOfMemoryError as _attempt_oom:
|
| 1090 |
+
eval_error = str(_attempt_oom)
|
| 1091 |
+
eval_status = "oom"
|
| 1092 |
+
torch.cuda.empty_cache()
|
| 1093 |
+
if int(_attempt_batch) == eval_diagnostics["eval_attempt_batches"][-1]:
|
| 1094 |
+
raise
|
| 1095 |
+
finally:
|
| 1096 |
+
_prepare_mod.EVAL_TOKENS = _orig
|
| 1097 |
+
_prepare_mod.EVAL_CHUNK_TOKENS = _orig_chunk
|
| 1098 |
+
eval_diagnostics["eval_attempt_batches"] = _attempts
|
| 1099 |
+
eval_diagnostics["eval_effective_batch"] = _successful_batch
|
| 1100 |
+
val_ppl = 2 ** val_bpb
|
| 1101 |
+
eval_status = "completed"
|
| 1102 |
+
print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
|
| 1103 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 1104 |
+
eval_status = "oom"
|
| 1105 |
+
eval_error = str(e)
|
| 1106 |
+
print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
|
| 1107 |
+
torch.cuda.empty_cache()
|
| 1108 |
+
try:
|
| 1109 |
+
eval_diagnostics["eval_free_vram_after_mb"] = float(torch.cuda.mem_get_info()[0] / 1024 / 1024)
|
| 1110 |
+
except Exception:
|
| 1111 |
+
pass
|
| 1112 |
+
except Exception as e:
|
| 1113 |
+
import traceback as _tb
|
| 1114 |
+
eval_status = type(e).__name__
|
| 1115 |
+
eval_error = str(e)
|
| 1116 |
+
print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
|
| 1117 |
+
_tb.print_exc()
|
| 1118 |
+
try:
|
| 1119 |
+
_free = torch.cuda.mem_get_info()[0] / 1024 / 1024
|
| 1120 |
+
eval_diagnostics["eval_free_vram_after_mb"] = float(_free)
|
| 1121 |
+
print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
|
| 1122 |
+
except Exception:
|
| 1123 |
+
pass
|
| 1124 |
|
| 1125 |
# Final ckpts with val_bpb filled in (if eval succeeded).
|
| 1126 |
save_ckpt(
|
|
|
|
| 1164 |
/ total_training_time / GPU_BF16_PEAK_FLOPS
|
| 1165 |
if total_training_time > 0 else 0
|
| 1166 |
)
|
| 1167 |
+
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 1168 |
+
metrics = model.get_secondary_metrics()
|
| 1169 |
+
domain_fingerprints = dataset_domain_payload(
|
| 1170 |
+
env=dict(os.environ),
|
| 1171 |
+
prepare_module=_prepare_mod,
|
| 1172 |
+
nemotron_module=globals().get("_p_nemo"),
|
| 1173 |
+
)
|
| 1174 |
|
| 1175 |
print("---")
|
| 1176 |
print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
|
|
|
|
| 1206 |
# Emit full metrics dictionary as JSON for sweep aggregation. Path from
|
| 1207 |
# HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
|
| 1208 |
# written (even without diagnostics) so the aggregator can compare runs.
|
| 1209 |
+
_metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
|
| 1210 |
+
try:
|
| 1211 |
+
_dump = build_final_metrics_payload(
|
| 1212 |
+
secondary_metrics=metrics,
|
| 1213 |
+
val_bpb=val_bpb,
|
| 1214 |
+
val_ppl=val_ppl,
|
| 1215 |
+
eval_status=eval_status,
|
| 1216 |
+
eval_error=eval_error,
|
| 1217 |
+
n_layer=N_LAYER,
|
| 1218 |
+
d_model=D_MODEL,
|
| 1219 |
+
num_params=num_params,
|
| 1220 |
+
step=step,
|
| 1221 |
+
total_tokens=total_tokens,
|
| 1222 |
+
peak_vram_mb=peak_vram_mb,
|
| 1223 |
+
total_training_time=total_training_time,
|
| 1224 |
+
sdr_target_active=int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
|
| 1225 |
+
architecture_env=dict(os.environ),
|
| 1226 |
+
eval_diagnostics=eval_diagnostics,
|
| 1227 |
+
domain_fingerprints=domain_fingerprints,
|
| 1228 |
+
lineage_payload=lineage_payload,
|
| 1229 |
+
)
|
| 1230 |
+
Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
|
| 1231 |
with open(_metrics_out, 'w') as _f:
|
| 1232 |
json.dump(_dump, _f, indent=2, sort_keys=True)
|
| 1233 |
print(f"[METRICS] wrote {_metrics_out}", flush=True)
|
overlay/prepare.py
CHANGED
|
@@ -13,10 +13,10 @@ import os
|
|
| 13 |
import sys
|
| 14 |
import time
|
| 15 |
import math
|
| 16 |
-
import argparse
|
| 17 |
-
import pickle
|
| 18 |
-
from multiprocessing import Pool
|
| 19 |
-
from typing import Any
|
| 20 |
|
| 21 |
import requests
|
| 22 |
import pyarrow.parquet as pq
|
|
@@ -30,8 +30,8 @@ import torch
|
|
| 30 |
|
| 31 |
MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) # context length
|
| 32 |
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
|
| 33 |
-
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
|
| 34 |
-
EVAL_CHUNK_TOKENS = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(EVAL_TOKENS)))
|
| 35 |
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
# Configuration
|
|
@@ -160,8 +160,8 @@ def train_tokenizer():
|
|
| 160 |
print("Tokenizer: training BPE tokenizer...")
|
| 161 |
t0 = time.time()
|
| 162 |
|
| 163 |
-
tokenizer_cls = getattr(rustbpe, "Tokenizer")
|
| 164 |
-
tokenizer: Any = tokenizer_cls()
|
| 165 |
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
|
| 166 |
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 167 |
|
|
@@ -228,10 +228,10 @@ class Tokenizer:
|
|
| 228 |
def get_bos_token_id(self):
|
| 229 |
return self.bos_token_id
|
| 230 |
|
| 231 |
-
def encode(self, text, prepend=None, num_threads=8):
|
| 232 |
-
prepend_id = None
|
| 233 |
-
if prepend is not None:
|
| 234 |
-
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
|
| 235 |
if isinstance(text, str):
|
| 236 |
ids = self.enc.encode_ordinary(text)
|
| 237 |
if prepend is not None:
|
|
@@ -249,7 +249,7 @@ class Tokenizer:
|
|
| 249 |
return self.enc.decode(ids)
|
| 250 |
|
| 251 |
|
| 252 |
-
_TOKEN_BYTES_CACHE: dict[str, torch.Tensor] = {}
|
| 253 |
|
| 254 |
def get_token_bytes(device="cpu"):
|
| 255 |
key = str(device)
|
|
@@ -345,30 +345,30 @@ def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
|
|
| 345 |
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
|
| 346 |
yield inputs, targets, epoch
|
| 347 |
|
| 348 |
-
# ---------------------------------------------------------------------------
|
| 349 |
-
# Evaluation (DO NOT CHANGE — this is the fixed metric)
|
| 350 |
-
# ---------------------------------------------------------------------------
|
| 351 |
-
|
| 352 |
-
def compute_bpb_from_totals(total_nats: torch.Tensor, total_bytes: torch.Tensor) -> torch.Tensor:
|
| 353 |
-
if int(total_bytes.item()) <= 0:
|
| 354 |
-
raise ValueError("BPB normalization requires at least one non-special token")
|
| 355 |
-
return total_nats.to(dtype=torch.float64) / (math.log(2) * total_bytes.to(dtype=torch.float64))
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
def compute_bpb_from_losses(loss_flat: torch.Tensor, nbytes: torch.Tensor) -> torch.Tensor:
|
| 359 |
-
"""Convert per-token losses and token byte lengths into bits-per-byte.
|
| 360 |
-
|
| 361 |
-
Tokens with zero byte length (special tokens) are excluded from both the
|
| 362 |
-
numerator and denominator so BPB remains comparable across tokenizer
|
| 363 |
-
special-token conventions.
|
| 364 |
-
"""
|
| 365 |
-
mask = nbytes > 0
|
| 366 |
-
total_nats = (loss_flat * mask).sum(dtype=torch.float64)
|
| 367 |
-
total_bytes = nbytes[mask].sum(dtype=torch.int64)
|
| 368 |
-
return compute_bpb_from_totals(total_nats, total_bytes)
|
| 369 |
-
|
| 370 |
-
@torch.no_grad()
|
| 371 |
-
def evaluate_bpb(model, tokenizer, batch_size):
|
| 372 |
"""
|
| 373 |
Bits per byte (BPB): vocab size-independent evaluation metric.
|
| 374 |
Sums per-token cross-entropy (in nats), sums target byte lengths,
|
|
@@ -379,35 +379,35 @@ def evaluate_bpb(model, tokenizer, batch_size):
|
|
| 379 |
Perf: accumulates on GPU (single sync at end), prefetches next batch
|
| 380 |
while current forward runs.
|
| 381 |
"""
|
| 382 |
-
token_bytes = get_token_bytes(device="cuda")
|
| 383 |
-
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
|
| 384 |
-
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
|
| 385 |
-
chunk_steps = max(1, EVAL_CHUNK_TOKENS // (batch_size * MAX_SEQ_LEN))
|
| 386 |
|
| 387 |
# GPU-resident accumulators — avoid per-batch .item() sync
|
| 388 |
total_nats_t = torch.zeros(1, device="cuda", dtype=torch.float64)
|
| 389 |
total_bytes_t = torch.zeros(1, device="cuda", dtype=torch.int64)
|
| 390 |
|
| 391 |
# Prefetch first batch
|
| 392 |
-
next_batch = next(val_loader)
|
| 393 |
-
steps_done = 0
|
| 394 |
-
while steps_done < steps:
|
| 395 |
-
this_chunk = min(chunk_steps, steps - steps_done)
|
| 396 |
-
for _ in range(this_chunk):
|
| 397 |
-
x, y, _epoch = next_batch
|
| 398 |
-
# Prefetch NEXT batch while GPU computes current forward
|
| 399 |
-
next_batch = next(val_loader)
|
| 400 |
-
loss_flat = model(x, y, reduction='none').view(-1)
|
| 401 |
-
y_flat = y.view(-1)
|
| 402 |
-
nbytes = token_bytes[y_flat]
|
| 403 |
-
total_nats_t += (loss_flat * (nbytes > 0)).sum(dtype=torch.float64)
|
| 404 |
-
total_bytes_t += nbytes[nbytes > 0].sum(dtype=torch.int64)
|
| 405 |
-
steps_done += this_chunk
|
| 406 |
-
if steps_done < steps:
|
| 407 |
-
torch.cuda.empty_cache()
|
| 408 |
-
|
| 409 |
-
# Single GPU→CPU sync at end
|
| 410 |
-
return float(compute_bpb_from_totals(total_nats_t, total_bytes_t).item())
|
| 411 |
|
| 412 |
# ---------------------------------------------------------------------------
|
| 413 |
# Main
|
|
|
|
| 13 |
import sys
|
| 14 |
import time
|
| 15 |
import math
|
| 16 |
+
import argparse
|
| 17 |
+
import pickle
|
| 18 |
+
from multiprocessing import Pool
|
| 19 |
+
from typing import Any
|
| 20 |
|
| 21 |
import requests
|
| 22 |
import pyarrow.parquet as pq
|
|
|
|
| 30 |
|
| 31 |
MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) # context length
|
| 32 |
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
|
| 33 |
+
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
|
| 34 |
+
EVAL_CHUNK_TOKENS = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(EVAL_TOKENS)))
|
| 35 |
|
| 36 |
# ---------------------------------------------------------------------------
|
| 37 |
# Configuration
|
|
|
|
| 160 |
print("Tokenizer: training BPE tokenizer...")
|
| 161 |
t0 = time.time()
|
| 162 |
|
| 163 |
+
tokenizer_cls = getattr(rustbpe, "Tokenizer")
|
| 164 |
+
tokenizer: Any = tokenizer_cls()
|
| 165 |
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
|
| 166 |
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 167 |
|
|
|
|
| 228 |
def get_bos_token_id(self):
|
| 229 |
return self.bos_token_id
|
| 230 |
|
| 231 |
+
def encode(self, text, prepend=None, num_threads=8):
|
| 232 |
+
prepend_id = None
|
| 233 |
+
if prepend is not None:
|
| 234 |
+
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
|
| 235 |
if isinstance(text, str):
|
| 236 |
ids = self.enc.encode_ordinary(text)
|
| 237 |
if prepend is not None:
|
|
|
|
| 249 |
return self.enc.decode(ids)
|
| 250 |
|
| 251 |
|
| 252 |
+
_TOKEN_BYTES_CACHE: dict[str, torch.Tensor] = {}
|
| 253 |
|
| 254 |
def get_token_bytes(device="cpu"):
|
| 255 |
key = str(device)
|
|
|
|
| 345 |
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
|
| 346 |
yield inputs, targets, epoch
|
| 347 |
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
# Evaluation (DO NOT CHANGE — this is the fixed metric)
|
| 350 |
+
# ---------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
def compute_bpb_from_totals(total_nats: torch.Tensor, total_bytes: torch.Tensor) -> torch.Tensor:
|
| 353 |
+
if int(total_bytes.item()) <= 0:
|
| 354 |
+
raise ValueError("BPB normalization requires at least one non-special token")
|
| 355 |
+
return total_nats.to(dtype=torch.float64) / (math.log(2) * total_bytes.to(dtype=torch.float64))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def compute_bpb_from_losses(loss_flat: torch.Tensor, nbytes: torch.Tensor) -> torch.Tensor:
|
| 359 |
+
"""Convert per-token losses and token byte lengths into bits-per-byte.
|
| 360 |
+
|
| 361 |
+
Tokens with zero byte length (special tokens) are excluded from both the
|
| 362 |
+
numerator and denominator so BPB remains comparable across tokenizer
|
| 363 |
+
special-token conventions.
|
| 364 |
+
"""
|
| 365 |
+
mask = nbytes > 0
|
| 366 |
+
total_nats = (loss_flat * mask).sum(dtype=torch.float64)
|
| 367 |
+
total_bytes = nbytes[mask].sum(dtype=torch.int64)
|
| 368 |
+
return compute_bpb_from_totals(total_nats, total_bytes)
|
| 369 |
+
|
| 370 |
+
@torch.no_grad()
|
| 371 |
+
def evaluate_bpb(model, tokenizer, batch_size):
|
| 372 |
"""
|
| 373 |
Bits per byte (BPB): vocab size-independent evaluation metric.
|
| 374 |
Sums per-token cross-entropy (in nats), sums target byte lengths,
|
|
|
|
| 379 |
Perf: accumulates on GPU (single sync at end), prefetches next batch
|
| 380 |
while current forward runs.
|
| 381 |
"""
|
| 382 |
+
token_bytes = get_token_bytes(device="cuda")
|
| 383 |
+
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
|
| 384 |
+
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
|
| 385 |
+
chunk_steps = max(1, EVAL_CHUNK_TOKENS // (batch_size * MAX_SEQ_LEN))
|
| 386 |
|
| 387 |
# GPU-resident accumulators — avoid per-batch .item() sync
|
| 388 |
total_nats_t = torch.zeros(1, device="cuda", dtype=torch.float64)
|
| 389 |
total_bytes_t = torch.zeros(1, device="cuda", dtype=torch.int64)
|
| 390 |
|
| 391 |
# Prefetch first batch
|
| 392 |
+
next_batch = next(val_loader)
|
| 393 |
+
steps_done = 0
|
| 394 |
+
while steps_done < steps:
|
| 395 |
+
this_chunk = min(chunk_steps, steps - steps_done)
|
| 396 |
+
for _ in range(this_chunk):
|
| 397 |
+
x, y, _epoch = next_batch
|
| 398 |
+
# Prefetch NEXT batch while GPU computes current forward
|
| 399 |
+
next_batch = next(val_loader)
|
| 400 |
+
loss_flat = model(x, y, reduction='none').view(-1)
|
| 401 |
+
y_flat = y.view(-1)
|
| 402 |
+
nbytes = token_bytes[y_flat]
|
| 403 |
+
total_nats_t += (loss_flat * (nbytes > 0)).sum(dtype=torch.float64)
|
| 404 |
+
total_bytes_t += nbytes[nbytes > 0].sum(dtype=torch.int64)
|
| 405 |
+
steps_done += this_chunk
|
| 406 |
+
if steps_done < steps:
|
| 407 |
+
torch.cuda.empty_cache()
|
| 408 |
+
|
| 409 |
+
# Single GPU→CPU sync at end
|
| 410 |
+
return float(compute_bpb_from_totals(total_nats_t, total_bytes_t).item())
|
| 411 |
|
| 412 |
# ---------------------------------------------------------------------------
|
| 413 |
# Main
|
overlay/prepare_nemotron.py
CHANGED
|
@@ -20,16 +20,15 @@ Full blend mode (env HYDRA_USE_FULL_BLEND=1):
|
|
| 20 |
"""
|
| 21 |
from __future__ import annotations
|
| 22 |
|
| 23 |
-
import os
|
| 24 |
-
import random
|
| 25 |
-
import importlib
|
| 26 |
-
import
|
| 27 |
-
from
|
| 28 |
-
from typing import Any, Iterator, cast
|
| 29 |
|
| 30 |
import torch
|
| 31 |
|
| 32 |
-
import prepare as _p # reuse tokenizer, BOS, byte-length helpers
|
| 33 |
|
| 34 |
NEMOTRON_REPO = "nvidia/Nemotron-Pretraining-Specialized-v1.1"
|
| 35 |
|
|
@@ -66,96 +65,94 @@ PHASE1_WEIGHTS = {
|
|
| 66 |
"Nemotron-Pretraining-Formal-Logic": 0.20,
|
| 67 |
"Nemotron-Pretraining-Multiple-Choice": 0.20,
|
| 68 |
}
|
| 69 |
-
PHASE2_WEIGHTS = {
|
| 70 |
"Nemotron-Pretraining-Multiple-Choice": 0.45,
|
| 71 |
"Nemotron-Pretraining-Economics": 0.20,
|
| 72 |
"Nemotron-Pretraining-Formal-Logic": 0.15,
|
| 73 |
"Nemotron-Pretraining-Code-Concepts": 0.10,
|
| 74 |
"Nemotron-Pretraining-Unconditional-Algorithmic": 0.10,
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
StreamBatch = tuple[list[str], int]
|
| 78 |
-
TokenBatch = tuple[list[list[int]], int]
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def _tokenizer_cache_repo() -> str:
|
| 82 |
-
return (
|
| 83 |
-
os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
|
| 84 |
-
or os.environ.get("FEATHER_HF_OUTPUT_REPO")
|
| 85 |
-
or os.environ.get("HF_REPO_ID")
|
| 86 |
-
or os.environ.get("HYDRA_RETINA_CACHE_REPO")
|
| 87 |
-
or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
|
| 88 |
-
or ""
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def _tokenizer_cache_prefix() -> str:
|
| 93 |
-
return f"tokenizer/vocab{_p.VOCAB_SIZE}"
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def maybe_hydrate_tokenizer_cache() -> bool:
|
| 97 |
-
"""Try to download tokenizer artifacts from HF cache storage."""
|
| 98 |
-
repo_id = _tokenizer_cache_repo()
|
| 99 |
-
token = os.environ.get("HF_TOKEN")
|
| 100 |
-
if not repo_id or not token:
|
| 101 |
-
return False
|
| 102 |
-
|
| 103 |
-
try:
|
| 104 |
-
from huggingface_hub import hf_hub_download
|
| 105 |
-
except Exception as e: # noqa: BLE001
|
| 106 |
-
print(f"[nemotron] tokenizer cache unavailable: {type(e).__name__}: {e}", flush=True)
|
| 107 |
-
return False
|
| 108 |
-
|
| 109 |
-
os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
|
| 110 |
-
prefix = _tokenizer_cache_prefix()
|
| 111 |
-
try:
|
| 112 |
-
|
| 113 |
-
repo_id=repo_id,
|
| 114 |
-
repo_type="model",
|
| 115 |
-
subfolder=prefix,
|
| 116 |
-
filename="tokenizer.pkl",
|
| 117 |
-
token=token,
|
| 118 |
-
local_dir=_p.TOKENIZER_DIR,
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
repo_id=repo_id,
|
| 122 |
-
repo_type="model",
|
| 123 |
-
subfolder=prefix,
|
| 124 |
-
filename="token_bytes.pt",
|
| 125 |
-
token=token,
|
| 126 |
-
local_dir=_p.TOKENIZER_DIR,
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
repo_id
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
api =
|
| 153 |
-
prefix =
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
print(f"[nemotron]
|
| 157 |
-
except Exception as e: # noqa: BLE001
|
| 158 |
-
print(f"[nemotron] tokenizer cache upload skipped: {type(e).__name__}: {e}", flush=True)
|
| 159 |
|
| 160 |
|
| 161 |
def _phase_weights() -> dict[str, float]:
|
|
@@ -166,7 +163,7 @@ def _phase_weights() -> dict[str, float]:
|
|
| 166 |
return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
|
| 167 |
|
| 168 |
|
| 169 |
-
def _open_stream(config: str, split: str):
|
| 170 |
"""Open a streaming iterator over one dataset config.
|
| 171 |
|
| 172 |
Handles two modes:
|
|
@@ -177,17 +174,17 @@ def _open_stream(config: str, split: str):
|
|
| 177 |
|
| 178 |
Yields dicts; text extraction handled downstream by _extract_text.
|
| 179 |
"""
|
| 180 |
-
load_dataset = importlib.import_module("datasets").load_dataset
|
| 181 |
-
token = os.environ.get("HF_TOKEN")
|
| 182 |
-
shuffle_buf = int(os.environ.get("HYDRA_STREAM_SHUFFLE_BUFFER", "2048"))
|
| 183 |
-
|
| 184 |
-
if config in _BLEND_REGISTRY:
|
| 185 |
-
repo, name, _text_col = _BLEND_REGISTRY[config]
|
| 186 |
-
kwargs: dict[str, object] = dict(
|
| 187 |
-
split="train",
|
| 188 |
-
streaming=True,
|
| 189 |
-
token=token,
|
| 190 |
-
)
|
| 191 |
if name is not None:
|
| 192 |
kwargs["name"] = name
|
| 193 |
# nemotron-specialized has multiple sub-configs; pick the first one
|
|
@@ -209,18 +206,18 @@ def _open_stream(config: str, split: str):
|
|
| 209 |
return iter(ds)
|
| 210 |
|
| 211 |
|
| 212 |
-
def _extract_text(row: dict[str, object]) -> str:
|
| 213 |
"""Pick the right text column — datasets have different column names.
|
| 214 |
|
| 215 |
Priority order: text, content, prompt_completion, question, body.
|
| 216 |
For math datasets that split into problem+solution, concatenate both.
|
| 217 |
Fallback: concatenate all string-valued fields.
|
| 218 |
"""
|
| 219 |
-
# Fast path: most datasets use "text" or "content".
|
| 220 |
-
for k in ("text", "content", "prompt_completion", "question", "body"):
|
| 221 |
-
value = row.get(k)
|
| 222 |
-
if isinstance(value, str) and value:
|
| 223 |
-
return value
|
| 224 |
# Math datasets may have problem + solution as separate fields.
|
| 225 |
if "problem" in row and "solution" in row:
|
| 226 |
p = row["problem"] or ""
|
|
@@ -236,20 +233,20 @@ def _extract_text(row: dict[str, object]) -> str:
|
|
| 236 |
return "\n".join(parts)
|
| 237 |
|
| 238 |
|
| 239 |
-
class _WeightedStream:
|
| 240 |
"""Infinite weighted-round-robin over configs' streaming iterators."""
|
| 241 |
|
| 242 |
-
def __init__(self, weights: dict[str, float], seed: int = 0):
|
| 243 |
-
self.configs = list(weights.keys())
|
| 244 |
-
self.weights = [weights[c] for c in self.configs]
|
| 245 |
-
self.streams: dict[str, Iterator[dict[str, object]]] = {
|
| 246 |
-
c: _open_stream(c, "train") for c in self.configs
|
| 247 |
-
}
|
| 248 |
-
self.rng = random.Random(seed)
|
| 249 |
-
self.epoch = 1
|
| 250 |
-
self._factual_docs: list[str] | None = None
|
| 251 |
-
self._factual_idx = 0
|
| 252 |
-
self._inject_counter = 0
|
| 253 |
|
| 254 |
def _reopen(self, config: str):
|
| 255 |
# stream exhausted — reopen (HF streaming typically infinite but restart on edge)
|
|
@@ -265,20 +262,20 @@ class _WeightedStream:
|
|
| 265 |
# exist in the Nemotron configs. Controlled by HYDRA_FACTUAL_INJECT_RATE
|
| 266 |
# (default 50 = inject one factual doc every 50 Nemotron docs = ~2%).
|
| 267 |
inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
|
| 268 |
-
if inject_rate > 0 and self._factual_docs is None:
|
| 269 |
-
factual_path = os.path.join(
|
| 270 |
-
os.path.dirname(os.path.abspath(__file__)), "data", "factual", "facts.txt")
|
| 271 |
-
if os.path.exists(factual_path):
|
| 272 |
-
self._factual_docs = open(factual_path).read().strip().split('\n')
|
| 273 |
-
self._factual_idx = 0
|
| 274 |
-
self._inject_counter = 0
|
| 275 |
-
if inject_rate > 0 and self._factual_docs:
|
| 276 |
-
self._inject_counter += 1
|
| 277 |
-
if self._inject_counter >= inject_rate:
|
| 278 |
-
self._inject_counter = 0
|
| 279 |
-
doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
|
| 280 |
-
self._factual_idx += 1
|
| 281 |
-
return doc, self.epoch
|
| 282 |
|
| 283 |
config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
|
| 284 |
try:
|
|
@@ -311,9 +308,9 @@ def _document_batches(split: str, tokenizer_batch_size: int = 128) -> Iterator[t
|
|
| 311 |
stream = _WeightedStream(_phase_weights(), seed=0)
|
| 312 |
|
| 313 |
prefetch_depth = int(os.environ.get("HYDRA_STREAM_PREFETCH", "32"))
|
| 314 |
-
q: queue.Queue[StreamBatch | object] = queue.Queue(maxsize=prefetch_depth)
|
| 315 |
-
sentinel_stop = object()
|
| 316 |
-
error_box: list[BaseException] = []
|
| 317 |
|
| 318 |
def producer():
|
| 319 |
try:
|
|
@@ -338,7 +335,7 @@ def _document_batches(split: str, tokenizer_batch_size: int = 128) -> Iterator[t
|
|
| 338 |
if error_box:
|
| 339 |
raise error_box[0]
|
| 340 |
return
|
| 341 |
-
yield cast(StreamBatch, item)
|
| 342 |
|
| 343 |
|
| 344 |
def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 1000):
|
|
@@ -364,9 +361,9 @@ def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 10
|
|
| 364 |
# Stage 2: tokenization prefetch thread. Each queue element is a list of
|
| 365 |
# token-id lists (pre-tokenized docs). HYDRA_TOKEN_PREFETCH controls depth.
|
| 366 |
tok_prefetch = int(os.environ.get("HYDRA_TOKEN_PREFETCH", "8"))
|
| 367 |
-
tok_q: queue.Queue[TokenBatch | object] = queue.Queue(maxsize=tok_prefetch)
|
| 368 |
-
tok_sentinel = object()
|
| 369 |
-
tok_err_box: list[BaseException] = []
|
| 370 |
|
| 371 |
def tokenizer_producer():
|
| 372 |
try:
|
|
@@ -390,8 +387,8 @@ def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 10
|
|
| 390 |
if tok_err_box:
|
| 391 |
raise tok_err_box[0]
|
| 392 |
raise StopIteration
|
| 393 |
-
token_lists, epoch = cast(TokenBatch, item)
|
| 394 |
-
doc_buffer.extend(token_lists)
|
| 395 |
|
| 396 |
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
|
| 397 |
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
|
|
@@ -465,24 +462,24 @@ def evaluate_bpb(model, tokenizer, B: int) -> float:
|
|
| 465 |
return total_nats / (math.log(2) * max(total_bytes, 1))
|
| 466 |
|
| 467 |
|
| 468 |
-
def ensure_tokenizer():
|
| 469 |
"""Ensure rustbpe tokenizer exists. If absent, train on a Nemotron stream
|
| 470 |
sample using the same rustbpe.train_from_iterator API that prepare.py uses
|
| 471 |
(production path — don't fork tokenizer training logic).
|
| 472 |
"""
|
| 473 |
import pickle
|
| 474 |
import torch
|
| 475 |
-
path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
|
| 476 |
-
token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
|
| 477 |
-
if os.path.exists(path) and os.path.exists(token_bytes_path):
|
| 478 |
-
print(f"[nemotron] tokenizer + token_bytes already trained at {_p.TOKENIZER_DIR}", flush=True)
|
| 479 |
-
return
|
| 480 |
-
if maybe_hydrate_tokenizer_cache() and os.path.exists(path) and os.path.exists(token_bytes_path):
|
| 481 |
-
return
|
| 482 |
-
os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
|
| 483 |
print(f"[nemotron] training BPE (vocab_size={_p.VOCAB_SIZE}) on stream sample…", flush=True)
|
| 484 |
-
import rustbpe
|
| 485 |
-
import tiktoken
|
| 486 |
|
| 487 |
# Pull a sample of docs — use full blend if active so BPE covers all 7 sources.
|
| 488 |
n_docs = int(os.environ.get("HYDRA_BPE_TRAIN_DOCS", "20000"))
|
|
@@ -498,8 +495,8 @@ def ensure_tokenizer():
|
|
| 498 |
print(f"[nemotron] collected {len(sample_texts)} sample docs; training BPE…", flush=True)
|
| 499 |
|
| 500 |
# Train rustbpe — identical API to prepare.py's train_tokenizer().
|
| 501 |
-
tokenizer_cls = getattr(rustbpe, "Tokenizer")
|
| 502 |
-
tokenizer: Any = tokenizer_cls()
|
| 503 |
vocab_size_no_special = _p.VOCAB_SIZE - len(_p.SPECIAL_TOKENS)
|
| 504 |
tokenizer.train_from_iterator(iter(sample_texts), vocab_size_no_special, pattern=_p.SPLIT_PATTERN)
|
| 505 |
|
|
@@ -524,7 +521,7 @@ def ensure_tokenizer():
|
|
| 524 |
for token_id in range(enc.n_vocab):
|
| 525 |
tstr = enc.decode([token_id])
|
| 526 |
token_bytes_list.append(0 if tstr in special_set else len(tstr.encode("utf-8")))
|
| 527 |
-
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
|
| 528 |
-
torch.save(token_bytes_tensor, token_bytes_path)
|
| 529 |
-
print(f"[nemotron] BPE + token_bytes saved to {_p.TOKENIZER_DIR}", flush=True)
|
| 530 |
-
upload_tokenizer_cache()
|
|
|
|
| 20 |
"""
|
| 21 |
from __future__ import annotations
|
| 22 |
|
| 23 |
+
import os
|
| 24 |
+
import random
|
| 25 |
+
import importlib
|
| 26 |
+
from itertools import cycle
|
| 27 |
+
from typing import Any, Iterator, cast
|
|
|
|
| 28 |
|
| 29 |
import torch
|
| 30 |
|
| 31 |
+
import prepare as _p # reuse tokenizer, BOS, byte-length helpers
|
| 32 |
|
| 33 |
NEMOTRON_REPO = "nvidia/Nemotron-Pretraining-Specialized-v1.1"
|
| 34 |
|
|
|
|
| 65 |
"Nemotron-Pretraining-Formal-Logic": 0.20,
|
| 66 |
"Nemotron-Pretraining-Multiple-Choice": 0.20,
|
| 67 |
}
|
| 68 |
+
PHASE2_WEIGHTS = {
|
| 69 |
"Nemotron-Pretraining-Multiple-Choice": 0.45,
|
| 70 |
"Nemotron-Pretraining-Economics": 0.20,
|
| 71 |
"Nemotron-Pretraining-Formal-Logic": 0.15,
|
| 72 |
"Nemotron-Pretraining-Code-Concepts": 0.10,
|
| 73 |
"Nemotron-Pretraining-Unconditional-Algorithmic": 0.10,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
type StreamBatch = tuple[list[str], int]
|
| 77 |
+
type TokenBatch = tuple[list[list[int]], int]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _tokenizer_cache_repo() -> str:
|
| 81 |
+
return (
|
| 82 |
+
os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
|
| 83 |
+
or os.environ.get("FEATHER_HF_OUTPUT_REPO")
|
| 84 |
+
or os.environ.get("HF_REPO_ID")
|
| 85 |
+
or os.environ.get("HYDRA_RETINA_CACHE_REPO")
|
| 86 |
+
or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
|
| 87 |
+
or ""
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _tokenizer_cache_prefix() -> str:
|
| 92 |
+
return f"tokenizer/vocab{_p.VOCAB_SIZE}"
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def maybe_hydrate_tokenizer_cache() -> bool:
|
| 96 |
+
"""Try to download tokenizer artifacts from HF cache storage."""
|
| 97 |
+
repo_id = _tokenizer_cache_repo()
|
| 98 |
+
token = os.environ.get("HF_TOKEN")
|
| 99 |
+
if not repo_id or not token:
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
from huggingface_hub import hf_hub_download
|
| 104 |
+
except Exception as e: # noqa: BLE001
|
| 105 |
+
print(f"[nemotron] tokenizer cache unavailable: {type(e).__name__}: {e}", flush=True)
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
|
| 109 |
+
prefix = _tokenizer_cache_prefix()
|
| 110 |
+
try:
|
| 111 |
+
hf_hub_download(
|
| 112 |
+
repo_id=repo_id,
|
| 113 |
+
repo_type="model",
|
| 114 |
+
subfolder=prefix,
|
| 115 |
+
filename="tokenizer.pkl",
|
| 116 |
+
token=token,
|
| 117 |
+
local_dir=_p.TOKENIZER_DIR,
|
| 118 |
+
)
|
| 119 |
+
hf_hub_download(
|
| 120 |
+
repo_id=repo_id,
|
| 121 |
+
repo_type="model",
|
| 122 |
+
subfolder=prefix,
|
| 123 |
+
filename="token_bytes.pt",
|
| 124 |
+
token=token,
|
| 125 |
+
local_dir=_p.TOKENIZER_DIR,
|
| 126 |
+
)
|
| 127 |
+
except Exception as e: # noqa: BLE001
|
| 128 |
+
print(f"[nemotron] tokenizer cache miss in {repo_id}/{prefix}: {type(e).__name__}: {e}", flush=True)
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
print(f"[nemotron] hydrated tokenizer cache from {repo_id}/{prefix}", flush=True)
|
| 132 |
+
return True
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def upload_tokenizer_cache() -> None:
|
| 136 |
+
"""Upload tokenizer artifacts for reuse by future jobs."""
|
| 137 |
+
repo_id = _tokenizer_cache_repo()
|
| 138 |
+
token = os.environ.get("HF_TOKEN")
|
| 139 |
+
if not repo_id or not token:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
|
| 143 |
+
token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
|
| 144 |
+
if not (os.path.exists(path) and os.path.exists(token_bytes_path)):
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
from huggingface_hub import HfApi
|
| 149 |
+
api = HfApi(token=token)
|
| 150 |
+
prefix = _tokenizer_cache_prefix()
|
| 151 |
+
api.upload_file(path_or_fileobj=path, path_in_repo=f"{prefix}/tokenizer.pkl", repo_id=repo_id, repo_type="model")
|
| 152 |
+
api.upload_file(path_or_fileobj=token_bytes_path, path_in_repo=f"{prefix}/token_bytes.pt", repo_id=repo_id, repo_type="model")
|
| 153 |
+
print(f"[nemotron] uploaded tokenizer cache to {repo_id}/{prefix}", flush=True)
|
| 154 |
+
except Exception as e: # noqa: BLE001
|
| 155 |
+
print(f"[nemotron] tokenizer cache upload skipped: {type(e).__name__}: {e}", flush=True)
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def _phase_weights() -> dict[str, float]:
|
|
|
|
| 163 |
return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
|
| 164 |
|
| 165 |
|
| 166 |
+
def _open_stream(config: str, split: str):
|
| 167 |
"""Open a streaming iterator over one dataset config.
|
| 168 |
|
| 169 |
Handles two modes:
|
|
|
|
| 174 |
|
| 175 |
Yields dicts; text extraction handled downstream by _extract_text.
|
| 176 |
"""
|
| 177 |
+
load_dataset = importlib.import_module("datasets").load_dataset
|
| 178 |
+
token = os.environ.get("HF_TOKEN")
|
| 179 |
+
shuffle_buf = int(os.environ.get("HYDRA_STREAM_SHUFFLE_BUFFER", "2048"))
|
| 180 |
+
|
| 181 |
+
if config in _BLEND_REGISTRY:
|
| 182 |
+
repo, name, _text_col = _BLEND_REGISTRY[config]
|
| 183 |
+
kwargs: dict[str, object] = dict(
|
| 184 |
+
split="train",
|
| 185 |
+
streaming=True,
|
| 186 |
+
token=token,
|
| 187 |
+
)
|
| 188 |
if name is not None:
|
| 189 |
kwargs["name"] = name
|
| 190 |
# nemotron-specialized has multiple sub-configs; pick the first one
|
|
|
|
| 206 |
return iter(ds)
|
| 207 |
|
| 208 |
|
| 209 |
+
def _extract_text(row: dict[str, object]) -> str:
|
| 210 |
"""Pick the right text column — datasets have different column names.
|
| 211 |
|
| 212 |
Priority order: text, content, prompt_completion, question, body.
|
| 213 |
For math datasets that split into problem+solution, concatenate both.
|
| 214 |
Fallback: concatenate all string-valued fields.
|
| 215 |
"""
|
| 216 |
+
# Fast path: most datasets use "text" or "content".
|
| 217 |
+
for k in ("text", "content", "prompt_completion", "question", "body"):
|
| 218 |
+
value = row.get(k)
|
| 219 |
+
if isinstance(value, str) and value:
|
| 220 |
+
return value
|
| 221 |
# Math datasets may have problem + solution as separate fields.
|
| 222 |
if "problem" in row and "solution" in row:
|
| 223 |
p = row["problem"] or ""
|
|
|
|
| 233 |
return "\n".join(parts)
|
| 234 |
|
| 235 |
|
| 236 |
+
class _WeightedStream:
|
| 237 |
"""Infinite weighted-round-robin over configs' streaming iterators."""
|
| 238 |
|
| 239 |
+
def __init__(self, weights: dict[str, float], seed: int = 0):
|
| 240 |
+
self.configs = list(weights.keys())
|
| 241 |
+
self.weights = [weights[c] for c in self.configs]
|
| 242 |
+
self.streams: dict[str, Iterator[dict[str, object]]] = {
|
| 243 |
+
c: _open_stream(c, "train") for c in self.configs
|
| 244 |
+
}
|
| 245 |
+
self.rng = random.Random(seed)
|
| 246 |
+
self.epoch = 1
|
| 247 |
+
self._factual_docs: list[str] | None = None
|
| 248 |
+
self._factual_idx = 0
|
| 249 |
+
self._inject_counter = 0
|
| 250 |
|
| 251 |
def _reopen(self, config: str):
|
| 252 |
# stream exhausted — reopen (HF streaming typically infinite but restart on edge)
|
|
|
|
| 262 |
# exist in the Nemotron configs. Controlled by HYDRA_FACTUAL_INJECT_RATE
|
| 263 |
# (default 50 = inject one factual doc every 50 Nemotron docs = ~2%).
|
| 264 |
inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
|
| 265 |
+
if inject_rate > 0 and self._factual_docs is None:
|
| 266 |
+
factual_path = os.path.join(
|
| 267 |
+
os.path.dirname(os.path.abspath(__file__)), "data", "factual", "facts.txt")
|
| 268 |
+
if os.path.exists(factual_path):
|
| 269 |
+
self._factual_docs = open(factual_path).read().strip().split('\n')
|
| 270 |
+
self._factual_idx = 0
|
| 271 |
+
self._inject_counter = 0
|
| 272 |
+
if inject_rate > 0 and self._factual_docs:
|
| 273 |
+
self._inject_counter += 1
|
| 274 |
+
if self._inject_counter >= inject_rate:
|
| 275 |
+
self._inject_counter = 0
|
| 276 |
+
doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
|
| 277 |
+
self._factual_idx += 1
|
| 278 |
+
return doc, self.epoch
|
| 279 |
|
| 280 |
config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
|
| 281 |
try:
|
|
|
|
| 308 |
stream = _WeightedStream(_phase_weights(), seed=0)
|
| 309 |
|
| 310 |
prefetch_depth = int(os.environ.get("HYDRA_STREAM_PREFETCH", "32"))
|
| 311 |
+
q: queue.Queue[StreamBatch | object] = queue.Queue(maxsize=prefetch_depth)
|
| 312 |
+
sentinel_stop = object()
|
| 313 |
+
error_box: list[BaseException] = []
|
| 314 |
|
| 315 |
def producer():
|
| 316 |
try:
|
|
|
|
| 335 |
if error_box:
|
| 336 |
raise error_box[0]
|
| 337 |
return
|
| 338 |
+
yield cast(StreamBatch, item)
|
| 339 |
|
| 340 |
|
| 341 |
def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 1000):
|
|
|
|
| 361 |
# Stage 2: tokenization prefetch thread. Each queue element is a list of
|
| 362 |
# token-id lists (pre-tokenized docs). HYDRA_TOKEN_PREFETCH controls depth.
|
| 363 |
tok_prefetch = int(os.environ.get("HYDRA_TOKEN_PREFETCH", "8"))
|
| 364 |
+
tok_q: queue.Queue[TokenBatch | object] = queue.Queue(maxsize=tok_prefetch)
|
| 365 |
+
tok_sentinel = object()
|
| 366 |
+
tok_err_box: list[BaseException] = []
|
| 367 |
|
| 368 |
def tokenizer_producer():
|
| 369 |
try:
|
|
|
|
| 387 |
if tok_err_box:
|
| 388 |
raise tok_err_box[0]
|
| 389 |
raise StopIteration
|
| 390 |
+
token_lists, epoch = cast(TokenBatch, item)
|
| 391 |
+
doc_buffer.extend(token_lists)
|
| 392 |
|
| 393 |
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
|
| 394 |
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
|
|
|
|
| 462 |
return total_nats / (math.log(2) * max(total_bytes, 1))
|
| 463 |
|
| 464 |
|
| 465 |
+
def ensure_tokenizer():
|
| 466 |
"""Ensure rustbpe tokenizer exists. If absent, train on a Nemotron stream
|
| 467 |
sample using the same rustbpe.train_from_iterator API that prepare.py uses
|
| 468 |
(production path — don't fork tokenizer training logic).
|
| 469 |
"""
|
| 470 |
import pickle
|
| 471 |
import torch
|
| 472 |
+
path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
|
| 473 |
+
token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
|
| 474 |
+
if os.path.exists(path) and os.path.exists(token_bytes_path):
|
| 475 |
+
print(f"[nemotron] tokenizer + token_bytes already trained at {_p.TOKENIZER_DIR}", flush=True)
|
| 476 |
+
return
|
| 477 |
+
if maybe_hydrate_tokenizer_cache() and os.path.exists(path) and os.path.exists(token_bytes_path):
|
| 478 |
+
return
|
| 479 |
+
os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
|
| 480 |
print(f"[nemotron] training BPE (vocab_size={_p.VOCAB_SIZE}) on stream sample…", flush=True)
|
| 481 |
+
import rustbpe
|
| 482 |
+
import tiktoken
|
| 483 |
|
| 484 |
# Pull a sample of docs — use full blend if active so BPE covers all 7 sources.
|
| 485 |
n_docs = int(os.environ.get("HYDRA_BPE_TRAIN_DOCS", "20000"))
|
|
|
|
| 495 |
print(f"[nemotron] collected {len(sample_texts)} sample docs; training BPE…", flush=True)
|
| 496 |
|
| 497 |
# Train rustbpe — identical API to prepare.py's train_tokenizer().
|
| 498 |
+
tokenizer_cls = getattr(rustbpe, "Tokenizer")
|
| 499 |
+
tokenizer: Any = tokenizer_cls()
|
| 500 |
vocab_size_no_special = _p.VOCAB_SIZE - len(_p.SPECIAL_TOKENS)
|
| 501 |
tokenizer.train_from_iterator(iter(sample_texts), vocab_size_no_special, pattern=_p.SPLIT_PATTERN)
|
| 502 |
|
|
|
|
| 521 |
for token_id in range(enc.n_vocab):
|
| 522 |
tstr = enc.decode([token_id])
|
| 523 |
token_bytes_list.append(0 if tstr in special_set else len(tstr.encode("utf-8")))
|
| 524 |
+
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
|
| 525 |
+
torch.save(token_bytes_tensor, token_bytes_path)
|
| 526 |
+
print(f"[nemotron] BPE + token_bytes saved to {_p.TOKENIZER_DIR}", flush=True)
|
| 527 |
+
upload_tokenizer_cache()
|
overlay/scripts/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Package marker for script-level shared utilities.
|
|
|
|
| 1 |
+
# Package marker for script-level shared utilities.
|
overlay/scripts/audit_overlay_sync.py
CHANGED
|
@@ -1,100 +1,100 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
DEFAULT_INCLUDE_PATHS = [
|
| 10 |
-
"hydra",
|
| 11 |
-
"subsystems",
|
| 12 |
-
"scripts",
|
| 13 |
-
"htm_rust",
|
| 14 |
-
"harness",
|
| 15 |
-
"configs",
|
| 16 |
-
"prepare.py",
|
| 17 |
-
"prepare_nemotron.py",
|
| 18 |
-
"train.py",
|
| 19 |
-
"pyproject.toml",
|
| 20 |
-
"uv.lock",
|
| 21 |
-
]
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def _iter_files(path: Path) -> list[Path]:
|
| 25 |
-
if not path.exists():
|
| 26 |
-
return []
|
| 27 |
-
if path.is_file():
|
| 28 |
-
return [path]
|
| 29 |
-
return sorted(p for p in path.rglob("*") if p.is_file())
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def classify_overlay_pairs(*, repo_root: Path, include_paths: list[str]) -> dict[str, list[str]]:
|
| 33 |
-
overlay_root = repo_root / "hf_jobs" / "feather_h200_image" / "overlay"
|
| 34 |
-
identical: list[str] = []
|
| 35 |
-
root_ahead: list[str] = []
|
| 36 |
-
overlay_only: list[str] = []
|
| 37 |
-
missing_overlay: list[str] = []
|
| 38 |
-
|
| 39 |
-
for rel in include_paths:
|
| 40 |
-
root_path = repo_root / rel
|
| 41 |
-
overlay_path = overlay_root / rel
|
| 42 |
-
|
| 43 |
-
root_files = {p.relative_to(root_path).as_posix(): p for p in _iter_files(root_path)} if root_path.exists() and root_path.is_dir() else {}
|
| 44 |
-
overlay_files = {p.relative_to(overlay_path).as_posix(): p for p in _iter_files(overlay_path)} if overlay_path.exists() and overlay_path.is_dir() else {}
|
| 45 |
-
|
| 46 |
-
if root_path.is_file() or overlay_path.is_file():
|
| 47 |
-
rel_name = rel.replace("\\", "/")
|
| 48 |
-
if root_path.exists() and overlay_path.exists():
|
| 49 |
-
if root_path.read_bytes() == overlay_path.read_bytes():
|
| 50 |
-
identical.append(rel_name)
|
| 51 |
-
else:
|
| 52 |
-
root_ahead.append(rel_name)
|
| 53 |
-
elif root_path.exists():
|
| 54 |
-
missing_overlay.append(rel_name)
|
| 55 |
-
elif overlay_path.exists():
|
| 56 |
-
overlay_only.append(rel_name)
|
| 57 |
-
continue
|
| 58 |
-
|
| 59 |
-
for subrel, root_file in root_files.items():
|
| 60 |
-
rel_name = f"{rel}/{subrel}".replace("\\", "/")
|
| 61 |
-
overlay_file = overlay_files.get(subrel)
|
| 62 |
-
if overlay_file is None:
|
| 63 |
-
missing_overlay.append(rel_name)
|
| 64 |
-
elif root_file.read_bytes() == overlay_file.read_bytes():
|
| 65 |
-
identical.append(rel_name)
|
| 66 |
-
else:
|
| 67 |
-
root_ahead.append(rel_name)
|
| 68 |
-
|
| 69 |
-
for subrel in overlay_files:
|
| 70 |
-
if subrel not in root_files:
|
| 71 |
-
overlay_only.append(f"{rel}/{subrel}".replace("\\", "/"))
|
| 72 |
-
|
| 73 |
-
for bucket in (identical, root_ahead, overlay_only, missing_overlay):
|
| 74 |
-
bucket.sort()
|
| 75 |
-
|
| 76 |
-
return {
|
| 77 |
-
"identical": identical,
|
| 78 |
-
"root_ahead": root_ahead,
|
| 79 |
-
"overlay_only": overlay_only,
|
| 80 |
-
"missing_overlay": missing_overlay,
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 85 |
-
parser = argparse.ArgumentParser(description="Audit mirrored H200 overlay files against root source-of-truth paths")
|
| 86 |
-
parser.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[1])
|
| 87 |
-
parser.add_argument("--include-path", action="append", default=[])
|
| 88 |
-
return parser.parse_args(argv)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def main(argv: list[str] | None = None) -> int:
|
| 92 |
-
args = parse_args(argv)
|
| 93 |
-
include_paths = args.include_path or DEFAULT_INCLUDE_PATHS
|
| 94 |
-
payload = classify_overlay_pairs(repo_root=args.repo_root, include_paths=include_paths)
|
| 95 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 96 |
-
return 0
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
if __name__ == "__main__":
|
| 100 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
DEFAULT_INCLUDE_PATHS = [
|
| 10 |
+
"hydra",
|
| 11 |
+
"subsystems",
|
| 12 |
+
"scripts",
|
| 13 |
+
"htm_rust",
|
| 14 |
+
"harness",
|
| 15 |
+
"configs",
|
| 16 |
+
"prepare.py",
|
| 17 |
+
"prepare_nemotron.py",
|
| 18 |
+
"train.py",
|
| 19 |
+
"pyproject.toml",
|
| 20 |
+
"uv.lock",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _iter_files(path: Path) -> list[Path]:
|
| 25 |
+
if not path.exists():
|
| 26 |
+
return []
|
| 27 |
+
if path.is_file():
|
| 28 |
+
return [path]
|
| 29 |
+
return sorted(p for p in path.rglob("*") if p.is_file())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def classify_overlay_pairs(*, repo_root: Path, include_paths: list[str]) -> dict[str, list[str]]:
|
| 33 |
+
overlay_root = repo_root / "hf_jobs" / "feather_h200_image" / "overlay"
|
| 34 |
+
identical: list[str] = []
|
| 35 |
+
root_ahead: list[str] = []
|
| 36 |
+
overlay_only: list[str] = []
|
| 37 |
+
missing_overlay: list[str] = []
|
| 38 |
+
|
| 39 |
+
for rel in include_paths:
|
| 40 |
+
root_path = repo_root / rel
|
| 41 |
+
overlay_path = overlay_root / rel
|
| 42 |
+
|
| 43 |
+
root_files = {p.relative_to(root_path).as_posix(): p for p in _iter_files(root_path)} if root_path.exists() and root_path.is_dir() else {}
|
| 44 |
+
overlay_files = {p.relative_to(overlay_path).as_posix(): p for p in _iter_files(overlay_path)} if overlay_path.exists() and overlay_path.is_dir() else {}
|
| 45 |
+
|
| 46 |
+
if root_path.is_file() or overlay_path.is_file():
|
| 47 |
+
rel_name = rel.replace("\\", "/")
|
| 48 |
+
if root_path.exists() and overlay_path.exists():
|
| 49 |
+
if root_path.read_bytes() == overlay_path.read_bytes():
|
| 50 |
+
identical.append(rel_name)
|
| 51 |
+
else:
|
| 52 |
+
root_ahead.append(rel_name)
|
| 53 |
+
elif root_path.exists():
|
| 54 |
+
missing_overlay.append(rel_name)
|
| 55 |
+
elif overlay_path.exists():
|
| 56 |
+
overlay_only.append(rel_name)
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
for subrel, root_file in root_files.items():
|
| 60 |
+
rel_name = f"{rel}/{subrel}".replace("\\", "/")
|
| 61 |
+
overlay_file = overlay_files.get(subrel)
|
| 62 |
+
if overlay_file is None:
|
| 63 |
+
missing_overlay.append(rel_name)
|
| 64 |
+
elif root_file.read_bytes() == overlay_file.read_bytes():
|
| 65 |
+
identical.append(rel_name)
|
| 66 |
+
else:
|
| 67 |
+
root_ahead.append(rel_name)
|
| 68 |
+
|
| 69 |
+
for subrel in overlay_files:
|
| 70 |
+
if subrel not in root_files:
|
| 71 |
+
overlay_only.append(f"{rel}/{subrel}".replace("\\", "/"))
|
| 72 |
+
|
| 73 |
+
for bucket in (identical, root_ahead, overlay_only, missing_overlay):
|
| 74 |
+
bucket.sort()
|
| 75 |
+
|
| 76 |
+
return {
|
| 77 |
+
"identical": identical,
|
| 78 |
+
"root_ahead": root_ahead,
|
| 79 |
+
"overlay_only": overlay_only,
|
| 80 |
+
"missing_overlay": missing_overlay,
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 85 |
+
parser = argparse.ArgumentParser(description="Audit mirrored H200 overlay files against root source-of-truth paths")
|
| 86 |
+
parser.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[1])
|
| 87 |
+
parser.add_argument("--include-path", action="append", default=[])
|
| 88 |
+
return parser.parse_args(argv)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def main(argv: list[str] | None = None) -> int:
|
| 92 |
+
args = parse_args(argv)
|
| 93 |
+
include_paths = args.include_path or DEFAULT_INCLUDE_PATHS
|
| 94 |
+
payload = classify_overlay_pairs(repo_root=args.repo_root, include_paths=include_paths)
|
| 95 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 96 |
+
return 0
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
raise SystemExit(main())
|
overlay/scripts/benchmark_assets.py
CHANGED
|
@@ -1,124 +1,62 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
import
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
if
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
return latest[-1]
|
| 64 |
-
return None
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def hydrate_benchmark_assets(*, cache_dir: Path, output_repo: str, tokenizer_repo: str, token: str | None) -> dict[str, str]:
|
| 68 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 69 |
-
tok_dir = cache_dir / "tokenizer"
|
| 70 |
-
tok_dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
-
tok_repo = resolve_tokenizer_cache_repo(output_repo=tokenizer_repo, retina_cache_repo=tokenizer_repo)
|
| 72 |
-
tok_prefix = tokenizer_cache_prefix()
|
| 73 |
-
|
| 74 |
-
ckpt_path = None
|
| 75 |
-
for candidate in checkpoint_candidates(cache_dir):
|
| 76 |
-
if candidate.exists():
|
| 77 |
-
ckpt_path = candidate
|
| 78 |
-
break
|
| 79 |
-
try:
|
| 80 |
-
ckpt_path = _download_file(repo_id=output_repo, filename=candidate.name, local_dir=str(cache_dir), token=token)
|
| 81 |
-
break
|
| 82 |
-
except Exception:
|
| 83 |
-
continue
|
| 84 |
-
if ckpt_path is None:
|
| 85 |
-
try:
|
| 86 |
-
if HfApi is None:
|
| 87 |
-
raise RuntimeError("huggingface_hub unavailable")
|
| 88 |
-
files = HfApi(token=token).list_repo_files(repo_id=output_repo, repo_type="model", token=token)
|
| 89 |
-
remote_path = choose_remote_checkpoint_path(files)
|
| 90 |
-
if remote_path is not None:
|
| 91 |
-
parent, filename = remote_path.rsplit("/", 1)
|
| 92 |
-
downloaded_path = _download_file(
|
| 93 |
-
repo_id=output_repo,
|
| 94 |
-
filename=filename,
|
| 95 |
-
local_dir=str(cache_dir),
|
| 96 |
-
token=token,
|
| 97 |
-
subfolder=parent,
|
| 98 |
-
)
|
| 99 |
-
canonical_path = cache_dir / filename
|
| 100 |
-
if downloaded_path != canonical_path:
|
| 101 |
-
canonical_path.parent.mkdir(parents=True, exist_ok=True)
|
| 102 |
-
shutil.copy2(downloaded_path, canonical_path)
|
| 103 |
-
ckpt_path = canonical_path
|
| 104 |
-
except Exception:
|
| 105 |
-
pass
|
| 106 |
-
if ckpt_path is None:
|
| 107 |
-
raise FileNotFoundError(f"No benchmark checkpoint found in cache or repo {output_repo}")
|
| 108 |
-
|
| 109 |
-
tok_path = tok_dir / "tokenizer.pkl"
|
| 110 |
-
if not tok_path.exists():
|
| 111 |
-
downloaded_tok = _download_file(repo_id=tok_repo, filename="tokenizer.pkl", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
|
| 112 |
-
if downloaded_tok != tok_path:
|
| 113 |
-
shutil.copy2(downloaded_tok, tok_path)
|
| 114 |
-
|
| 115 |
-
token_bytes_path = tok_dir / "token_bytes.pt"
|
| 116 |
-
if not token_bytes_path.exists():
|
| 117 |
-
downloaded_token_bytes = _download_file(repo_id=tok_repo, filename="token_bytes.pt", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
|
| 118 |
-
if downloaded_token_bytes != token_bytes_path:
|
| 119 |
-
shutil.copy2(downloaded_token_bytes, token_bytes_path)
|
| 120 |
-
|
| 121 |
-
return {
|
| 122 |
-
"checkpoint_path": str(ckpt_path),
|
| 123 |
-
"tokenizer_dir": str(tok_dir),
|
| 124 |
-
}
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _download_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> Path:
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
|
| 11 |
+
path = hf_hub_download(
|
| 12 |
+
repo_id=repo_id,
|
| 13 |
+
repo_type="model",
|
| 14 |
+
filename=filename,
|
| 15 |
+
subfolder=subfolder,
|
| 16 |
+
token=token,
|
| 17 |
+
local_dir=local_dir,
|
| 18 |
+
local_dir_use_symlinks=False,
|
| 19 |
+
)
|
| 20 |
+
return Path(path)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def resolve_tokenizer_cache_repo(*, output_repo: str, retina_cache_repo: str) -> str:
|
| 24 |
+
return (
|
| 25 |
+
os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
|
| 26 |
+
or os.environ.get("FEATHER_HF_OUTPUT_REPO")
|
| 27 |
+
or os.environ.get("HF_REPO_ID")
|
| 28 |
+
or os.environ.get("HYDRA_RETINA_CACHE_REPO")
|
| 29 |
+
or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
|
| 30 |
+
or output_repo
|
| 31 |
+
or retina_cache_repo
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def tokenizer_cache_prefix() -> str:
|
| 36 |
+
vocab_size = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536"))
|
| 37 |
+
return f"tokenizer/vocab{vocab_size}"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def hydrate_benchmark_assets(*, cache_dir: Path, output_repo: str, tokenizer_repo: str, token: str | None) -> dict[str, str]:
|
| 41 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
tok_dir = cache_dir / "tokenizer"
|
| 43 |
+
tok_dir.mkdir(parents=True, exist_ok=True)
|
| 44 |
+
tok_repo = resolve_tokenizer_cache_repo(output_repo=tokenizer_repo, retina_cache_repo=tokenizer_repo)
|
| 45 |
+
tok_prefix = tokenizer_cache_prefix()
|
| 46 |
+
|
| 47 |
+
ckpt_path = cache_dir / "best_bpb.pt"
|
| 48 |
+
if not ckpt_path.exists():
|
| 49 |
+
ckpt_path = _download_file(repo_id=output_repo, filename="best_bpb.pt", local_dir=str(cache_dir), token=token)
|
| 50 |
+
|
| 51 |
+
tok_path = tok_dir / "tokenizer.pkl"
|
| 52 |
+
if not tok_path.exists():
|
| 53 |
+
tok_path = _download_file(repo_id=tok_repo, filename="tokenizer.pkl", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
|
| 54 |
+
|
| 55 |
+
token_bytes_path = tok_dir / "token_bytes.pt"
|
| 56 |
+
if not token_bytes_path.exists():
|
| 57 |
+
token_bytes_path = _download_file(repo_id=tok_repo, filename="token_bytes.pt", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
|
| 58 |
+
|
| 59 |
+
return {
|
| 60 |
+
"checkpoint_path": str(ckpt_path),
|
| 61 |
+
"tokenizer_dir": str(tok_dir),
|
| 62 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/benchmark_checkpoint.py
CHANGED
|
@@ -1,118 +1,19 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
if not preferred:
|
| 21 |
-
return None
|
| 22 |
-
pretrain = sorted([p for p in preferred if p.endswith("/pretrain_final.pt")])
|
| 23 |
-
if pretrain:
|
| 24 |
-
return pretrain[-1]
|
| 25 |
-
best = sorted([p for p in preferred if p.endswith("/best_bpb.pt")])
|
| 26 |
-
if best:
|
| 27 |
-
return best[-1]
|
| 28 |
-
latest = sorted([p for p in preferred if p.endswith("/latest.pt")])
|
| 29 |
-
if latest:
|
| 30 |
-
return latest[-1]
|
| 31 |
-
return None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def checkpoint_candidates(cache_dir: Path) -> list[Path]:
|
| 35 |
-
return [
|
| 36 |
-
cache_dir / "best_bpb.pt",
|
| 37 |
-
cache_dir / "pretrain_final.pt",
|
| 38 |
-
cache_dir / "latest.pt",
|
| 39 |
-
]
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def choose_checkpoint_candidate(cache_dir: Path) -> Path | None:
|
| 43 |
-
for path in checkpoint_candidates(cache_dir):
|
| 44 |
-
if path.exists():
|
| 45 |
-
return path
|
| 46 |
-
return None
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def resolve_checkpoint_source(*, cache_dir: Path, output_repo: str | None) -> dict[str, str]:
|
| 50 |
-
local = choose_checkpoint_candidate(cache_dir)
|
| 51 |
-
if local is not None:
|
| 52 |
-
return {"mode": "local", "path": str(local)}
|
| 53 |
-
if output_repo:
|
| 54 |
-
return {"mode": "remote", "repo_id": output_repo}
|
| 55 |
-
routing = resolve_routing(token=None)
|
| 56 |
-
return {"mode": "remote", "repo_id": routing.output_repo}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _download_checkpoint_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> str:
|
| 60 |
-
from huggingface_hub import hf_hub_download
|
| 61 |
-
|
| 62 |
-
return hf_hub_download(
|
| 63 |
-
repo_id=repo_id,
|
| 64 |
-
repo_type="model",
|
| 65 |
-
filename=filename,
|
| 66 |
-
subfolder=subfolder,
|
| 67 |
-
token=token,
|
| 68 |
-
local_dir=local_dir,
|
| 69 |
-
local_dir_use_symlinks=False,
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def hydrate_checkpoint(*, cache_dir: Path, output_repo: str | None, token: str | None) -> Path | None:
|
| 74 |
-
local = choose_checkpoint_candidate(cache_dir)
|
| 75 |
-
if local is not None:
|
| 76 |
-
return local
|
| 77 |
-
source = resolve_checkpoint_source(cache_dir=cache_dir, output_repo=output_repo)
|
| 78 |
-
if source["mode"] != "remote":
|
| 79 |
-
return None
|
| 80 |
-
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
-
for filename in ("best_bpb.pt", "pretrain_final.pt", "latest.pt"):
|
| 82 |
-
try:
|
| 83 |
-
path = Path(
|
| 84 |
-
_download_checkpoint_file(
|
| 85 |
-
repo_id=source["repo_id"],
|
| 86 |
-
filename=filename,
|
| 87 |
-
local_dir=str(cache_dir),
|
| 88 |
-
token=token,
|
| 89 |
-
)
|
| 90 |
-
)
|
| 91 |
-
if path.exists():
|
| 92 |
-
return path
|
| 93 |
-
except Exception:
|
| 94 |
-
continue
|
| 95 |
-
try:
|
| 96 |
-
if HfApi is None:
|
| 97 |
-
raise RuntimeError("huggingface_hub unavailable")
|
| 98 |
-
files = HfApi(token=token).list_repo_files(repo_id=source["repo_id"], repo_type="model", token=token)
|
| 99 |
-
remote_path = choose_remote_checkpoint_path(files)
|
| 100 |
-
if remote_path is not None:
|
| 101 |
-
parent, filename = remote_path.rsplit("/", 1)
|
| 102 |
-
downloaded = Path(
|
| 103 |
-
_download_checkpoint_file(
|
| 104 |
-
repo_id=source["repo_id"],
|
| 105 |
-
filename=filename,
|
| 106 |
-
local_dir=str(cache_dir),
|
| 107 |
-
token=token,
|
| 108 |
-
subfolder=parent,
|
| 109 |
-
)
|
| 110 |
-
)
|
| 111 |
-
canonical = cache_dir / filename
|
| 112 |
-
if downloaded != canonical:
|
| 113 |
-
shutil.copy2(downloaded, canonical)
|
| 114 |
-
if canonical.exists():
|
| 115 |
-
return canonical
|
| 116 |
-
except Exception:
|
| 117 |
-
pass
|
| 118 |
-
return None
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def checkpoint_candidates(cache_dir: Path) -> list[Path]:
|
| 8 |
+
return [
|
| 9 |
+
cache_dir / "best_bpb.pt",
|
| 10 |
+
cache_dir / "pretrain_final.pt",
|
| 11 |
+
cache_dir / "latest.pt",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def choose_checkpoint_candidate(cache_dir: Path) -> Path | None:
|
| 16 |
+
for path in checkpoint_candidates(cache_dir):
|
| 17 |
+
if path.exists():
|
| 18 |
+
return path
|
| 19 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/benchmark_checkpoint_report.py
CHANGED
|
@@ -1,50 +1,50 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def build_checkpoint_report(files: list[str]) -> dict[str, object]:
|
| 8 |
-
by_job: dict[str, dict[str, object]] = {}
|
| 9 |
-
for path in files:
|
| 10 |
-
parts = path.split("/")
|
| 11 |
-
if len(parts) < 3 or parts[0] != "jobs":
|
| 12 |
-
continue
|
| 13 |
-
job_id = parts[1]
|
| 14 |
-
filename = parts[-1]
|
| 15 |
-
if filename not in {"best_bpb.pt", "pretrain_final.pt", "latest.pt"}:
|
| 16 |
-
continue
|
| 17 |
-
row = by_job.setdefault(job_id, {"job_id": job_id, "paths": []})
|
| 18 |
-
row["paths"].append(path)
|
| 19 |
-
|
| 20 |
-
candidates = []
|
| 21 |
-
for job_id, row in by_job.items():
|
| 22 |
-
paths = list(row["paths"])
|
| 23 |
-
preferred = None
|
| 24 |
-
for suffix in ("pretrain_final.pt", "best_bpb.pt", "latest.pt"):
|
| 25 |
-
for path in paths:
|
| 26 |
-
if path.endswith(suffix):
|
| 27 |
-
preferred = path
|
| 28 |
-
break
|
| 29 |
-
if preferred is not None:
|
| 30 |
-
break
|
| 31 |
-
candidates.append({
|
| 32 |
-
"job_id": job_id,
|
| 33 |
-
"preferred_path": preferred,
|
| 34 |
-
"available_paths": sorted(paths),
|
| 35 |
-
})
|
| 36 |
-
|
| 37 |
-
candidates.sort(key=lambda row: row["job_id"], reverse=True)
|
| 38 |
-
return {
|
| 39 |
-
"n_candidates": len(candidates),
|
| 40 |
-
"candidates": candidates,
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def main() -> int:
|
| 45 |
-
print(json.dumps(build_checkpoint_report([]), indent=2, sort_keys=True))
|
| 46 |
-
return 0
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
if __name__ == "__main__":
|
| 50 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def build_checkpoint_report(files: list[str]) -> dict[str, object]:
|
| 8 |
+
by_job: dict[str, dict[str, object]] = {}
|
| 9 |
+
for path in files:
|
| 10 |
+
parts = path.split("/")
|
| 11 |
+
if len(parts) < 3 or parts[0] != "jobs":
|
| 12 |
+
continue
|
| 13 |
+
job_id = parts[1]
|
| 14 |
+
filename = parts[-1]
|
| 15 |
+
if filename not in {"best_bpb.pt", "pretrain_final.pt", "latest.pt"}:
|
| 16 |
+
continue
|
| 17 |
+
row = by_job.setdefault(job_id, {"job_id": job_id, "paths": []})
|
| 18 |
+
row["paths"].append(path)
|
| 19 |
+
|
| 20 |
+
candidates = []
|
| 21 |
+
for job_id, row in by_job.items():
|
| 22 |
+
paths = list(row["paths"])
|
| 23 |
+
preferred = None
|
| 24 |
+
for suffix in ("pretrain_final.pt", "best_bpb.pt", "latest.pt"):
|
| 25 |
+
for path in paths:
|
| 26 |
+
if path.endswith(suffix):
|
| 27 |
+
preferred = path
|
| 28 |
+
break
|
| 29 |
+
if preferred is not None:
|
| 30 |
+
break
|
| 31 |
+
candidates.append({
|
| 32 |
+
"job_id": job_id,
|
| 33 |
+
"preferred_path": preferred,
|
| 34 |
+
"available_paths": sorted(paths),
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
candidates.sort(key=lambda row: row["job_id"], reverse=True)
|
| 38 |
+
return {
|
| 39 |
+
"n_candidates": len(candidates),
|
| 40 |
+
"candidates": candidates,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main() -> int:
|
| 45 |
+
print(json.dumps(build_checkpoint_report([]), indent=2, sort_keys=True))
|
| 46 |
+
return 0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
raise SystemExit(main())
|
overlay/scripts/benchmark_contract.py
CHANGED
|
@@ -1,67 +1,67 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def _require_path(payload: dict[str, Any], path: str) -> None:
|
| 10 |
-
current: Any = payload
|
| 11 |
-
for part in path.split('.'):
|
| 12 |
-
if not isinstance(current, dict) or part not in current:
|
| 13 |
-
raise ValueError(f"missing required field: {path}")
|
| 14 |
-
current = current[part]
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def validate_benchmark_contract(payload: dict[str, Any]) -> None:
|
| 18 |
-
for field in [
|
| 19 |
-
"cycle_id",
|
| 20 |
-
"hardware_class",
|
| 21 |
-
"seeds",
|
| 22 |
-
"budget_modes",
|
| 23 |
-
"coding_benchmarks.fast_iteration",
|
| 24 |
-
"coding_benchmarks.milestone",
|
| 25 |
-
"reasoning_benchmarks.fast_iteration",
|
| 26 |
-
"reasoning_benchmarks.milestone",
|
| 27 |
-
"variants.hydra_full",
|
| 28 |
-
"variants.baseline_mamba_matched",
|
| 29 |
-
]:
|
| 30 |
-
_require_path(payload, field)
|
| 31 |
-
|
| 32 |
-
for section in [
|
| 33 |
-
payload["coding_benchmarks"]["fast_iteration"],
|
| 34 |
-
payload["coding_benchmarks"]["milestone"],
|
| 35 |
-
payload["reasoning_benchmarks"]["fast_iteration"],
|
| 36 |
-
payload["reasoning_benchmarks"]["milestone"],
|
| 37 |
-
]:
|
| 38 |
-
if "name" not in section or "primary_metric" not in section or "decode" not in section:
|
| 39 |
-
raise ValueError("benchmark sections require name, primary_metric, and decode")
|
| 40 |
-
|
| 41 |
-
if not isinstance(payload["seeds"], list) or len(payload["seeds"]) < 3:
|
| 42 |
-
raise ValueError("seeds must contain at least three values")
|
| 43 |
-
|
| 44 |
-
if payload["variants"]["hydra_full"].get("status") != "runnable_now":
|
| 45 |
-
raise ValueError("hydra_full must be runnable_now")
|
| 46 |
-
|
| 47 |
-
if payload["variants"]["baseline_mamba_matched"].get("status") != "runnable_now":
|
| 48 |
-
raise ValueError("baseline_mamba_matched must be runnable_now")
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def load_benchmark_contract(path: Path) -> dict[str, Any]:
|
| 52 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 53 |
-
if not isinstance(payload, dict):
|
| 54 |
-
raise ValueError("benchmark contract must be a JSON object")
|
| 55 |
-
validate_benchmark_contract(payload)
|
| 56 |
-
return payload
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def main() -> int:
|
| 60 |
-
path = Path("artifacts/cycle_1_execution_freeze.json")
|
| 61 |
-
payload = load_benchmark_contract(path)
|
| 62 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 63 |
-
return 0
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
if __name__ == "__main__":
|
| 67 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _require_path(payload: dict[str, Any], path: str) -> None:
|
| 10 |
+
current: Any = payload
|
| 11 |
+
for part in path.split('.'):
|
| 12 |
+
if not isinstance(current, dict) or part not in current:
|
| 13 |
+
raise ValueError(f"missing required field: {path}")
|
| 14 |
+
current = current[part]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def validate_benchmark_contract(payload: dict[str, Any]) -> None:
|
| 18 |
+
for field in [
|
| 19 |
+
"cycle_id",
|
| 20 |
+
"hardware_class",
|
| 21 |
+
"seeds",
|
| 22 |
+
"budget_modes",
|
| 23 |
+
"coding_benchmarks.fast_iteration",
|
| 24 |
+
"coding_benchmarks.milestone",
|
| 25 |
+
"reasoning_benchmarks.fast_iteration",
|
| 26 |
+
"reasoning_benchmarks.milestone",
|
| 27 |
+
"variants.hydra_full",
|
| 28 |
+
"variants.baseline_mamba_matched",
|
| 29 |
+
]:
|
| 30 |
+
_require_path(payload, field)
|
| 31 |
+
|
| 32 |
+
for section in [
|
| 33 |
+
payload["coding_benchmarks"]["fast_iteration"],
|
| 34 |
+
payload["coding_benchmarks"]["milestone"],
|
| 35 |
+
payload["reasoning_benchmarks"]["fast_iteration"],
|
| 36 |
+
payload["reasoning_benchmarks"]["milestone"],
|
| 37 |
+
]:
|
| 38 |
+
if "name" not in section or "primary_metric" not in section or "decode" not in section:
|
| 39 |
+
raise ValueError("benchmark sections require name, primary_metric, and decode")
|
| 40 |
+
|
| 41 |
+
if not isinstance(payload["seeds"], list) or len(payload["seeds"]) < 3:
|
| 42 |
+
raise ValueError("seeds must contain at least three values")
|
| 43 |
+
|
| 44 |
+
if payload["variants"]["hydra_full"].get("status") != "runnable_now":
|
| 45 |
+
raise ValueError("hydra_full must be runnable_now")
|
| 46 |
+
|
| 47 |
+
if payload["variants"]["baseline_mamba_matched"].get("status") != "runnable_now":
|
| 48 |
+
raise ValueError("baseline_mamba_matched must be runnable_now")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_benchmark_contract(path: Path) -> dict[str, Any]:
|
| 52 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 53 |
+
if not isinstance(payload, dict):
|
| 54 |
+
raise ValueError("benchmark contract must be a JSON object")
|
| 55 |
+
validate_benchmark_contract(payload)
|
| 56 |
+
return payload
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main() -> int:
|
| 60 |
+
path = Path("artifacts/cycle_1_execution_freeze.json")
|
| 61 |
+
payload = load_benchmark_contract(path)
|
| 62 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 63 |
+
return 0
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
raise SystemExit(main())
|
overlay/scripts/benchmark_datasets.py
CHANGED
|
@@ -1,190 +1,18 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
CANONICAL_SUBSETS = {
|
| 20 |
-
"MBPP": Path("data/benchmarks/mbpp.cycle1.jsonl"),
|
| 21 |
-
"GSM8K": Path("data/benchmarks/gsm8k.cycle1.jsonl"),
|
| 22 |
-
"HumanEval": Path("data/benchmarks/humaneval.cycle1.jsonl"),
|
| 23 |
-
"ARC-Challenge": Path("data/benchmarks/arc_challenge.cycle1.jsonl"),
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
DATASET_SOURCES: dict[str, dict[str, str]] = {
|
| 28 |
-
"MBPP": {"repo_id": "Muennighoff/mbpp", "subset": "full", "split": "test", "raw_path": "data/mbpp.jsonl"},
|
| 29 |
-
"GSM8K": {"repo_id": "openai/gsm8k", "subset": "main", "split": "test"},
|
| 30 |
-
"HumanEval": {"repo_id": "openai/openai_humaneval", "subset": "default", "split": "test"},
|
| 31 |
-
"ARC-Challenge": {"repo_id": "allenai/ai2_arc", "subset": "ARC-Challenge", "split": "validation"},
|
| 32 |
-
}
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def resolve_benchmark_dataset(benchmark_name: str, explicit_path: Path | None) -> Path:
|
| 36 |
-
if explicit_path is not None:
|
| 37 |
-
return explicit_path
|
| 38 |
-
if benchmark_name not in CANONICAL_SUBSETS:
|
| 39 |
-
raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
|
| 40 |
-
return Path.cwd() / CANONICAL_SUBSETS[benchmark_name]
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def _normalize_gsm8k_answer(answer: str) -> str:
|
| 44 |
-
if "####" in answer:
|
| 45 |
-
return answer.split("####")[-1].strip()
|
| 46 |
-
return answer.strip()
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def transform_dataset_row(benchmark_name: str, row: dict[str, Any], *, row_id: int) -> dict[str, Any]:
|
| 50 |
-
if benchmark_name == "GSM8K":
|
| 51 |
-
return {
|
| 52 |
-
"question": str(row["question"]),
|
| 53 |
-
"answer": _normalize_gsm8k_answer(str(row["answer"])),
|
| 54 |
-
}
|
| 55 |
-
if benchmark_name == "ARC-Challenge":
|
| 56 |
-
choices = row["choices"]
|
| 57 |
-
labels = list(choices["label"])
|
| 58 |
-
texts = list(choices["text"])
|
| 59 |
-
answer_key = str(row["answerKey"])
|
| 60 |
-
answer_index = labels.index(answer_key)
|
| 61 |
-
return {
|
| 62 |
-
"question": str(row["question"]),
|
| 63 |
-
"choices": [str(choice) for choice in texts],
|
| 64 |
-
"answer": str(texts[answer_index]),
|
| 65 |
-
}
|
| 66 |
-
if benchmark_name == "MBPP":
|
| 67 |
-
task_id = row.get("task_id", row_id)
|
| 68 |
-
return {
|
| 69 |
-
"task_id": str(task_id),
|
| 70 |
-
"prompt": str(row["text"]),
|
| 71 |
-
"tests": [str(test) for test in row["test_list"]],
|
| 72 |
-
}
|
| 73 |
-
if benchmark_name == "HumanEval":
|
| 74 |
-
task_id = row.get("task_id", row_id)
|
| 75 |
-
return {
|
| 76 |
-
"task_id": str(task_id),
|
| 77 |
-
"prompt": str(row["prompt"]),
|
| 78 |
-
"test": str(row["test"]),
|
| 79 |
-
}
|
| 80 |
-
raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def write_canonical_dataset(*, benchmark_name: str, rows: list[dict[str, Any]], out_path: Path, limit: int) -> int:
|
| 84 |
-
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
-
transformed = [transform_dataset_row(benchmark_name, row, row_id=index) for index, row in enumerate(rows[:limit])]
|
| 86 |
-
out_path.write_text("".join(json.dumps(row) + "\n" for row in transformed), encoding="utf-8")
|
| 87 |
-
return len(transformed)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def choose_dataset_parquet_path(benchmark_name: str, files: list[str]) -> str | None:
|
| 91 |
-
source = DATASET_SOURCES[benchmark_name]
|
| 92 |
-
subset = source["subset"].lower()
|
| 93 |
-
split = source["split"].lower()
|
| 94 |
-
candidates = [path for path in files if path.endswith(".parquet")]
|
| 95 |
-
preferred = [path for path in candidates if subset in path.lower() and split in path.lower()]
|
| 96 |
-
if preferred:
|
| 97 |
-
return sorted(preferred)[0]
|
| 98 |
-
split_only = [path for path in candidates if split in path.lower()]
|
| 99 |
-
if split_only:
|
| 100 |
-
return sorted(split_only)[0]
|
| 101 |
-
return sorted(candidates)[0] if candidates else None
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def download_dataset_snapshot(benchmark_name: str, *, cache_dir: Path, token: str | None) -> Path:
|
| 105 |
-
source = DATASET_SOURCES[benchmark_name]
|
| 106 |
-
if HfApi is None or hf_hub_download is None:
|
| 107 |
-
raise RuntimeError("huggingface_hub unavailable")
|
| 108 |
-
raw_path = source.get("raw_path")
|
| 109 |
-
if raw_path:
|
| 110 |
-
if "/" in raw_path:
|
| 111 |
-
subfolder, filename = raw_path.rsplit("/", 1)
|
| 112 |
-
else:
|
| 113 |
-
subfolder, filename = None, raw_path
|
| 114 |
-
downloaded = hf_hub_download(
|
| 115 |
-
repo_id=source["repo_id"],
|
| 116 |
-
repo_type="dataset",
|
| 117 |
-
filename=filename,
|
| 118 |
-
subfolder=subfolder,
|
| 119 |
-
token=token,
|
| 120 |
-
local_dir=str(cache_dir),
|
| 121 |
-
)
|
| 122 |
-
return Path(downloaded)
|
| 123 |
-
files = HfApi(token=token).list_repo_files(repo_id=source["repo_id"], repo_type="dataset", token=token)
|
| 124 |
-
parquet_path = choose_dataset_parquet_path(benchmark_name, files)
|
| 125 |
-
if parquet_path is None:
|
| 126 |
-
raise FileNotFoundError(f"No parquet dataset file found for {benchmark_name}")
|
| 127 |
-
if "/" in parquet_path:
|
| 128 |
-
subfolder, filename = parquet_path.rsplit("/", 1)
|
| 129 |
-
else:
|
| 130 |
-
subfolder, filename = None, parquet_path
|
| 131 |
-
downloaded = hf_hub_download(
|
| 132 |
-
repo_id=source["repo_id"],
|
| 133 |
-
repo_type="dataset",
|
| 134 |
-
filename=filename,
|
| 135 |
-
subfolder=subfolder,
|
| 136 |
-
token=token,
|
| 137 |
-
local_dir=str(cache_dir),
|
| 138 |
-
)
|
| 139 |
-
return Path(downloaded)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def hydrate_canonical_dataset(
|
| 143 |
-
*,
|
| 144 |
-
benchmark_name: str,
|
| 145 |
-
out_path: Path,
|
| 146 |
-
limit: int,
|
| 147 |
-
cache_dir: Path,
|
| 148 |
-
token: str | None,
|
| 149 |
-
) -> int:
|
| 150 |
-
source_path = download_dataset_snapshot(benchmark_name, cache_dir=cache_dir, token=token)
|
| 151 |
-
if source_path.suffix == ".jsonl":
|
| 152 |
-
rows = [json.loads(line) for line in source_path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
| 153 |
-
else:
|
| 154 |
-
table = pq.read_table(source_path)
|
| 155 |
-
rows = table.to_pylist()
|
| 156 |
-
return write_canonical_dataset(benchmark_name=benchmark_name, rows=rows, out_path=out_path, limit=limit)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 160 |
-
parser = argparse.ArgumentParser(description="Hydrate a canonical benchmark dataset JSONL from a public source")
|
| 161 |
-
parser.add_argument("--benchmark", required=True, choices=list(CANONICAL_SUBSETS))
|
| 162 |
-
parser.add_argument("--samples", type=Path)
|
| 163 |
-
parser.add_argument("--out", type=Path)
|
| 164 |
-
parser.add_argument("--limit", type=int, default=20)
|
| 165 |
-
parser.add_argument("--cache-dir", type=Path, default=Path(".cache/benchmarks"))
|
| 166 |
-
parser.add_argument("--token")
|
| 167 |
-
return parser.parse_args(argv)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def main(argv: list[str] | None = None) -> int:
|
| 171 |
-
args = parse_args(argv)
|
| 172 |
-
if args.samples is not None:
|
| 173 |
-
rows = [json.loads(line) for line in args.samples.read_text(encoding="utf-8").splitlines() if line.strip()]
|
| 174 |
-
out_path = args.out or resolve_benchmark_dataset(args.benchmark, None)
|
| 175 |
-
write_canonical_dataset(benchmark_name=args.benchmark, rows=rows, out_path=out_path, limit=args.limit)
|
| 176 |
-
return 0
|
| 177 |
-
|
| 178 |
-
out_path = args.out or resolve_benchmark_dataset(args.benchmark, None)
|
| 179 |
-
hydrate_canonical_dataset(
|
| 180 |
-
benchmark_name=args.benchmark,
|
| 181 |
-
out_path=out_path,
|
| 182 |
-
limit=args.limit,
|
| 183 |
-
cache_dir=args.cache_dir,
|
| 184 |
-
token=args.token,
|
| 185 |
-
)
|
| 186 |
-
return 0
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
if __name__ == "__main__":
|
| 190 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
CANONICAL_SUBSETS = {
|
| 8 |
+
"MBPP": Path("data/benchmarks/mbpp.cycle1.jsonl"),
|
| 9 |
+
"GSM8K": Path("data/benchmarks/gsm8k.cycle1.jsonl"),
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resolve_benchmark_dataset(benchmark_name: str, explicit_path: Path | None) -> Path:
|
| 14 |
+
if explicit_path is not None:
|
| 15 |
+
return explicit_path
|
| 16 |
+
if benchmark_name not in CANONICAL_SUBSETS:
|
| 17 |
+
raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
|
| 18 |
+
return Path.cwd() / CANONICAL_SUBSETS[benchmark_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/benchmark_hyena_stack.py
CHANGED
|
@@ -26,11 +26,8 @@ Invocation:
|
|
| 26 |
# On A100/A10G (production cloud hardware), use time=900 (15 min) for
|
| 27 |
# stable steady-state numbers.
|
| 28 |
|
| 29 |
-
After each run the script prints:
|
| 30 |
-
BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB>
|
| 31 |
-
|
| 32 |
-
If `--min-tps` is set (>0), the script exits non-zero when steady-state TPS
|
| 33 |
-
falls below the threshold.
|
| 34 |
|
| 35 |
Collate those lines into the matrix table manually, then pick the winner
|
| 36 |
for the 6-hour production run (HYDRA_TIME_BUDGET=21600).
|
|
@@ -50,34 +47,30 @@ REPO = Path(__file__).resolve().parents[1]
|
|
| 50 |
|
| 51 |
CONFIGS = {
|
| 52 |
# Baseline: B=8, no flash, no train-cache. Current reference point.
|
| 53 |
-
"baseline": {
|
| 54 |
-
"HYDRA_BATCH_SIZE": "8",
|
| 55 |
-
"
|
| 56 |
-
"HYDRA_HYENA_LAYERS": "3,7",
|
| 57 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 58 |
"HYDRA_HYENA_TRAIN_CACHE": "0",
|
| 59 |
"HYDRA_HYENA_FILTER_CACHE": "0",
|
| 60 |
},
|
| 61 |
-
"b16": {
|
| 62 |
-
"HYDRA_BATCH_SIZE": "16",
|
| 63 |
-
"
|
| 64 |
-
"HYDRA_HYENA_LAYERS": "3,7",
|
| 65 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 66 |
"HYDRA_HYENA_TRAIN_CACHE": "0",
|
| 67 |
"HYDRA_HYENA_FILTER_CACHE": "0",
|
| 68 |
},
|
| 69 |
-
"cache": {
|
| 70 |
-
"HYDRA_BATCH_SIZE": "16",
|
| 71 |
-
"
|
| 72 |
-
"HYDRA_HYENA_LAYERS": "3,7",
|
| 73 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 74 |
"HYDRA_HYENA_TRAIN_CACHE": "1",
|
| 75 |
"HYDRA_HYENA_FILTER_CACHE": "1",
|
| 76 |
},
|
| 77 |
-
"kernel": {
|
| 78 |
-
"HYDRA_BATCH_SIZE": "16",
|
| 79 |
-
"
|
| 80 |
-
"HYDRA_HYENA_LAYERS": "3,7",
|
| 81 |
"HYDRA_HYENA_FLASH_FFT": "1",
|
| 82 |
"HYDRA_HYENA_TRAIN_CACHE": "1",
|
| 83 |
"HYDRA_HYENA_FILTER_CACHE": "1",
|
|
@@ -88,7 +81,7 @@ CONFIGS = {
|
|
| 88 |
}
|
| 89 |
|
| 90 |
|
| 91 |
-
def build_env(cfg_overrides: dict
|
| 92 |
"""Compose a full env dict from the inherited env + config overrides."""
|
| 93 |
env = os.environ.copy()
|
| 94 |
# Ensure the Hyena layer selection is always present (defaults to off).
|
|
@@ -98,7 +91,7 @@ def build_env(cfg_overrides: dict[str, str]) -> dict[str, str]:
|
|
| 98 |
return env
|
| 99 |
|
| 100 |
|
| 101 |
-
def parse_step_line(line: str) -> dict
|
| 102 |
"""Parse a single step=... line into a dict of metrics, or None."""
|
| 103 |
if not line.startswith("step="):
|
| 104 |
return None
|
|
@@ -109,7 +102,7 @@ def parse_step_line(line: str) -> dict[str, float] | None:
|
|
| 109 |
return None
|
| 110 |
|
| 111 |
|
| 112 |
-
def summarize(log_path: Path, warmup_steps: int = 50) -> dict
|
| 113 |
"""Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak.
|
| 114 |
|
| 115 |
Skips the first `warmup_steps` to discard CUDA graph capture / autotune
|
|
@@ -145,29 +138,20 @@ def summarize(log_path: Path, warmup_steps: int = 50) -> dict[str, float]:
|
|
| 145 |
tps_sorted = sorted(tps_vals)
|
| 146 |
tps_steady = tps_sorted[len(tps_sorted) // 2] # median
|
| 147 |
|
| 148 |
-
return {
|
| 149 |
-
"tps_steady": tps_steady,
|
| 150 |
-
"bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0),
|
| 151 |
-
"vram_peak": vram_peak,
|
| 152 |
-
"steps": len(tps_vals) + warmup_steps,
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
def main() -> int:
|
| 164 |
-
ap = argparse.ArgumentParser()
|
| 165 |
-
ap.add_argument("--config", required=True, choices=list(CONFIGS))
|
| 166 |
-
ap.add_argument("--time", type=int, default=300, help="training seconds")
|
| 167 |
-
ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)")
|
| 168 |
-
ap.add_argument("--min-tps", type=float, default=50000.0, help="Required steady-state TPS floor (set 0 to disable)")
|
| 169 |
-
ap.add_argument("--warmup-steps", type=int, default=50, help="Number of initial steps to skip before TPS median")
|
| 170 |
-
args = ap.parse_args()
|
| 171 |
|
| 172 |
cfg = CONFIGS[args.config]
|
| 173 |
log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log"))
|
|
@@ -194,25 +178,16 @@ def main() -> int:
|
|
| 194 |
print(f"BENCH FAIL config={args.config}", flush=True)
|
| 195 |
return proc.returncode
|
| 196 |
|
| 197 |
-
summary = summarize(log_path
|
| 198 |
-
print(
|
| 199 |
-
f"BENCHMARK config={args.config} "
|
| 200 |
-
f"tps_steady={summary['tps_steady']:.0f} "
|
| 201 |
-
f"bpb_at_500={summary['bpb_at_500']:.4f} "
|
| 202 |
-
f"vram_peak={summary['vram_peak']:.0f}MiB "
|
| 203 |
-
f"steps={summary['steps']}",
|
| 204 |
-
flush=True,
|
| 205 |
-
)
|
| 206 |
-
|
| 207 |
-
if fails_tps_floor(summary, args.min_tps):
|
| 208 |
-
print(
|
| 209 |
-
f"BENCH FAIL config={args.config} tps_steady={summary['tps_steady']:.0f} < min_tps={args.min_tps:.0f}",
|
| 210 |
-
flush=True,
|
| 211 |
-
)
|
| 212 |
-
return 2
|
| 213 |
-
|
| 214 |
-
print(f"BENCH PASS config={args.config} min_tps={args.min_tps:.0f}", flush=True)
|
| 215 |
-
return 0
|
| 216 |
|
| 217 |
|
| 218 |
if __name__ == "__main__":
|
|
|
|
| 26 |
# On A100/A10G (production cloud hardware), use time=900 (15 min) for
|
| 27 |
# stable steady-state numbers.
|
| 28 |
|
| 29 |
+
After each run the script prints:
|
| 30 |
+
BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB>
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
Collate those lines into the matrix table manually, then pick the winner
|
| 33 |
for the 6-hour production run (HYDRA_TIME_BUDGET=21600).
|
|
|
|
| 47 |
|
| 48 |
CONFIGS = {
|
| 49 |
# Baseline: B=8, no flash, no train-cache. Current reference point.
|
| 50 |
+
"baseline": {
|
| 51 |
+
"HYDRA_BATCH_SIZE": "8",
|
| 52 |
+
"HYDRA_HYENA_LAYERS": "3,7",
|
|
|
|
| 53 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 54 |
"HYDRA_HYENA_TRAIN_CACHE": "0",
|
| 55 |
"HYDRA_HYENA_FILTER_CACHE": "0",
|
| 56 |
},
|
| 57 |
+
"b16": {
|
| 58 |
+
"HYDRA_BATCH_SIZE": "16",
|
| 59 |
+
"HYDRA_HYENA_LAYERS": "3,7",
|
|
|
|
| 60 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 61 |
"HYDRA_HYENA_TRAIN_CACHE": "0",
|
| 62 |
"HYDRA_HYENA_FILTER_CACHE": "0",
|
| 63 |
},
|
| 64 |
+
"cache": {
|
| 65 |
+
"HYDRA_BATCH_SIZE": "16",
|
| 66 |
+
"HYDRA_HYENA_LAYERS": "3,7",
|
|
|
|
| 67 |
"HYDRA_HYENA_FLASH_FFT": "0",
|
| 68 |
"HYDRA_HYENA_TRAIN_CACHE": "1",
|
| 69 |
"HYDRA_HYENA_FILTER_CACHE": "1",
|
| 70 |
},
|
| 71 |
+
"kernel": {
|
| 72 |
+
"HYDRA_BATCH_SIZE": "16",
|
| 73 |
+
"HYDRA_HYENA_LAYERS": "3,7",
|
|
|
|
| 74 |
"HYDRA_HYENA_FLASH_FFT": "1",
|
| 75 |
"HYDRA_HYENA_TRAIN_CACHE": "1",
|
| 76 |
"HYDRA_HYENA_FILTER_CACHE": "1",
|
|
|
|
| 81 |
}
|
| 82 |
|
| 83 |
|
| 84 |
+
def build_env(cfg_overrides: dict) -> dict:
|
| 85 |
"""Compose a full env dict from the inherited env + config overrides."""
|
| 86 |
env = os.environ.copy()
|
| 87 |
# Ensure the Hyena layer selection is always present (defaults to off).
|
|
|
|
| 91 |
return env
|
| 92 |
|
| 93 |
|
| 94 |
+
def parse_step_line(line: str) -> dict | None:
|
| 95 |
"""Parse a single step=... line into a dict of metrics, or None."""
|
| 96 |
if not line.startswith("step="):
|
| 97 |
return None
|
|
|
|
| 102 |
return None
|
| 103 |
|
| 104 |
|
| 105 |
+
def summarize(log_path: Path, warmup_steps: int = 50) -> dict:
|
| 106 |
"""Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak.
|
| 107 |
|
| 108 |
Skips the first `warmup_steps` to discard CUDA graph capture / autotune
|
|
|
|
| 138 |
tps_sorted = sorted(tps_vals)
|
| 139 |
tps_steady = tps_sorted[len(tps_sorted) // 2] # median
|
| 140 |
|
| 141 |
+
return {
|
| 142 |
+
"tps_steady": tps_steady,
|
| 143 |
+
"bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0),
|
| 144 |
+
"vram_peak": vram_peak,
|
| 145 |
+
"steps": len(tps_vals) + warmup_steps,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def main() -> int:
|
| 150 |
+
ap = argparse.ArgumentParser()
|
| 151 |
+
ap.add_argument("--config", required=True, choices=list(CONFIGS))
|
| 152 |
+
ap.add_argument("--time", type=int, default=300, help="training seconds")
|
| 153 |
+
ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)")
|
| 154 |
+
args = ap.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
cfg = CONFIGS[args.config]
|
| 157 |
log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log"))
|
|
|
|
| 178 |
print(f"BENCH FAIL config={args.config}", flush=True)
|
| 179 |
return proc.returncode
|
| 180 |
|
| 181 |
+
summary = summarize(log_path)
|
| 182 |
+
print(
|
| 183 |
+
f"BENCHMARK config={args.config} "
|
| 184 |
+
f"tps_steady={summary['tps_steady']:.0f} "
|
| 185 |
+
f"bpb_at_500={summary['bpb_at_500']:.4f} "
|
| 186 |
+
f"vram_peak={summary['vram_peak']:.0f}MiB "
|
| 187 |
+
f"steps={summary['steps']}",
|
| 188 |
+
flush=True,
|
| 189 |
+
)
|
| 190 |
+
return 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
|
| 193 |
if __name__ == "__main__":
|
overlay/scripts/benchmark_preflight.py
CHANGED
|
@@ -1,35 +1,31 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
|
| 6 |
-
from scripts.bootstrap_benchmark_env import build_bootstrap_report
|
| 7 |
-
from scripts.benchmark_checkpoint import choose_checkpoint_candidate
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def build_readiness_report(*, cache_dir: Path, hf_token_present: bool, dependencies_present: bool = True, missing_dependencies: list[str] | None = None, output_repo: str | None = None, tokenizer_repo: str | None = None) -> dict[str, object]:
|
| 11 |
-
checkpoint = choose_checkpoint_candidate(cache_dir)
|
| 12 |
-
tokenizer_dir = cache_dir / "tokenizer"
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
|
| 32 |
-
"tokenizer_repo": tokenizer_repo,
|
| 33 |
-
"hydration_possible": bool(hf_token_present and output_repo and tokenizer_repo),
|
| 34 |
-
"ready_for_hydra_benchmarks": checkpoint_present and tokenizer_ready and retina_ready and dependencies_present,
|
| 35 |
-
}
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from scripts.bootstrap_benchmark_env import build_bootstrap_report
|
| 7 |
+
from scripts.benchmark_checkpoint import choose_checkpoint_candidate
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_readiness_report(*, cache_dir: Path, hf_token_present: bool, dependencies_present: bool = True, missing_dependencies: list[str] | None = None, output_repo: str | None = None, tokenizer_repo: str | None = None) -> dict[str, object]:
|
| 11 |
+
checkpoint = choose_checkpoint_candidate(cache_dir)
|
| 12 |
+
tokenizer_dir = cache_dir / "tokenizer"
|
| 13 |
+
tokenizer_ready = (tokenizer_dir / "tokenizer.pkl").exists() and (tokenizer_dir / "token_bytes.pt").exists()
|
| 14 |
+
checkpoint_present = checkpoint is not None
|
| 15 |
+
runtime = build_bootstrap_report(missing_dependencies=list(missing_dependencies or []))
|
| 16 |
+
return {
|
| 17 |
+
"cache_dir": str(cache_dir),
|
| 18 |
+
"checkpoint_present": checkpoint_present,
|
| 19 |
+
"checkpoint_path": str(checkpoint) if checkpoint is not None else None,
|
| 20 |
+
"tokenizer_ready": tokenizer_ready,
|
| 21 |
+
"hf_token_present": hf_token_present,
|
| 22 |
+
"dependencies_present": dependencies_present,
|
| 23 |
+
"missing_dependencies": list(missing_dependencies or []),
|
| 24 |
+
"install_hint": runtime["install_hint"],
|
| 25 |
+
"install_command": runtime["install_command"],
|
| 26 |
+
"install_blockers": runtime["install_blockers"],
|
| 27 |
+
"output_repo": output_repo,
|
| 28 |
+
"tokenizer_repo": tokenizer_repo,
|
| 29 |
+
"hydration_possible": bool(hf_token_present and output_repo and tokenizer_repo),
|
| 30 |
+
"ready_for_hydra_benchmarks": checkpoint_present and tokenizer_ready and dependencies_present,
|
| 31 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/benchmark_runner.py
CHANGED
|
@@ -1,327 +1,248 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import re
|
| 7 |
-
import sys
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Any, Callable
|
| 10 |
-
|
| 11 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
-
if str(REPO_ROOT) not in sys.path:
|
| 13 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 14 |
-
|
| 15 |
-
LEDGER_TEMPLATE_PATH = REPO_ROOT / "artifacts" / "benchmark_ledger.template.json"
|
| 16 |
-
|
| 17 |
-
from scripts.hydra_generation import build_hydra_generator
|
| 18 |
-
from scripts.benchmark_datasets import resolve_benchmark_dataset as resolve_canonical_dataset
|
| 19 |
-
from scripts.benchmark_suite import build_prompt, validate_sample
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
return
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
"
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
"
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
parser.add_argument("--checkpoint", type=Path)
|
| 250 |
-
parser.add_argument("--device")
|
| 251 |
-
parser.add_argument("--max-new-tokens", type=int, default=256)
|
| 252 |
-
parser.add_argument("--temperature", type=float, default=0.2)
|
| 253 |
-
parser.add_argument("--top-p", type=float, default=0.95)
|
| 254 |
-
return parser.parse_args(argv)
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def main(argv: list[str] | None = None) -> int:
|
| 258 |
-
args = parse_args(argv)
|
| 259 |
-
sample_path = resolve_samples_path(args.benchmark, args.samples, args.suite)
|
| 260 |
-
try:
|
| 261 |
-
if args.generator_mode == "hydra":
|
| 262 |
-
generator = build_hydra_generator(
|
| 263 |
-
checkpoint_path=args.checkpoint,
|
| 264 |
-
device=args.device,
|
| 265 |
-
max_new_tokens=args.max_new_tokens,
|
| 266 |
-
temperature=args.temperature,
|
| 267 |
-
top_p=args.top_p,
|
| 268 |
-
)
|
| 269 |
-
else:
|
| 270 |
-
def generator(prompt: str) -> str:
|
| 271 |
-
return prompt
|
| 272 |
-
|
| 273 |
-
result = run_benchmark(args.benchmark, sample_path, generator)
|
| 274 |
-
exit_code = 0
|
| 275 |
-
except FileNotFoundError as exc:
|
| 276 |
-
result = {
|
| 277 |
-
"benchmark": args.benchmark,
|
| 278 |
-
"status": "failed",
|
| 279 |
-
"failure_type": "missing_checkpoint",
|
| 280 |
-
"error": str(exc),
|
| 281 |
-
"n_samples": 0,
|
| 282 |
-
}
|
| 283 |
-
exit_code = 1
|
| 284 |
-
except BenchmarkExecutionError as exc:
|
| 285 |
-
result = {
|
| 286 |
-
"benchmark": args.benchmark,
|
| 287 |
-
"status": "failed",
|
| 288 |
-
"failure_type": type(exc.cause).__name__,
|
| 289 |
-
"error": str(exc.cause),
|
| 290 |
-
"n_samples": 0,
|
| 291 |
-
"debug": {
|
| 292 |
-
"sample": {
|
| 293 |
-
"task_id": exc.sample.get("task_id"),
|
| 294 |
-
"question": exc.sample.get("question"),
|
| 295 |
-
},
|
| 296 |
-
"generated_output_preview": _preview_text(exc.generated_output),
|
| 297 |
-
"extracted_code_preview": _preview_text(exc.extracted_output) if exc.extracted_output is not None else None,
|
| 298 |
-
},
|
| 299 |
-
}
|
| 300 |
-
exit_code = 1
|
| 301 |
-
except Exception as exc: # noqa: BLE001
|
| 302 |
-
result = {
|
| 303 |
-
"benchmark": args.benchmark,
|
| 304 |
-
"status": "failed",
|
| 305 |
-
"failure_type": type(exc).__name__,
|
| 306 |
-
"error": str(exc),
|
| 307 |
-
"n_samples": 0,
|
| 308 |
-
}
|
| 309 |
-
exit_code = 1
|
| 310 |
-
|
| 311 |
-
if args.out is not None:
|
| 312 |
-
write_benchmark_result(args.out, result)
|
| 313 |
-
if args.ledger is not None and exit_code == 0:
|
| 314 |
-
append_benchmark_run_record(
|
| 315 |
-
args.ledger,
|
| 316 |
-
result,
|
| 317 |
-
benchmark_name=args.benchmark,
|
| 318 |
-
variant=args.variant,
|
| 319 |
-
seed=args.seed,
|
| 320 |
-
samples_path=sample_path,
|
| 321 |
-
)
|
| 322 |
-
print(json.dumps(result, indent=2, sort_keys=True))
|
| 323 |
-
return exit_code
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
if __name__ == "__main__":
|
| 327 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Callable
|
| 10 |
+
|
| 11 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
if str(REPO_ROOT) not in sys.path:
|
| 13 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 14 |
+
|
| 15 |
+
LEDGER_TEMPLATE_PATH = REPO_ROOT / "artifacts" / "benchmark_ledger.template.json"
|
| 16 |
+
|
| 17 |
+
from scripts.hydra_generation import build_hydra_generator
|
| 18 |
+
from scripts.benchmark_datasets import resolve_benchmark_dataset as resolve_canonical_dataset
|
| 19 |
+
from scripts.benchmark_suite import build_prompt, validate_sample
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_jsonl_samples(path: Path) -> list[dict[str, Any]]:
|
| 23 |
+
rows: list[dict[str, Any]] = []
|
| 24 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 25 |
+
if line.strip():
|
| 26 |
+
rows.append(json.loads(line))
|
| 27 |
+
return rows
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _score_mbpp(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
|
| 31 |
+
passed = 0
|
| 32 |
+
for sample in samples:
|
| 33 |
+
validate_sample("MBPP", sample)
|
| 34 |
+
code = generate_fn(build_prompt("MBPP", sample))
|
| 35 |
+
namespace: dict[str, Any] = {}
|
| 36 |
+
exec(code, namespace, namespace)
|
| 37 |
+
for test in sample["tests"]:
|
| 38 |
+
exec(test, namespace, namespace)
|
| 39 |
+
passed += 1
|
| 40 |
+
return passed / len(samples) if samples else 0.0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _extract_last_number(text: str) -> str | None:
|
| 44 |
+
matches = re.findall(r"-?\d+(?:\.\d+)?", text)
|
| 45 |
+
return matches[-1] if matches else None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _score_gsm8k(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
|
| 49 |
+
passed = 0
|
| 50 |
+
for sample in samples:
|
| 51 |
+
validate_sample("GSM8K", sample)
|
| 52 |
+
output = generate_fn(build_prompt("GSM8K", sample))
|
| 53 |
+
pred = _extract_last_number(output)
|
| 54 |
+
if pred is not None and pred == str(sample["answer"]):
|
| 55 |
+
passed += 1
|
| 56 |
+
return passed / len(samples) if samples else 0.0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _score_humaneval(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
|
| 60 |
+
passed = 0
|
| 61 |
+
for sample in samples:
|
| 62 |
+
validate_sample("HumanEval", sample)
|
| 63 |
+
code = generate_fn(build_prompt("HumanEval", sample))
|
| 64 |
+
namespace: dict[str, Any] = {}
|
| 65 |
+
exec(code, namespace, namespace)
|
| 66 |
+
exec(sample["test"], namespace, namespace)
|
| 67 |
+
passed += 1
|
| 68 |
+
return passed / len(samples) if samples else 0.0
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _score_arc(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
|
| 72 |
+
passed = 0
|
| 73 |
+
for sample in samples:
|
| 74 |
+
validate_sample("ARC-Challenge", sample)
|
| 75 |
+
output = generate_fn(build_prompt("ARC-Challenge", sample)).strip()
|
| 76 |
+
if output == str(sample["answer"]):
|
| 77 |
+
passed += 1
|
| 78 |
+
return passed / len(samples) if samples else 0.0
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def run_benchmark(benchmark_name: str, path: Path, generate_fn: Callable[[str], str]) -> dict[str, Any]:
|
| 82 |
+
samples = load_jsonl_samples(path)
|
| 83 |
+
if benchmark_name == "MBPP":
|
| 84 |
+
return {
|
| 85 |
+
"benchmark": "MBPP",
|
| 86 |
+
"primary_metric": "pass_at_1",
|
| 87 |
+
"score": _score_mbpp(samples, generate_fn),
|
| 88 |
+
"n_samples": len(samples),
|
| 89 |
+
}
|
| 90 |
+
if benchmark_name == "GSM8K":
|
| 91 |
+
return {
|
| 92 |
+
"benchmark": "GSM8K",
|
| 93 |
+
"primary_metric": "exact_match",
|
| 94 |
+
"score": _score_gsm8k(samples, generate_fn),
|
| 95 |
+
"n_samples": len(samples),
|
| 96 |
+
}
|
| 97 |
+
if benchmark_name == "HumanEval":
|
| 98 |
+
return {
|
| 99 |
+
"benchmark": "HumanEval",
|
| 100 |
+
"primary_metric": "pass_at_1",
|
| 101 |
+
"score": _score_humaneval(samples, generate_fn),
|
| 102 |
+
"n_samples": len(samples),
|
| 103 |
+
}
|
| 104 |
+
if benchmark_name == "ARC-Challenge":
|
| 105 |
+
return {
|
| 106 |
+
"benchmark": "ARC-Challenge",
|
| 107 |
+
"primary_metric": "accuracy",
|
| 108 |
+
"score": _score_arc(samples, generate_fn),
|
| 109 |
+
"n_samples": len(samples),
|
| 110 |
+
}
|
| 111 |
+
raise ValueError(f"Unsupported runnable benchmark: {benchmark_name}")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def write_benchmark_result(path: Path, payload: dict[str, Any]) -> None:
|
| 115 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def append_benchmark_run_record(
|
| 120 |
+
ledger_path: Path,
|
| 121 |
+
result: dict[str, Any],
|
| 122 |
+
*,
|
| 123 |
+
benchmark_name: str,
|
| 124 |
+
variant: str,
|
| 125 |
+
seed: int,
|
| 126 |
+
samples_path: Path,
|
| 127 |
+
) -> None:
|
| 128 |
+
if not ledger_path.exists():
|
| 129 |
+
ledger_path.parent.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
ledger_path.write_text(LEDGER_TEMPLATE_PATH.read_text(encoding="utf-8"), encoding="utf-8")
|
| 131 |
+
payload = json.loads(ledger_path.read_text(encoding="utf-8"))
|
| 132 |
+
run_records = payload.setdefault("run_records", [])
|
| 133 |
+
if len(run_records) == 1 and run_records[0].get("run_id") == "example-run-0001":
|
| 134 |
+
run_records.clear()
|
| 135 |
+
run_records.append(
|
| 136 |
+
{
|
| 137 |
+
"run_id": result.get("run_id", f"{benchmark_name.lower()}-{seed}"),
|
| 138 |
+
"commit": "HEAD",
|
| 139 |
+
"model_family": "hydra",
|
| 140 |
+
"variant": variant,
|
| 141 |
+
"seed": seed,
|
| 142 |
+
"hardware": {
|
| 143 |
+
"hardware_class": payload.get("benchmark_cycle", {}).get("hardware_class", "unknown"),
|
| 144 |
+
},
|
| 145 |
+
"budget": {
|
| 146 |
+
"budget_mode": payload.get("benchmark_cycle", {}).get("budget_modes", [None])[0],
|
| 147 |
+
},
|
| 148 |
+
"capability": {
|
| 149 |
+
"coding_score": result["score"] if benchmark_name in {"MBPP", "HumanEval"} else None,
|
| 150 |
+
"reasoning_score": result["score"] if benchmark_name in {"GSM8K", "ARC-Challenge"} else None,
|
| 151 |
+
},
|
| 152 |
+
"artifacts": {
|
| 153 |
+
"samples_path": str(samples_path),
|
| 154 |
+
},
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
ledger_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def resolve_samples_path(benchmark_name: str, samples: Path | None, suite_path: Path) -> Path:
|
| 161 |
+
if samples is not None:
|
| 162 |
+
return samples
|
| 163 |
+
payload = json.loads(suite_path.read_text(encoding="utf-8"))
|
| 164 |
+
for section in ("coding_benchmarks", "reasoning_benchmarks"):
|
| 165 |
+
if section not in payload:
|
| 166 |
+
continue
|
| 167 |
+
for slot in ("fast_iteration", "milestone"):
|
| 168 |
+
entry = payload[section].get(slot)
|
| 169 |
+
if isinstance(entry, dict) and entry.get("name") == benchmark_name and "sample_path" in entry:
|
| 170 |
+
return Path(entry["sample_path"])
|
| 171 |
+
try:
|
| 172 |
+
return resolve_canonical_dataset(benchmark_name, None)
|
| 173 |
+
except ValueError:
|
| 174 |
+
raise ValueError(f"No sample path found for benchmark: {benchmark_name}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 178 |
+
parser = argparse.ArgumentParser(description="Run a local benchmark against JSONL samples")
|
| 179 |
+
parser.add_argument("--benchmark", required=True, choices=["MBPP", "GSM8K", "HumanEval", "ARC-Challenge"])
|
| 180 |
+
parser.add_argument("--samples", type=Path)
|
| 181 |
+
parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
|
| 182 |
+
parser.add_argument("--out", type=Path)
|
| 183 |
+
parser.add_argument("--ledger", type=Path)
|
| 184 |
+
parser.add_argument("--variant", default="hydra_full")
|
| 185 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 186 |
+
parser.add_argument("--generator-mode", choices=["stub", "hydra"], default="stub")
|
| 187 |
+
parser.add_argument("--checkpoint", type=Path)
|
| 188 |
+
parser.add_argument("--device")
|
| 189 |
+
parser.add_argument("--max-new-tokens", type=int, default=256)
|
| 190 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 191 |
+
parser.add_argument("--top-p", type=float, default=0.95)
|
| 192 |
+
return parser.parse_args(argv)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def main(argv: list[str] | None = None) -> int:
|
| 196 |
+
args = parse_args(argv)
|
| 197 |
+
sample_path = resolve_samples_path(args.benchmark, args.samples, args.suite)
|
| 198 |
+
try:
|
| 199 |
+
if args.generator_mode == "hydra":
|
| 200 |
+
generator = build_hydra_generator(
|
| 201 |
+
checkpoint_path=args.checkpoint,
|
| 202 |
+
device=args.device,
|
| 203 |
+
max_new_tokens=args.max_new_tokens,
|
| 204 |
+
temperature=args.temperature,
|
| 205 |
+
top_p=args.top_p,
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
def generator(prompt: str) -> str:
|
| 209 |
+
return prompt
|
| 210 |
+
|
| 211 |
+
result = run_benchmark(args.benchmark, sample_path, generator)
|
| 212 |
+
exit_code = 0
|
| 213 |
+
except FileNotFoundError as exc:
|
| 214 |
+
result = {
|
| 215 |
+
"benchmark": args.benchmark,
|
| 216 |
+
"status": "failed",
|
| 217 |
+
"failure_type": "missing_checkpoint",
|
| 218 |
+
"error": str(exc),
|
| 219 |
+
"n_samples": 0,
|
| 220 |
+
}
|
| 221 |
+
exit_code = 1
|
| 222 |
+
except Exception as exc: # noqa: BLE001
|
| 223 |
+
result = {
|
| 224 |
+
"benchmark": args.benchmark,
|
| 225 |
+
"status": "failed",
|
| 226 |
+
"failure_type": type(exc).__name__,
|
| 227 |
+
"error": str(exc),
|
| 228 |
+
"n_samples": 0,
|
| 229 |
+
}
|
| 230 |
+
exit_code = 1
|
| 231 |
+
|
| 232 |
+
if args.out is not None:
|
| 233 |
+
write_benchmark_result(args.out, result)
|
| 234 |
+
if args.ledger is not None and exit_code == 0:
|
| 235 |
+
append_benchmark_run_record(
|
| 236 |
+
args.ledger,
|
| 237 |
+
result,
|
| 238 |
+
benchmark_name=args.benchmark,
|
| 239 |
+
variant=args.variant,
|
| 240 |
+
seed=args.seed,
|
| 241 |
+
samples_path=sample_path,
|
| 242 |
+
)
|
| 243 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
| 244 |
+
return exit_code
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
raise SystemExit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/benchmark_suite.py
CHANGED
|
@@ -1,84 +1,84 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass(frozen=True)
|
| 11 |
-
class BenchmarkSpec:
|
| 12 |
-
name: str
|
| 13 |
-
family: str
|
| 14 |
-
required_fields: tuple[str, ...]
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
REGISTRY: dict[str, BenchmarkSpec] = {
|
| 18 |
-
"MBPP": BenchmarkSpec("MBPP", "coding", ("task_id", "prompt", "tests")),
|
| 19 |
-
"HumanEval": BenchmarkSpec("HumanEval", "coding", ("task_id", "prompt", "test")),
|
| 20 |
-
"GSM8K": BenchmarkSpec("GSM8K", "reasoning", ("question", "answer")),
|
| 21 |
-
"ARC-Challenge": BenchmarkSpec("ARC-Challenge", "reasoning", ("question", "choices", "answer")),
|
| 22 |
-
}
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def validate_sample(benchmark_name: str, sample: dict[str, Any]) -> None:
|
| 26 |
-
spec = REGISTRY[benchmark_name]
|
| 27 |
-
for field in spec.required_fields:
|
| 28 |
-
if field not in sample:
|
| 29 |
-
raise ValueError(f"{benchmark_name} sample missing required field: {field}")
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def build_prompt(benchmark_name: str, sample: dict[str, Any]) -> str:
|
| 33 |
-
validate_sample(benchmark_name, sample)
|
| 34 |
-
if benchmark_name == "MBPP":
|
| 35 |
-
tests = sample["tests"]
|
| 36 |
-
rendered_tests = "\n".join(str(t) for t in tests)
|
| 37 |
-
return (
|
| 38 |
-
"Write a Python function that solves the task below.\n\n"
|
| 39 |
-
f"Task:\n{sample['prompt']}\n\n"
|
| 40 |
-
f"Tests:\n{rendered_tests}\n"
|
| 41 |
-
)
|
| 42 |
-
if benchmark_name == "HumanEval":
|
| 43 |
-
return (
|
| 44 |
-
"Complete the following Python function exactly as specified.\n\n"
|
| 45 |
-
f"Prompt:\n{sample['prompt']}\n\n"
|
| 46 |
-
f"Reference test:\n{sample['test']}\n"
|
| 47 |
-
)
|
| 48 |
-
if benchmark_name == "GSM8K":
|
| 49 |
-
return f"Solve the following math word problem. Return only the final answer.\n\nQuestion: {sample['question']}\n"
|
| 50 |
-
if benchmark_name == "ARC-Challenge":
|
| 51 |
-
choices = sample["choices"]
|
| 52 |
-
rendered_choices = "\n".join(f"- {choice}" for choice in choices)
|
| 53 |
-
return (
|
| 54 |
-
"Answer the following multiple-choice science question. Return only the correct option text or label.\n\n"
|
| 55 |
-
f"Question: {sample['question']}\nChoices:\n{rendered_choices}\n"
|
| 56 |
-
)
|
| 57 |
-
raise ValueError(f"Unknown benchmark: {benchmark_name}")
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def load_cycle_benchmark_suite(path: Path) -> dict[str, dict[str, BenchmarkSpec]]:
|
| 61 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 62 |
-
out: dict[str, dict[str, BenchmarkSpec]] = {"coding_benchmarks": {}, "reasoning_benchmarks": {}}
|
| 63 |
-
for section in ("coding_benchmarks", "reasoning_benchmarks"):
|
| 64 |
-
if section not in payload:
|
| 65 |
-
raise ValueError(f"missing benchmark section: {section}")
|
| 66 |
-
for slot in ("fast_iteration", "milestone"):
|
| 67 |
-
if slot not in payload[section]:
|
| 68 |
-
raise ValueError(f"missing benchmark slot: {section}.{slot}")
|
| 69 |
-
name = payload[section][slot]["name"]
|
| 70 |
-
if name not in REGISTRY:
|
| 71 |
-
raise ValueError(f"unsupported benchmark: {name}")
|
| 72 |
-
out[section][slot] = REGISTRY[name]
|
| 73 |
-
return out
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def main() -> int:
|
| 77 |
-
path = Path("artifacts/benchmark_suite.cycle1.json")
|
| 78 |
-
suite = load_cycle_benchmark_suite(path)
|
| 79 |
-
print(json.dumps({k: {slot: spec.name for slot, spec in section.items()} for k, section in suite.items()}, indent=2))
|
| 80 |
-
return 0
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if __name__ == "__main__":
|
| 84 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass(frozen=True)
|
| 11 |
+
class BenchmarkSpec:
|
| 12 |
+
name: str
|
| 13 |
+
family: str
|
| 14 |
+
required_fields: tuple[str, ...]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
REGISTRY: dict[str, BenchmarkSpec] = {
|
| 18 |
+
"MBPP": BenchmarkSpec("MBPP", "coding", ("task_id", "prompt", "tests")),
|
| 19 |
+
"HumanEval": BenchmarkSpec("HumanEval", "coding", ("task_id", "prompt", "test")),
|
| 20 |
+
"GSM8K": BenchmarkSpec("GSM8K", "reasoning", ("question", "answer")),
|
| 21 |
+
"ARC-Challenge": BenchmarkSpec("ARC-Challenge", "reasoning", ("question", "choices", "answer")),
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def validate_sample(benchmark_name: str, sample: dict[str, Any]) -> None:
|
| 26 |
+
spec = REGISTRY[benchmark_name]
|
| 27 |
+
for field in spec.required_fields:
|
| 28 |
+
if field not in sample:
|
| 29 |
+
raise ValueError(f"{benchmark_name} sample missing required field: {field}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def build_prompt(benchmark_name: str, sample: dict[str, Any]) -> str:
|
| 33 |
+
validate_sample(benchmark_name, sample)
|
| 34 |
+
if benchmark_name == "MBPP":
|
| 35 |
+
tests = sample["tests"]
|
| 36 |
+
rendered_tests = "\n".join(str(t) for t in tests)
|
| 37 |
+
return (
|
| 38 |
+
"Write a Python function that solves the task below.\n\n"
|
| 39 |
+
f"Task:\n{sample['prompt']}\n\n"
|
| 40 |
+
f"Tests:\n{rendered_tests}\n"
|
| 41 |
+
)
|
| 42 |
+
if benchmark_name == "HumanEval":
|
| 43 |
+
return (
|
| 44 |
+
"Complete the following Python function exactly as specified.\n\n"
|
| 45 |
+
f"Prompt:\n{sample['prompt']}\n\n"
|
| 46 |
+
f"Reference test:\n{sample['test']}\n"
|
| 47 |
+
)
|
| 48 |
+
if benchmark_name == "GSM8K":
|
| 49 |
+
return f"Solve the following math word problem. Return only the final answer.\n\nQuestion: {sample['question']}\n"
|
| 50 |
+
if benchmark_name == "ARC-Challenge":
|
| 51 |
+
choices = sample["choices"]
|
| 52 |
+
rendered_choices = "\n".join(f"- {choice}" for choice in choices)
|
| 53 |
+
return (
|
| 54 |
+
"Answer the following multiple-choice science question. Return only the correct option text or label.\n\n"
|
| 55 |
+
f"Question: {sample['question']}\nChoices:\n{rendered_choices}\n"
|
| 56 |
+
)
|
| 57 |
+
raise ValueError(f"Unknown benchmark: {benchmark_name}")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_cycle_benchmark_suite(path: Path) -> dict[str, dict[str, BenchmarkSpec]]:
|
| 61 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 62 |
+
out: dict[str, dict[str, BenchmarkSpec]] = {"coding_benchmarks": {}, "reasoning_benchmarks": {}}
|
| 63 |
+
for section in ("coding_benchmarks", "reasoning_benchmarks"):
|
| 64 |
+
if section not in payload:
|
| 65 |
+
raise ValueError(f"missing benchmark section: {section}")
|
| 66 |
+
for slot in ("fast_iteration", "milestone"):
|
| 67 |
+
if slot not in payload[section]:
|
| 68 |
+
raise ValueError(f"missing benchmark slot: {section}.{slot}")
|
| 69 |
+
name = payload[section][slot]["name"]
|
| 70 |
+
if name not in REGISTRY:
|
| 71 |
+
raise ValueError(f"unsupported benchmark: {name}")
|
| 72 |
+
out[section][slot] = REGISTRY[name]
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main() -> int:
|
| 77 |
+
path = Path("artifacts/benchmark_suite.cycle1.json")
|
| 78 |
+
suite = load_cycle_benchmark_suite(path)
|
| 79 |
+
print(json.dumps({k: {slot: spec.name for slot, spec in section.items()} for k, section in suite.items()}, indent=2))
|
| 80 |
+
return 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
raise SystemExit(main())
|
overlay/scripts/bootstrap_benchmark_env.py
CHANGED
|
@@ -1,63 +1,63 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
import shutil
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
PACKAGE_MAP = {
|
| 11 |
-
"mamba_ssm": "mamba-ssm",
|
| 12 |
-
"transformers": "transformers",
|
| 13 |
-
}
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
|
| 17 |
-
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
|
| 18 |
-
return [] if not packages else ["python", "-m", "pip", "install", *packages]
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def diagnose_install_blockers(
|
| 22 |
-
*,
|
| 23 |
-
missing_dependencies: list[str],
|
| 24 |
-
torch_version: str,
|
| 25 |
-
cuda_available: bool,
|
| 26 |
-
nvcc_present: bool,
|
| 27 |
-
) -> list[str]:
|
| 28 |
-
blockers: list[str] = []
|
| 29 |
-
if "mamba_ssm" in missing_dependencies:
|
| 30 |
-
if "+cpu" in torch_version or not cuda_available:
|
| 31 |
-
blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
|
| 32 |
-
if not nvcc_present:
|
| 33 |
-
blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
|
| 34 |
-
return blockers
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
|
| 38 |
-
ready = len(missing_dependencies) == 0
|
| 39 |
-
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
|
| 40 |
-
install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
|
| 41 |
-
blockers = diagnose_install_blockers(
|
| 42 |
-
missing_dependencies=missing_dependencies,
|
| 43 |
-
torch_version=getattr(torch, "__version__", "unknown"),
|
| 44 |
-
cuda_available=torch.cuda.is_available(),
|
| 45 |
-
nvcc_present=shutil.which("nvcc") is not None,
|
| 46 |
-
)
|
| 47 |
-
return {
|
| 48 |
-
"ready": ready,
|
| 49 |
-
"missing_dependencies": list(missing_dependencies),
|
| 50 |
-
"install_hint": install_hint,
|
| 51 |
-
"install_command": build_install_command(missing_dependencies=missing_dependencies),
|
| 52 |
-
"install_blockers": blockers,
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def main() -> int:
|
| 57 |
-
report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
|
| 58 |
-
print(json.dumps(report, indent=2, sort_keys=True))
|
| 59 |
-
return 0
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
if __name__ == "__main__":
|
| 63 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
PACKAGE_MAP = {
|
| 11 |
+
"mamba_ssm": "mamba-ssm",
|
| 12 |
+
"transformers": "transformers",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
|
| 17 |
+
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
|
| 18 |
+
return [] if not packages else ["python", "-m", "pip", "install", *packages]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def diagnose_install_blockers(
|
| 22 |
+
*,
|
| 23 |
+
missing_dependencies: list[str],
|
| 24 |
+
torch_version: str,
|
| 25 |
+
cuda_available: bool,
|
| 26 |
+
nvcc_present: bool,
|
| 27 |
+
) -> list[str]:
|
| 28 |
+
blockers: list[str] = []
|
| 29 |
+
if "mamba_ssm" in missing_dependencies:
|
| 30 |
+
if "+cpu" in torch_version or not cuda_available:
|
| 31 |
+
blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
|
| 32 |
+
if not nvcc_present:
|
| 33 |
+
blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
|
| 34 |
+
return blockers
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
|
| 38 |
+
ready = len(missing_dependencies) == 0
|
| 39 |
+
packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
|
| 40 |
+
install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
|
| 41 |
+
blockers = diagnose_install_blockers(
|
| 42 |
+
missing_dependencies=missing_dependencies,
|
| 43 |
+
torch_version=getattr(torch, "__version__", "unknown"),
|
| 44 |
+
cuda_available=torch.cuda.is_available(),
|
| 45 |
+
nvcc_present=shutil.which("nvcc") is not None,
|
| 46 |
+
)
|
| 47 |
+
return {
|
| 48 |
+
"ready": ready,
|
| 49 |
+
"missing_dependencies": list(missing_dependencies),
|
| 50 |
+
"install_hint": install_hint,
|
| 51 |
+
"install_command": build_install_command(missing_dependencies=missing_dependencies),
|
| 52 |
+
"install_blockers": blockers,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main() -> int:
|
| 57 |
+
report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
|
| 58 |
+
print(json.dumps(report, indent=2, sort_keys=True))
|
| 59 |
+
return 0
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
raise SystemExit(main())
|
overlay/scripts/cycle1a_report.py
CHANGED
|
@@ -1,52 +1,52 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
from collections import defaultdict
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def build_cycle1a_report(run_dir: Path) -> dict[str, Any]:
|
| 11 |
-
runs = []
|
| 12 |
-
for path in sorted(run_dir.glob("*.json")):
|
| 13 |
-
try:
|
| 14 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 15 |
-
except Exception:
|
| 16 |
-
continue
|
| 17 |
-
if isinstance(payload, dict) and "benchmark" in payload:
|
| 18 |
-
runs.append((path.name, payload))
|
| 19 |
-
|
| 20 |
-
n_failed = sum(1 for _, payload in runs if payload.get("status") == "failed")
|
| 21 |
-
by_benchmark: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
|
| 22 |
-
for filename, payload in runs:
|
| 23 |
-
if payload.get("status") == "failed":
|
| 24 |
-
continue
|
| 25 |
-
parts = filename.removesuffix('.json').split('_')
|
| 26 |
-
if len(parts) < 3:
|
| 27 |
-
continue
|
| 28 |
-
benchmark = payload["benchmark"]
|
| 29 |
-
variant = '_'.join(parts[1:-1])
|
| 30 |
-
score = float(payload.get("score", 0.0))
|
| 31 |
-
slot = by_benchmark.setdefault(benchmark, {}).setdefault(variant, {"scores": []})
|
| 32 |
-
slot["scores"].append(score)
|
| 33 |
-
|
| 34 |
-
for benchmark, variants in by_benchmark.items():
|
| 35 |
-
for variant, slot in variants.items():
|
| 36 |
-
scores = slot.pop("scores")
|
| 37 |
-
slot["mean_score"] = sum(scores) / len(scores)
|
| 38 |
-
slot["n_scores"] = len(scores)
|
| 39 |
-
|
| 40 |
-
if runs and n_failed == len(runs):
|
| 41 |
-
panel_status = "blocked"
|
| 42 |
-
elif n_failed > 0:
|
| 43 |
-
panel_status = "partial"
|
| 44 |
-
else:
|
| 45 |
-
panel_status = "ready"
|
| 46 |
-
|
| 47 |
-
return {
|
| 48 |
-
"n_runs": len(runs),
|
| 49 |
-
"n_failed": n_failed,
|
| 50 |
-
"panel_status": panel_status,
|
| 51 |
-
"by_benchmark": by_benchmark,
|
| 52 |
-
}
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def build_cycle1a_report(run_dir: Path) -> dict[str, Any]:
|
| 11 |
+
runs = []
|
| 12 |
+
for path in sorted(run_dir.glob("*.json")):
|
| 13 |
+
try:
|
| 14 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 15 |
+
except Exception:
|
| 16 |
+
continue
|
| 17 |
+
if isinstance(payload, dict) and "benchmark" in payload:
|
| 18 |
+
runs.append((path.name, payload))
|
| 19 |
+
|
| 20 |
+
n_failed = sum(1 for _, payload in runs if payload.get("status") == "failed")
|
| 21 |
+
by_benchmark: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
|
| 22 |
+
for filename, payload in runs:
|
| 23 |
+
if payload.get("status") == "failed":
|
| 24 |
+
continue
|
| 25 |
+
parts = filename.removesuffix('.json').split('_')
|
| 26 |
+
if len(parts) < 3:
|
| 27 |
+
continue
|
| 28 |
+
benchmark = payload["benchmark"]
|
| 29 |
+
variant = '_'.join(parts[1:-1])
|
| 30 |
+
score = float(payload.get("score", 0.0))
|
| 31 |
+
slot = by_benchmark.setdefault(benchmark, {}).setdefault(variant, {"scores": []})
|
| 32 |
+
slot["scores"].append(score)
|
| 33 |
+
|
| 34 |
+
for benchmark, variants in by_benchmark.items():
|
| 35 |
+
for variant, slot in variants.items():
|
| 36 |
+
scores = slot.pop("scores")
|
| 37 |
+
slot["mean_score"] = sum(scores) / len(scores)
|
| 38 |
+
slot["n_scores"] = len(scores)
|
| 39 |
+
|
| 40 |
+
if runs and n_failed == len(runs):
|
| 41 |
+
panel_status = "blocked"
|
| 42 |
+
elif n_failed > 0:
|
| 43 |
+
panel_status = "partial"
|
| 44 |
+
else:
|
| 45 |
+
panel_status = "ready"
|
| 46 |
+
|
| 47 |
+
return {
|
| 48 |
+
"n_runs": len(runs),
|
| 49 |
+
"n_failed": n_failed,
|
| 50 |
+
"panel_status": panel_status,
|
| 51 |
+
"by_benchmark": by_benchmark,
|
| 52 |
+
}
|
overlay/scripts/cycle_executor.py
CHANGED
|
@@ -1,332 +1,312 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import importlib.util
|
| 6 |
-
import importlib
|
| 7 |
-
import json
|
| 8 |
-
import os
|
| 9 |
-
import subprocess
|
| 10 |
-
import sys
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
from typing import Any
|
| 13 |
-
|
| 14 |
-
from scripts.benchmark_preflight import build_readiness_report
|
| 15 |
-
from scripts.hf_routing import resolve_routing
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
-
FREEZE_PATH = REPO_ROOT / "artifacts" / "cycle_1_execution_freeze.json"
|
| 20 |
-
RUNNER_PATH = REPO_ROOT / "scripts" / "benchmark_runner.py"
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def active_hf_token() -> str | None:
|
| 24 |
-
token = os.environ.get("HF_TOKEN")
|
| 25 |
-
if token:
|
| 26 |
-
return token
|
| 27 |
-
try:
|
| 28 |
-
from huggingface_hub.utils import get_token
|
| 29 |
-
return get_token()
|
| 30 |
-
except Exception:
|
| 31 |
-
return None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def missing_benchmark_dependencies() -> list[str]:
|
| 35 |
-
required = ["mamba_ssm", "transformers"]
|
| 36 |
-
missing: list[str] = []
|
| 37 |
-
for name in required:
|
| 38 |
-
try:
|
| 39 |
-
spec = importlib.util.find_spec(name)
|
| 40 |
-
except (ImportError, ValueError):
|
| 41 |
-
spec = None
|
| 42 |
-
if spec is None:
|
| 43 |
-
try:
|
| 44 |
-
importlib.import_module(name)
|
| 45 |
-
except Exception:
|
| 46 |
-
missing.append(name)
|
| 47 |
-
return missing
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def load_cycle_freeze(path: Path) -> dict[str, Any]:
|
| 51 |
-
return json.loads(path.read_text(encoding="utf-8"))
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def load_cycle_benchmarks(path: Path) -> list[str]:
|
| 55 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 56 |
-
out: list[str] = []
|
| 57 |
-
for section in ("coding_benchmarks", "reasoning_benchmarks"):
|
| 58 |
-
for slot in ("fast_iteration", "milestone"):
|
| 59 |
-
entry = payload.get(section, {}).get(slot)
|
| 60 |
-
if isinstance(entry, dict) and entry.get("name"):
|
| 61 |
-
out.append(str(entry["name"]))
|
| 62 |
-
return out
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
variant=args.variant,
|
| 314 |
-
seed=args.seed,
|
| 315 |
-
out_dir=args.out_dir,
|
| 316 |
-
)
|
| 317 |
-
proc = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env)
|
| 318 |
-
if args.summary_out is not None:
|
| 319 |
-
write_cycle_summary(
|
| 320 |
-
args.summary_out,
|
| 321 |
-
[{
|
| 322 |
-
"benchmark": args.benchmark,
|
| 323 |
-
"variant": args.variant,
|
| 324 |
-
"seed": args.seed,
|
| 325 |
-
"returncode": proc.returncode,
|
| 326 |
-
}],
|
| 327 |
-
)
|
| 328 |
-
return proc.returncode
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
if __name__ == "__main__":
|
| 332 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import importlib.util
|
| 6 |
+
import importlib
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from scripts.benchmark_preflight import build_readiness_report
|
| 15 |
+
from scripts.hf_routing import resolve_routing
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
FREEZE_PATH = REPO_ROOT / "artifacts" / "cycle_1_execution_freeze.json"
|
| 20 |
+
RUNNER_PATH = REPO_ROOT / "scripts" / "benchmark_runner.py"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def active_hf_token() -> str | None:
|
| 24 |
+
token = os.environ.get("HF_TOKEN")
|
| 25 |
+
if token:
|
| 26 |
+
return token
|
| 27 |
+
try:
|
| 28 |
+
from huggingface_hub.utils import get_token
|
| 29 |
+
return get_token()
|
| 30 |
+
except Exception:
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def missing_benchmark_dependencies() -> list[str]:
|
| 35 |
+
required = ["mamba_ssm", "transformers"]
|
| 36 |
+
missing: list[str] = []
|
| 37 |
+
for name in required:
|
| 38 |
+
try:
|
| 39 |
+
spec = importlib.util.find_spec(name)
|
| 40 |
+
except (ImportError, ValueError):
|
| 41 |
+
spec = None
|
| 42 |
+
if spec is None:
|
| 43 |
+
try:
|
| 44 |
+
importlib.import_module(name)
|
| 45 |
+
except Exception:
|
| 46 |
+
missing.append(name)
|
| 47 |
+
return missing
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_cycle_freeze(path: Path) -> dict[str, Any]:
|
| 51 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_cycle_benchmarks(path: Path) -> list[str]:
|
| 55 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 56 |
+
out: list[str] = []
|
| 57 |
+
for section in ("coding_benchmarks", "reasoning_benchmarks"):
|
| 58 |
+
for slot in ("fast_iteration", "milestone"):
|
| 59 |
+
entry = payload.get(section, {}).get(slot)
|
| 60 |
+
if isinstance(entry, dict) and entry.get("name"):
|
| 61 |
+
out.append(str(entry["name"]))
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def build_preflight_report(
|
| 66 |
+
*,
|
| 67 |
+
cache_dir: Path,
|
| 68 |
+
output_repo: str | None = None,
|
| 69 |
+
tokenizer_repo: str | None = None,
|
| 70 |
+
) -> dict[str, object]:
|
| 71 |
+
return build_readiness_report(
|
| 72 |
+
cache_dir=cache_dir,
|
| 73 |
+
hf_token_present=bool(active_hf_token()),
|
| 74 |
+
dependencies_present=not bool(missing_benchmark_dependencies()),
|
| 75 |
+
missing_dependencies=missing_benchmark_dependencies(),
|
| 76 |
+
output_repo=output_repo,
|
| 77 |
+
tokenizer_repo=tokenizer_repo,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def write_preflight_report(path: Path, payload: dict[str, object]) -> None:
|
| 82 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def write_cycle_summary(path: Path, payload: list[dict[str, Any]]) -> None:
|
| 87 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def build_remote_checkpoint_report(output_repo: str, token: str | None) -> dict[str, Any]:
|
| 92 |
+
from huggingface_hub import HfApi
|
| 93 |
+
|
| 94 |
+
from scripts.benchmark_checkpoint_report import build_checkpoint_report
|
| 95 |
+
|
| 96 |
+
files = HfApi(token=token).list_repo_files(repo_id=output_repo, repo_type="model", token=token)
|
| 97 |
+
return build_checkpoint_report(files)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def ensure_benchmark_assets(
|
| 101 |
+
*,
|
| 102 |
+
cache_dir: Path,
|
| 103 |
+
output_repo: str,
|
| 104 |
+
tokenizer_repo: str,
|
| 105 |
+
token: str | None,
|
| 106 |
+
hydrate: bool,
|
| 107 |
+
) -> dict[str, str] | None:
|
| 108 |
+
if not hydrate:
|
| 109 |
+
return None
|
| 110 |
+
from scripts.benchmark_assets import hydrate_benchmark_assets
|
| 111 |
+
|
| 112 |
+
return hydrate_benchmark_assets(
|
| 113 |
+
cache_dir=cache_dir,
|
| 114 |
+
output_repo=output_repo,
|
| 115 |
+
tokenizer_repo=tokenizer_repo,
|
| 116 |
+
token=token,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def build_benchmark_command(
|
| 121 |
+
freeze: dict[str, Any],
|
| 122 |
+
*,
|
| 123 |
+
benchmark: str,
|
| 124 |
+
variant: str,
|
| 125 |
+
seed: int,
|
| 126 |
+
out_dir: Path,
|
| 127 |
+
) -> tuple[list[str], dict[str, str]]:
|
| 128 |
+
variant_cfg = freeze["variants"][variant]
|
| 129 |
+
env = os.environ.copy()
|
| 130 |
+
env.update({str(k): str(v) for k, v in variant_cfg.get("env", {}).items()})
|
| 131 |
+
env["HYDRA_SEED"] = str(seed)
|
| 132 |
+
|
| 133 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
result_path = out_dir / f"{benchmark.lower()}_{variant}_seed{seed}.json"
|
| 135 |
+
ledger_path = out_dir / "benchmark_ledger.json"
|
| 136 |
+
cmd = [
|
| 137 |
+
sys.executable,
|
| 138 |
+
str(RUNNER_PATH),
|
| 139 |
+
"--benchmark",
|
| 140 |
+
benchmark,
|
| 141 |
+
"--generator-mode",
|
| 142 |
+
"hydra",
|
| 143 |
+
"--out",
|
| 144 |
+
str(result_path),
|
| 145 |
+
"--ledger",
|
| 146 |
+
str(ledger_path),
|
| 147 |
+
"--variant",
|
| 148 |
+
variant,
|
| 149 |
+
"--seed",
|
| 150 |
+
str(seed),
|
| 151 |
+
]
|
| 152 |
+
return cmd, env
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def build_cycle_plan(freeze: dict[str, Any], *, benchmark: str, out_dir: Path) -> list[dict[str, Any]]:
|
| 156 |
+
runnable_variants = [
|
| 157 |
+
name for name, cfg in freeze.get("variants", {}).items()
|
| 158 |
+
if isinstance(cfg, dict) and cfg.get("status") == "runnable_now"
|
| 159 |
+
]
|
| 160 |
+
seeds = [int(seed) for seed in freeze.get("seeds", [])]
|
| 161 |
+
plan: list[dict[str, Any]] = []
|
| 162 |
+
for variant in runnable_variants:
|
| 163 |
+
for seed in seeds:
|
| 164 |
+
cmd, env = build_benchmark_command(
|
| 165 |
+
freeze,
|
| 166 |
+
benchmark=benchmark,
|
| 167 |
+
variant=variant,
|
| 168 |
+
seed=seed,
|
| 169 |
+
out_dir=out_dir,
|
| 170 |
+
)
|
| 171 |
+
plan.append({
|
| 172 |
+
"benchmark": benchmark,
|
| 173 |
+
"variant": variant,
|
| 174 |
+
"seed": seed,
|
| 175 |
+
"command": cmd,
|
| 176 |
+
"env": env,
|
| 177 |
+
})
|
| 178 |
+
return plan
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def execute_cycle_plan(plan: list[dict[str, Any]], *, repo_root: Path) -> list[dict[str, Any]]:
|
| 182 |
+
results: list[dict[str, Any]] = []
|
| 183 |
+
for item in plan:
|
| 184 |
+
proc = subprocess.run(item["command"], cwd=str(repo_root), env=item["env"])
|
| 185 |
+
results.append(
|
| 186 |
+
{
|
| 187 |
+
"benchmark": item["benchmark"],
|
| 188 |
+
"variant": item["variant"],
|
| 189 |
+
"seed": item["seed"],
|
| 190 |
+
"returncode": proc.returncode,
|
| 191 |
+
}
|
| 192 |
+
)
|
| 193 |
+
return results
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 197 |
+
parser = argparse.ArgumentParser(description="Execute a frozen Cycle 1 benchmark run")
|
| 198 |
+
parser.add_argument("--freeze", type=Path, default=FREEZE_PATH)
|
| 199 |
+
parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
|
| 200 |
+
parser.add_argument("--benchmark", required=True)
|
| 201 |
+
parser.add_argument("--variant", required=True)
|
| 202 |
+
parser.add_argument("--seed", type=int, required=True)
|
| 203 |
+
parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "runs")
|
| 204 |
+
parser.add_argument("--preflight-out", type=Path)
|
| 205 |
+
parser.add_argument("--summary-out", type=Path)
|
| 206 |
+
parser.add_argument("--hydrate-assets", action="store_true")
|
| 207 |
+
parser.add_argument("--all-runnable", action="store_true")
|
| 208 |
+
parser.add_argument("--all-benchmarks", action="store_true")
|
| 209 |
+
parser.add_argument("--require-ready", action="store_true")
|
| 210 |
+
parser.add_argument("--output-repo")
|
| 211 |
+
parser.add_argument("--tokenizer-repo")
|
| 212 |
+
return parser.parse_args(argv)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def main(argv: list[str] | None = None) -> int:
|
| 216 |
+
args = parse_args(argv)
|
| 217 |
+
cache_dir = Path(os.path.expanduser("~/.cache/autoresearch"))
|
| 218 |
+
report = None
|
| 219 |
+
token = active_hf_token()
|
| 220 |
+
routing = resolve_routing(token=token)
|
| 221 |
+
output_repo = args.output_repo or routing.output_repo
|
| 222 |
+
tokenizer_repo = args.tokenizer_repo or routing.output_repo
|
| 223 |
+
if args.hydrate_assets:
|
| 224 |
+
try:
|
| 225 |
+
ensure_benchmark_assets(
|
| 226 |
+
cache_dir=cache_dir,
|
| 227 |
+
output_repo=output_repo,
|
| 228 |
+
tokenizer_repo=tokenizer_repo,
|
| 229 |
+
token=token,
|
| 230 |
+
hydrate=True,
|
| 231 |
+
)
|
| 232 |
+
except FileNotFoundError as exc:
|
| 233 |
+
checkpoint_report = None
|
| 234 |
+
try:
|
| 235 |
+
checkpoint_report = build_remote_checkpoint_report(output_repo, token)
|
| 236 |
+
except Exception:
|
| 237 |
+
checkpoint_report = None
|
| 238 |
+
if args.summary_out is not None:
|
| 239 |
+
write_cycle_summary(
|
| 240 |
+
args.summary_out,
|
| 241 |
+
[{
|
| 242 |
+
"status": "blocked",
|
| 243 |
+
"reason": "asset_hydration_failed",
|
| 244 |
+
"error": str(exc),
|
| 245 |
+
"checkpoint_candidates": checkpoint_report,
|
| 246 |
+
}],
|
| 247 |
+
)
|
| 248 |
+
return 3
|
| 249 |
+
if args.preflight_out is not None:
|
| 250 |
+
report = build_preflight_report(
|
| 251 |
+
cache_dir=cache_dir,
|
| 252 |
+
output_repo=output_repo,
|
| 253 |
+
tokenizer_repo=tokenizer_repo,
|
| 254 |
+
)
|
| 255 |
+
write_preflight_report(args.preflight_out, report)
|
| 256 |
+
if args.require_ready:
|
| 257 |
+
if report is None:
|
| 258 |
+
report = build_preflight_report(
|
| 259 |
+
cache_dir=cache_dir,
|
| 260 |
+
output_repo=output_repo,
|
| 261 |
+
tokenizer_repo=tokenizer_repo,
|
| 262 |
+
)
|
| 263 |
+
if not bool(report.get("ready_for_hydra_benchmarks")):
|
| 264 |
+
checkpoint_report = None
|
| 265 |
+
try:
|
| 266 |
+
checkpoint_report = build_remote_checkpoint_report(output_repo, token)
|
| 267 |
+
except Exception:
|
| 268 |
+
checkpoint_report = None
|
| 269 |
+
if args.summary_out is not None:
|
| 270 |
+
write_cycle_summary(
|
| 271 |
+
args.summary_out,
|
| 272 |
+
[{
|
| 273 |
+
"status": "blocked",
|
| 274 |
+
"reason": "preflight_not_ready",
|
| 275 |
+
"preflight": report,
|
| 276 |
+
"checkpoint_candidates": checkpoint_report,
|
| 277 |
+
}],
|
| 278 |
+
)
|
| 279 |
+
return 2
|
| 280 |
+
freeze = load_cycle_freeze(args.freeze)
|
| 281 |
+
if args.all_runnable:
|
| 282 |
+
benchmarks = load_cycle_benchmarks(args.suite) if args.all_benchmarks else [args.benchmark]
|
| 283 |
+
plan = []
|
| 284 |
+
for benchmark in benchmarks:
|
| 285 |
+
plan.extend(build_cycle_plan(freeze, benchmark=benchmark, out_dir=args.out_dir))
|
| 286 |
+
results = execute_cycle_plan(plan, repo_root=REPO_ROOT)
|
| 287 |
+
if args.summary_out is not None:
|
| 288 |
+
write_cycle_summary(args.summary_out, results)
|
| 289 |
+
return 0 if all(item["returncode"] == 0 for item in results) else 1
|
| 290 |
+
cmd, env = build_benchmark_command(
|
| 291 |
+
freeze,
|
| 292 |
+
benchmark=args.benchmark,
|
| 293 |
+
variant=args.variant,
|
| 294 |
+
seed=args.seed,
|
| 295 |
+
out_dir=args.out_dir,
|
| 296 |
+
)
|
| 297 |
+
proc = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env)
|
| 298 |
+
if args.summary_out is not None:
|
| 299 |
+
write_cycle_summary(
|
| 300 |
+
args.summary_out,
|
| 301 |
+
[{
|
| 302 |
+
"benchmark": args.benchmark,
|
| 303 |
+
"variant": args.variant,
|
| 304 |
+
"seed": args.seed,
|
| 305 |
+
"returncode": proc.returncode,
|
| 306 |
+
}],
|
| 307 |
+
)
|
| 308 |
+
return proc.returncode
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
if __name__ == "__main__":
|
| 312 |
+
raise SystemExit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/export_hpo_priors.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import datetime as dt
|
| 6 |
-
import json
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any
|
| 9 |
-
|
| 10 |
-
import optuna
|
| 11 |
-
|
| 12 |
-
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def parse_args() -> argparse.Namespace:
|
| 16 |
-
parser = argparse.ArgumentParser(description="Export top Optuna trials as transfer-learning priors")
|
| 17 |
-
parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
|
| 18 |
-
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 19 |
-
parser.add_argument("--top-k", type=int, default=20)
|
| 20 |
-
parser.add_argument("--out", type=Path, default=Path("docs") / "hpo_transfer_priors.json")
|
| 21 |
-
parser.add_argument("--metric", default="val_bpb")
|
| 22 |
-
return parser.parse_args()
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def _completed_trials(study: optuna.Study) -> list[optuna.trial.FrozenTrial]:
|
| 26 |
-
trials = [t for t in study.trials if t.value is not None]
|
| 27 |
-
reverse = study.direction == optuna.study.StudyDirection.MAXIMIZE
|
| 28 |
-
return sorted(trials, key=lambda t: float(t.value), reverse=reverse)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def _serialize_trial(trial: optuna.trial.FrozenTrial) -> dict[str, Any]:
|
| 32 |
-
return {
|
| 33 |
-
"trial_number": trial.number,
|
| 34 |
-
"value": float(trial.value) if trial.value is not None else None,
|
| 35 |
-
"params": dict(trial.params),
|
| 36 |
-
"user_attrs": dict(trial.user_attrs),
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def collect_prior_trials(*, storage: str, study_names: list[str], top_k: int, metric: str) -> dict[str, Any]:
|
| 41 |
-
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 42 |
-
selected = leaderboard["clean_trials"][: max(0, top_k)]
|
| 43 |
-
trials = [
|
| 44 |
-
{
|
| 45 |
-
"study_name": row["study_name"],
|
| 46 |
-
"trial_number": row["trial_number"],
|
| 47 |
-
"value": row["value"],
|
| 48 |
-
"params": row["params"],
|
| 49 |
-
"user_attrs": row["user_attrs"],
|
| 50 |
-
}
|
| 51 |
-
for row in selected
|
| 52 |
-
]
|
| 53 |
-
quarantined = [
|
| 54 |
-
{
|
| 55 |
-
"study_name": row["study_name"],
|
| 56 |
-
"trial_number": row["trial_number"],
|
| 57 |
-
"value": row["value"],
|
| 58 |
-
"params": row["params"],
|
| 59 |
-
"user_attrs": row["user_attrs"],
|
| 60 |
-
"contamination_reason": row["contamination_reason"],
|
| 61 |
-
}
|
| 62 |
-
for row in leaderboard["contaminated_trials"]
|
| 63 |
-
]
|
| 64 |
-
return {
|
| 65 |
-
"schema_version": 2,
|
| 66 |
-
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 67 |
-
"study_names": study_names,
|
| 68 |
-
"metric": metric,
|
| 69 |
-
"n_total_trials": sum(int(s["n_trials"]) for s in leaderboard["studies"]),
|
| 70 |
-
"n_completed_trials": sum(int(s["n_completed"]) for s in leaderboard["studies"]),
|
| 71 |
-
"n_exported_trials": len(trials),
|
| 72 |
-
"n_quarantined_trials": len(quarantined),
|
| 73 |
-
"top_k": top_k,
|
| 74 |
-
"trials": trials,
|
| 75 |
-
"quarantined_trials": quarantined,
|
| 76 |
-
}
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def main() -> int:
|
| 80 |
-
args = parse_args()
|
| 81 |
-
study_names = args.study_name or ["hydra_hpo"]
|
| 82 |
-
payload = collect_prior_trials(storage=args.storage, study_names=study_names, top_k=args.top_k, metric=args.metric)
|
| 83 |
-
|
| 84 |
-
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
-
args.out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 86 |
-
print(
|
| 87 |
-
f"[hpo-priors] wrote {args.out} with {payload['n_exported_trials']} clean trials "
|
| 88 |
-
f"({payload['n_quarantined_trials']} quarantined)"
|
| 89 |
-
)
|
| 90 |
-
return 0
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
if __name__ == "__main__":
|
| 94 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import datetime as dt
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import optuna
|
| 11 |
+
|
| 12 |
+
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_args() -> argparse.Namespace:
|
| 16 |
+
parser = argparse.ArgumentParser(description="Export top Optuna trials as transfer-learning priors")
|
| 17 |
+
parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
|
| 18 |
+
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 19 |
+
parser.add_argument("--top-k", type=int, default=20)
|
| 20 |
+
parser.add_argument("--out", type=Path, default=Path("docs") / "hpo_transfer_priors.json")
|
| 21 |
+
parser.add_argument("--metric", default="val_bpb")
|
| 22 |
+
return parser.parse_args()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _completed_trials(study: optuna.Study) -> list[optuna.trial.FrozenTrial]:
|
| 26 |
+
trials = [t for t in study.trials if t.value is not None]
|
| 27 |
+
reverse = study.direction == optuna.study.StudyDirection.MAXIMIZE
|
| 28 |
+
return sorted(trials, key=lambda t: float(t.value), reverse=reverse)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _serialize_trial(trial: optuna.trial.FrozenTrial) -> dict[str, Any]:
|
| 32 |
+
return {
|
| 33 |
+
"trial_number": trial.number,
|
| 34 |
+
"value": float(trial.value) if trial.value is not None else None,
|
| 35 |
+
"params": dict(trial.params),
|
| 36 |
+
"user_attrs": dict(trial.user_attrs),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def collect_prior_trials(*, storage: str, study_names: list[str], top_k: int, metric: str) -> dict[str, Any]:
|
| 41 |
+
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 42 |
+
selected = leaderboard["clean_trials"][: max(0, top_k)]
|
| 43 |
+
trials = [
|
| 44 |
+
{
|
| 45 |
+
"study_name": row["study_name"],
|
| 46 |
+
"trial_number": row["trial_number"],
|
| 47 |
+
"value": row["value"],
|
| 48 |
+
"params": row["params"],
|
| 49 |
+
"user_attrs": row["user_attrs"],
|
| 50 |
+
}
|
| 51 |
+
for row in selected
|
| 52 |
+
]
|
| 53 |
+
quarantined = [
|
| 54 |
+
{
|
| 55 |
+
"study_name": row["study_name"],
|
| 56 |
+
"trial_number": row["trial_number"],
|
| 57 |
+
"value": row["value"],
|
| 58 |
+
"params": row["params"],
|
| 59 |
+
"user_attrs": row["user_attrs"],
|
| 60 |
+
"contamination_reason": row["contamination_reason"],
|
| 61 |
+
}
|
| 62 |
+
for row in leaderboard["contaminated_trials"]
|
| 63 |
+
]
|
| 64 |
+
return {
|
| 65 |
+
"schema_version": 2,
|
| 66 |
+
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 67 |
+
"study_names": study_names,
|
| 68 |
+
"metric": metric,
|
| 69 |
+
"n_total_trials": sum(int(s["n_trials"]) for s in leaderboard["studies"]),
|
| 70 |
+
"n_completed_trials": sum(int(s["n_completed"]) for s in leaderboard["studies"]),
|
| 71 |
+
"n_exported_trials": len(trials),
|
| 72 |
+
"n_quarantined_trials": len(quarantined),
|
| 73 |
+
"top_k": top_k,
|
| 74 |
+
"trials": trials,
|
| 75 |
+
"quarantined_trials": quarantined,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def main() -> int:
|
| 80 |
+
args = parse_args()
|
| 81 |
+
study_names = args.study_name or ["hydra_hpo"]
|
| 82 |
+
payload = collect_prior_trials(storage=args.storage, study_names=study_names, top_k=args.top_k, metric=args.metric)
|
| 83 |
+
|
| 84 |
+
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
args.out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 86 |
+
print(
|
| 87 |
+
f"[hpo-priors] wrote {args.out} with {payload['n_exported_trials']} clean trials "
|
| 88 |
+
f"({payload['n_quarantined_trials']} quarantined)"
|
| 89 |
+
)
|
| 90 |
+
return 0
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
raise SystemExit(main())
|
overlay/scripts/hf_routing.py
CHANGED
|
@@ -1,94 +1,94 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import os
|
| 4 |
-
from dataclasses import dataclass
|
| 5 |
-
|
| 6 |
-
from huggingface_hub import HfApi
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
_OWNER_ALIASES = {
|
| 10 |
-
'jack': 'jackoatmon',
|
| 11 |
-
'jackoatmon': 'jackoatmon',
|
| 12 |
-
'icarus': 'icarus112',
|
| 13 |
-
'icarus112': 'icarus112',
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _normalize_owner(value: str | None) -> str | None:
|
| 18 |
-
if not value:
|
| 19 |
-
return None
|
| 20 |
-
normalized = value.strip().lower().lstrip('@')
|
| 21 |
-
if not normalized:
|
| 22 |
-
return None
|
| 23 |
-
return _OWNER_ALIASES.get(normalized, normalized)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _owner_from_env() -> str | None:
|
| 27 |
-
for key in (
|
| 28 |
-
'FEATHER_HF_OWNER',
|
| 29 |
-
'FEATHER_HF_NAMESPACE_OWNER',
|
| 30 |
-
'FEATHER_HF_PROFILE',
|
| 31 |
-
'FEATHER_HF_NAMESPACE',
|
| 32 |
-
):
|
| 33 |
-
owner = _normalize_owner(os.environ.get(key))
|
| 34 |
-
if owner:
|
| 35 |
-
return owner
|
| 36 |
-
return None
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def resolve_owner(token: str | None = None) -> str:
|
| 40 |
-
"""Resolve active HF owner in a collaborator-safe way.
|
| 41 |
-
|
| 42 |
-
Resolution precedence:
|
| 43 |
-
1) explicit env owner override (FEATHER_HF_OWNER/...)
|
| 44 |
-
2) Hugging Face `whoami` from HF_TOKEN (unless disabled)
|
| 45 |
-
3) default to jackoatmon
|
| 46 |
-
"""
|
| 47 |
-
owner = _owner_from_env()
|
| 48 |
-
if owner:
|
| 49 |
-
return owner
|
| 50 |
-
|
| 51 |
-
if os.environ.get('FEATHER_HF_DISABLE_WHOAMI', '0') != '1':
|
| 52 |
-
active_token = token or os.environ.get('HF_TOKEN')
|
| 53 |
-
if active_token:
|
| 54 |
-
try:
|
| 55 |
-
info = HfApi(token=active_token).whoami(token=active_token)
|
| 56 |
-
if isinstance(info, dict):
|
| 57 |
-
whoami_owner = _normalize_owner(info.get('name'))
|
| 58 |
-
if whoami_owner:
|
| 59 |
-
return whoami_owner
|
| 60 |
-
except Exception:
|
| 61 |
-
# We intentionally fail-open to deterministic defaults.
|
| 62 |
-
pass
|
| 63 |
-
|
| 64 |
-
return 'jackoatmon'
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
@dataclass(frozen=True)
|
| 68 |
-
class HfRouting:
|
| 69 |
-
owner: str
|
| 70 |
-
space_repo: str
|
| 71 |
-
output_repo: str
|
| 72 |
-
retina_cache_repo: str
|
| 73 |
-
job_namespace: str
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def resolve_routing(token: str | None = None) -> HfRouting:
|
| 77 |
-
owner = resolve_owner(token=token)
|
| 78 |
-
|
| 79 |
-
space_name = os.environ.get('FEATHER_HF_SPACE_NAME', 'feather-a10-runtime')
|
| 80 |
-
output_name = os.environ.get('FEATHER_HF_OUTPUT_REPO_NAME', 'feather-pretrain-checkpoints')
|
| 81 |
-
retina_name = os.environ.get('FEATHER_HF_RETINA_REPO_NAME', 'feather-retina-cache')
|
| 82 |
-
|
| 83 |
-
space_repo = os.environ.get('FEATHER_HF_SPACE_REPO') or f'{owner}/{space_name}'
|
| 84 |
-
output_repo = os.environ.get('FEATHER_HF_OUTPUT_REPO') or f'{owner}/{output_name}'
|
| 85 |
-
retina_cache_repo = os.environ.get('FEATHER_HF_RETINA_CACHE_REPO') or f'{owner}/{retina_name}'
|
| 86 |
-
job_namespace = os.environ.get('FEATHER_HF_JOB_NAMESPACE') or owner
|
| 87 |
-
|
| 88 |
-
return HfRouting(
|
| 89 |
-
owner=owner,
|
| 90 |
-
space_repo=space_repo,
|
| 91 |
-
output_repo=output_repo,
|
| 92 |
-
retina_cache_repo=retina_cache_repo,
|
| 93 |
-
job_namespace=job_namespace,
|
| 94 |
-
)
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from huggingface_hub import HfApi
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_OWNER_ALIASES = {
|
| 10 |
+
'jack': 'jackoatmon',
|
| 11 |
+
'jackoatmon': 'jackoatmon',
|
| 12 |
+
'icarus': 'icarus112',
|
| 13 |
+
'icarus112': 'icarus112',
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_owner(value: str | None) -> str | None:
|
| 18 |
+
if not value:
|
| 19 |
+
return None
|
| 20 |
+
normalized = value.strip().lower().lstrip('@')
|
| 21 |
+
if not normalized:
|
| 22 |
+
return None
|
| 23 |
+
return _OWNER_ALIASES.get(normalized, normalized)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _owner_from_env() -> str | None:
|
| 27 |
+
for key in (
|
| 28 |
+
'FEATHER_HF_OWNER',
|
| 29 |
+
'FEATHER_HF_NAMESPACE_OWNER',
|
| 30 |
+
'FEATHER_HF_PROFILE',
|
| 31 |
+
'FEATHER_HF_NAMESPACE',
|
| 32 |
+
):
|
| 33 |
+
owner = _normalize_owner(os.environ.get(key))
|
| 34 |
+
if owner:
|
| 35 |
+
return owner
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def resolve_owner(token: str | None = None) -> str:
|
| 40 |
+
"""Resolve active HF owner in a collaborator-safe way.
|
| 41 |
+
|
| 42 |
+
Resolution precedence:
|
| 43 |
+
1) explicit env owner override (FEATHER_HF_OWNER/...)
|
| 44 |
+
2) Hugging Face `whoami` from HF_TOKEN (unless disabled)
|
| 45 |
+
3) default to jackoatmon
|
| 46 |
+
"""
|
| 47 |
+
owner = _owner_from_env()
|
| 48 |
+
if owner:
|
| 49 |
+
return owner
|
| 50 |
+
|
| 51 |
+
if os.environ.get('FEATHER_HF_DISABLE_WHOAMI', '0') != '1':
|
| 52 |
+
active_token = token or os.environ.get('HF_TOKEN')
|
| 53 |
+
if active_token:
|
| 54 |
+
try:
|
| 55 |
+
info = HfApi(token=active_token).whoami(token=active_token)
|
| 56 |
+
if isinstance(info, dict):
|
| 57 |
+
whoami_owner = _normalize_owner(info.get('name'))
|
| 58 |
+
if whoami_owner:
|
| 59 |
+
return whoami_owner
|
| 60 |
+
except Exception:
|
| 61 |
+
# We intentionally fail-open to deterministic defaults.
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
return 'jackoatmon'
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass(frozen=True)
|
| 68 |
+
class HfRouting:
|
| 69 |
+
owner: str
|
| 70 |
+
space_repo: str
|
| 71 |
+
output_repo: str
|
| 72 |
+
retina_cache_repo: str
|
| 73 |
+
job_namespace: str
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def resolve_routing(token: str | None = None) -> HfRouting:
|
| 77 |
+
owner = resolve_owner(token=token)
|
| 78 |
+
|
| 79 |
+
space_name = os.environ.get('FEATHER_HF_SPACE_NAME', 'feather-a10-runtime')
|
| 80 |
+
output_name = os.environ.get('FEATHER_HF_OUTPUT_REPO_NAME', 'feather-pretrain-checkpoints')
|
| 81 |
+
retina_name = os.environ.get('FEATHER_HF_RETINA_REPO_NAME', 'feather-retina-cache')
|
| 82 |
+
|
| 83 |
+
space_repo = os.environ.get('FEATHER_HF_SPACE_REPO') or f'{owner}/{space_name}'
|
| 84 |
+
output_repo = os.environ.get('FEATHER_HF_OUTPUT_REPO') or f'{owner}/{output_name}'
|
| 85 |
+
retina_cache_repo = os.environ.get('FEATHER_HF_RETINA_CACHE_REPO') or f'{owner}/{retina_name}'
|
| 86 |
+
job_namespace = os.environ.get('FEATHER_HF_JOB_NAMESPACE') or owner
|
| 87 |
+
|
| 88 |
+
return HfRouting(
|
| 89 |
+
owner=owner,
|
| 90 |
+
space_repo=space_repo,
|
| 91 |
+
output_repo=output_repo,
|
| 92 |
+
retina_cache_repo=retina_cache_repo,
|
| 93 |
+
job_namespace=job_namespace,
|
| 94 |
+
)
|
overlay/scripts/hpo_component_report.py
CHANGED
|
@@ -1,130 +1,130 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import datetime as dt
|
| 6 |
-
import json
|
| 7 |
-
import math
|
| 8 |
-
from collections import defaultdict
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any
|
| 11 |
-
|
| 12 |
-
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
_COMPONENT_KEYS = [
|
| 16 |
-
"engram_subsample",
|
| 17 |
-
"htm_subsample",
|
| 18 |
-
"htm_learn_every",
|
| 19 |
-
"engram_n_columns",
|
| 20 |
-
"engram_layer_idx",
|
| 21 |
-
"sdr_target_active",
|
| 22 |
-
"mamba3_chunk",
|
| 23 |
-
"dropout",
|
| 24 |
-
"hyena_layers",
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
|
| 29 |
-
params = dict(row.get("params") or {})
|
| 30 |
-
attrs = row.get("user_attrs") or {}
|
| 31 |
-
for key, value in attrs.items():
|
| 32 |
-
if key.startswith("param_"):
|
| 33 |
-
params.setdefault(key.removeprefix("param_"), value)
|
| 34 |
-
return params
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _pearson(xs: list[float], ys: list[float]) -> float | None:
|
| 38 |
-
if len(xs) < 2 or len(xs) != len(ys):
|
| 39 |
-
return None
|
| 40 |
-
mean_x = sum(xs) / len(xs)
|
| 41 |
-
mean_y = sum(ys) / len(ys)
|
| 42 |
-
cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
|
| 43 |
-
var_x = sum((x - mean_x) ** 2 for x in xs)
|
| 44 |
-
var_y = sum((y - mean_y) ** 2 for y in ys)
|
| 45 |
-
if var_x <= 0 or var_y <= 0:
|
| 46 |
-
return None
|
| 47 |
-
return cov / math.sqrt(var_x * var_y)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def build_component_report(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
|
| 51 |
-
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 52 |
-
clean_trials = leaderboard["clean_trials"]
|
| 53 |
-
|
| 54 |
-
ablations: dict[str, list[dict[str, Any]]] = {}
|
| 55 |
-
numeric_correlations: list[dict[str, Any]] = []
|
| 56 |
-
|
| 57 |
-
for key in _COMPONENT_KEYS:
|
| 58 |
-
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
| 59 |
-
numeric_x: list[float] = []
|
| 60 |
-
metric_y: list[float] = []
|
| 61 |
-
tps_y: list[float] = []
|
| 62 |
-
for row in clean_trials:
|
| 63 |
-
params = _recover_params(row)
|
| 64 |
-
if key not in params:
|
| 65 |
-
continue
|
| 66 |
-
value = params[key]
|
| 67 |
-
grouped[str(value)].append({"value": value, "metric": float(row["value"]), "tps": row.get("tps")})
|
| 68 |
-
if isinstance(value, (int, float)) and isinstance(row.get("tps"), (int, float)):
|
| 69 |
-
numeric_x.append(float(value))
|
| 70 |
-
metric_y.append(float(row["value"]))
|
| 71 |
-
tps_y.append(float(row["tps"]))
|
| 72 |
-
|
| 73 |
-
rows: list[dict[str, Any]] = []
|
| 74 |
-
for grouped_rows in grouped.values():
|
| 75 |
-
value = grouped_rows[0]["value"]
|
| 76 |
-
metric_vals = [r["metric"] for r in grouped_rows]
|
| 77 |
-
tps_vals = [float(r["tps"]) for r in grouped_rows if isinstance(r["tps"], (int, float))]
|
| 78 |
-
rows.append({
|
| 79 |
-
"value": value,
|
| 80 |
-
"n_trials": len(grouped_rows),
|
| 81 |
-
"mean_metric": sum(metric_vals) / len(metric_vals),
|
| 82 |
-
"mean_tps": (sum(tps_vals) / len(tps_vals)) if tps_vals else None,
|
| 83 |
-
})
|
| 84 |
-
if rows:
|
| 85 |
-
rows.sort(key=lambda row: str(row["value"]))
|
| 86 |
-
ablations[key] = rows
|
| 87 |
-
|
| 88 |
-
pearson_metric = _pearson(numeric_x, metric_y)
|
| 89 |
-
pearson_tps = _pearson(numeric_x, tps_y)
|
| 90 |
-
if pearson_metric is not None or pearson_tps is not None:
|
| 91 |
-
numeric_correlations.append({
|
| 92 |
-
"param": key,
|
| 93 |
-
"pearson_with_metric": pearson_metric,
|
| 94 |
-
"pearson_with_tps": pearson_tps,
|
| 95 |
-
"n_points": len(numeric_x),
|
| 96 |
-
})
|
| 97 |
-
|
| 98 |
-
numeric_correlations.sort(key=lambda row: row["param"])
|
| 99 |
-
return {
|
| 100 |
-
"schema_version": 1,
|
| 101 |
-
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 102 |
-
"metric": metric,
|
| 103 |
-
"study_names": study_names,
|
| 104 |
-
"n_clean_trials": len(clean_trials),
|
| 105 |
-
"component_ablations": ablations,
|
| 106 |
-
"numeric_correlations": numeric_correlations,
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 111 |
-
parser = argparse.ArgumentParser(description="Build component ablation and correlation report from clean HPO trials")
|
| 112 |
-
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 113 |
-
parser.add_argument("--study-name", action="append", default=[])
|
| 114 |
-
parser.add_argument("--metric", default="val_bpb")
|
| 115 |
-
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "component_report.json")
|
| 116 |
-
return parser.parse_args(argv)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def main(argv: list[str] | None = None) -> int:
|
| 120 |
-
args = parse_args(argv)
|
| 121 |
-
study_names = args.study_name or ["hydra_hpo"]
|
| 122 |
-
payload = build_component_report(storage=args.storage, study_names=study_names, metric=args.metric)
|
| 123 |
-
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
-
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 125 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 126 |
-
return 0
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
if __name__ == "__main__":
|
| 130 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import datetime as dt
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_COMPONENT_KEYS = [
|
| 16 |
+
"engram_subsample",
|
| 17 |
+
"htm_subsample",
|
| 18 |
+
"htm_learn_every",
|
| 19 |
+
"engram_n_columns",
|
| 20 |
+
"engram_layer_idx",
|
| 21 |
+
"sdr_target_active",
|
| 22 |
+
"mamba3_chunk",
|
| 23 |
+
"dropout",
|
| 24 |
+
"hyena_layers",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
|
| 29 |
+
params = dict(row.get("params") or {})
|
| 30 |
+
attrs = row.get("user_attrs") or {}
|
| 31 |
+
for key, value in attrs.items():
|
| 32 |
+
if key.startswith("param_"):
|
| 33 |
+
params.setdefault(key.removeprefix("param_"), value)
|
| 34 |
+
return params
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _pearson(xs: list[float], ys: list[float]) -> float | None:
|
| 38 |
+
if len(xs) < 2 or len(xs) != len(ys):
|
| 39 |
+
return None
|
| 40 |
+
mean_x = sum(xs) / len(xs)
|
| 41 |
+
mean_y = sum(ys) / len(ys)
|
| 42 |
+
cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
|
| 43 |
+
var_x = sum((x - mean_x) ** 2 for x in xs)
|
| 44 |
+
var_y = sum((y - mean_y) ** 2 for y in ys)
|
| 45 |
+
if var_x <= 0 or var_y <= 0:
|
| 46 |
+
return None
|
| 47 |
+
return cov / math.sqrt(var_x * var_y)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def build_component_report(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
|
| 51 |
+
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 52 |
+
clean_trials = leaderboard["clean_trials"]
|
| 53 |
+
|
| 54 |
+
ablations: dict[str, list[dict[str, Any]]] = {}
|
| 55 |
+
numeric_correlations: list[dict[str, Any]] = []
|
| 56 |
+
|
| 57 |
+
for key in _COMPONENT_KEYS:
|
| 58 |
+
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
| 59 |
+
numeric_x: list[float] = []
|
| 60 |
+
metric_y: list[float] = []
|
| 61 |
+
tps_y: list[float] = []
|
| 62 |
+
for row in clean_trials:
|
| 63 |
+
params = _recover_params(row)
|
| 64 |
+
if key not in params:
|
| 65 |
+
continue
|
| 66 |
+
value = params[key]
|
| 67 |
+
grouped[str(value)].append({"value": value, "metric": float(row["value"]), "tps": row.get("tps")})
|
| 68 |
+
if isinstance(value, (int, float)) and isinstance(row.get("tps"), (int, float)):
|
| 69 |
+
numeric_x.append(float(value))
|
| 70 |
+
metric_y.append(float(row["value"]))
|
| 71 |
+
tps_y.append(float(row["tps"]))
|
| 72 |
+
|
| 73 |
+
rows: list[dict[str, Any]] = []
|
| 74 |
+
for grouped_rows in grouped.values():
|
| 75 |
+
value = grouped_rows[0]["value"]
|
| 76 |
+
metric_vals = [r["metric"] for r in grouped_rows]
|
| 77 |
+
tps_vals = [float(r["tps"]) for r in grouped_rows if isinstance(r["tps"], (int, float))]
|
| 78 |
+
rows.append({
|
| 79 |
+
"value": value,
|
| 80 |
+
"n_trials": len(grouped_rows),
|
| 81 |
+
"mean_metric": sum(metric_vals) / len(metric_vals),
|
| 82 |
+
"mean_tps": (sum(tps_vals) / len(tps_vals)) if tps_vals else None,
|
| 83 |
+
})
|
| 84 |
+
if rows:
|
| 85 |
+
rows.sort(key=lambda row: str(row["value"]))
|
| 86 |
+
ablations[key] = rows
|
| 87 |
+
|
| 88 |
+
pearson_metric = _pearson(numeric_x, metric_y)
|
| 89 |
+
pearson_tps = _pearson(numeric_x, tps_y)
|
| 90 |
+
if pearson_metric is not None or pearson_tps is not None:
|
| 91 |
+
numeric_correlations.append({
|
| 92 |
+
"param": key,
|
| 93 |
+
"pearson_with_metric": pearson_metric,
|
| 94 |
+
"pearson_with_tps": pearson_tps,
|
| 95 |
+
"n_points": len(numeric_x),
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
numeric_correlations.sort(key=lambda row: row["param"])
|
| 99 |
+
return {
|
| 100 |
+
"schema_version": 1,
|
| 101 |
+
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 102 |
+
"metric": metric,
|
| 103 |
+
"study_names": study_names,
|
| 104 |
+
"n_clean_trials": len(clean_trials),
|
| 105 |
+
"component_ablations": ablations,
|
| 106 |
+
"numeric_correlations": numeric_correlations,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 111 |
+
parser = argparse.ArgumentParser(description="Build component ablation and correlation report from clean HPO trials")
|
| 112 |
+
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 113 |
+
parser.add_argument("--study-name", action="append", default=[])
|
| 114 |
+
parser.add_argument("--metric", default="val_bpb")
|
| 115 |
+
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "component_report.json")
|
| 116 |
+
return parser.parse_args(argv)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def main(argv: list[str] | None = None) -> int:
|
| 120 |
+
args = parse_args(argv)
|
| 121 |
+
study_names = args.study_name or ["hydra_hpo"]
|
| 122 |
+
payload = build_component_report(storage=args.storage, study_names=study_names, metric=args.metric)
|
| 123 |
+
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 125 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 126 |
+
return 0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
raise SystemExit(main())
|
overlay/scripts/hpo_leaderboard.py
CHANGED
|
@@ -1,156 +1,156 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import datetime as dt
|
| 6 |
-
import json
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any
|
| 9 |
-
|
| 10 |
-
import optuna
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _trial_direction(study: optuna.Study) -> str:
|
| 14 |
-
return "maximize" if study.direction == optuna.study.StudyDirection.MAXIMIZE else "minimize"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def _contamination_reason(trial: optuna.trial.FrozenTrial, metric: str) -> str | None:
|
| 18 |
-
if trial.value is None:
|
| 19 |
-
return "missing_value"
|
| 20 |
-
attrs = trial.user_attrs
|
| 21 |
-
source = attrs.get("objective_source")
|
| 22 |
-
eval_status = attrs.get("eval_status")
|
| 23 |
-
objective_metric = attrs.get("objective_metric")
|
| 24 |
-
|
| 25 |
-
if source in {"train_log_fallback", "missing_metric", "missing_metrics", "missing_final_val"}:
|
| 26 |
-
return f"objective_source={source}"
|
| 27 |
-
if eval_status not in {None, "completed"}:
|
| 28 |
-
return f"eval_status={eval_status}"
|
| 29 |
-
if objective_metric not in {None, metric}:
|
| 30 |
-
return f"objective_metric={objective_metric}"
|
| 31 |
-
return None
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _serialize_trial(study_name: str, trial: optuna.trial.FrozenTrial, metric: str) -> dict[str, Any]:
|
| 35 |
-
attrs = dict(trial.user_attrs)
|
| 36 |
-
source = attrs.get("objective_source") or "legacy_completed_value"
|
| 37 |
-
row = {
|
| 38 |
-
"study_name": study_name,
|
| 39 |
-
"trial_number": trial.number,
|
| 40 |
-
"value": float(trial.value) if trial.value is not None else None,
|
| 41 |
-
"metric": metric,
|
| 42 |
-
"objective_source": source,
|
| 43 |
-
"objective_metric": attrs.get("objective_metric", metric),
|
| 44 |
-
"eval_status": attrs.get("eval_status"),
|
| 45 |
-
"hf_job_id": attrs.get("hf_job_id"),
|
| 46 |
-
"tps": attrs.get("tps"),
|
| 47 |
-
"params": dict(trial.params),
|
| 48 |
-
"user_attrs": attrs,
|
| 49 |
-
}
|
| 50 |
-
reason = _contamination_reason(trial, metric)
|
| 51 |
-
if reason is not None:
|
| 52 |
-
row["contamination_reason"] = reason
|
| 53 |
-
return row
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def _is_pareto_dominated(candidate: dict[str, Any], peers: list[dict[str, Any]]) -> bool:
|
| 57 |
-
candidate_value = float(candidate["value"])
|
| 58 |
-
candidate_tps = float(candidate["tps"])
|
| 59 |
-
for peer in peers:
|
| 60 |
-
if peer is candidate or peer.get("tps") is None:
|
| 61 |
-
continue
|
| 62 |
-
peer_value = float(peer["value"])
|
| 63 |
-
peer_tps = float(peer["tps"])
|
| 64 |
-
no_worse = peer_value <= candidate_value and peer_tps >= candidate_tps
|
| 65 |
-
strictly_better = peer_value < candidate_value or peer_tps > candidate_tps
|
| 66 |
-
if no_worse and strictly_better:
|
| 67 |
-
return True
|
| 68 |
-
return False
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def _annotate_pareto(clean_trials: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 72 |
-
pareto_trials: list[dict[str, Any]] = []
|
| 73 |
-
comparable = [row for row in clean_trials if row.get("tps") is not None]
|
| 74 |
-
for row in clean_trials:
|
| 75 |
-
if row.get("tps") is None:
|
| 76 |
-
row["pareto_frontier"] = False
|
| 77 |
-
row["pareto_dominated"] = None
|
| 78 |
-
row["pareto_reason"] = "missing_tps"
|
| 79 |
-
continue
|
| 80 |
-
dominated = _is_pareto_dominated(row, comparable)
|
| 81 |
-
row["pareto_frontier"] = not dominated
|
| 82 |
-
row["pareto_dominated"] = dominated
|
| 83 |
-
row["pareto_reason"] = "frontier" if not dominated else "dominated"
|
| 84 |
-
if not dominated:
|
| 85 |
-
pareto_trials.append(row)
|
| 86 |
-
pareto_trials.sort(key=lambda row: (float(row["value"]), -float(row["tps"])))
|
| 87 |
-
return pareto_trials
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def build_leaderboard(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
|
| 91 |
-
clean_trials: list[dict[str, Any]] = []
|
| 92 |
-
contaminated_trials: list[dict[str, Any]] = []
|
| 93 |
-
study_summaries: list[dict[str, Any]] = []
|
| 94 |
-
direction = "minimize"
|
| 95 |
-
|
| 96 |
-
for study_name in study_names:
|
| 97 |
-
study = optuna.load_study(study_name=study_name, storage=storage)
|
| 98 |
-
direction = _trial_direction(study)
|
| 99 |
-
completed = [t for t in study.trials if t.value is not None]
|
| 100 |
-
study_summaries.append({
|
| 101 |
-
"study_name": study_name,
|
| 102 |
-
"direction": direction,
|
| 103 |
-
"n_trials": len(study.trials),
|
| 104 |
-
"n_completed": len(completed),
|
| 105 |
-
})
|
| 106 |
-
for trial in completed:
|
| 107 |
-
row = _serialize_trial(study_name, trial, metric)
|
| 108 |
-
if "contamination_reason" in row:
|
| 109 |
-
contaminated_trials.append(row)
|
| 110 |
-
else:
|
| 111 |
-
clean_trials.append(row)
|
| 112 |
-
|
| 113 |
-
reverse = direction == "maximize"
|
| 114 |
-
clean_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 115 |
-
contaminated_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 116 |
-
pareto_trials = _annotate_pareto(clean_trials)
|
| 117 |
-
|
| 118 |
-
return {
|
| 119 |
-
"schema_version": 1,
|
| 120 |
-
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 121 |
-
"metric": metric,
|
| 122 |
-
"direction": direction,
|
| 123 |
-
"study_names": study_names,
|
| 124 |
-
"studies": study_summaries,
|
| 125 |
-
"n_clean_trials": len(clean_trials),
|
| 126 |
-
"n_contaminated_trials": len(contaminated_trials),
|
| 127 |
-
"pareto_metric_x": metric,
|
| 128 |
-
"pareto_metric_y": "tps",
|
| 129 |
-
"n_pareto_trials": len(pareto_trials),
|
| 130 |
-
"clean_trials": clean_trials,
|
| 131 |
-
"contaminated_trials": contaminated_trials,
|
| 132 |
-
"pareto_trials": pareto_trials,
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 137 |
-
parser = argparse.ArgumentParser(description="Build a clean Optuna HPO leaderboard")
|
| 138 |
-
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 139 |
-
parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
|
| 140 |
-
parser.add_argument("--metric", default="val_bpb")
|
| 141 |
-
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "leaderboard.json")
|
| 142 |
-
return parser.parse_args(argv)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
def main(argv: list[str] | None = None) -> int:
|
| 146 |
-
args = parse_args(argv)
|
| 147 |
-
study_names = args.study_name or ["hydra_hpo"]
|
| 148 |
-
payload = build_leaderboard(storage=args.storage, study_names=study_names, metric=args.metric)
|
| 149 |
-
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 150 |
-
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 151 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 152 |
-
return 0
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
if __name__ == "__main__":
|
| 156 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import datetime as dt
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import optuna
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _trial_direction(study: optuna.Study) -> str:
|
| 14 |
+
return "maximize" if study.direction == optuna.study.StudyDirection.MAXIMIZE else "minimize"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _contamination_reason(trial: optuna.trial.FrozenTrial, metric: str) -> str | None:
|
| 18 |
+
if trial.value is None:
|
| 19 |
+
return "missing_value"
|
| 20 |
+
attrs = trial.user_attrs
|
| 21 |
+
source = attrs.get("objective_source")
|
| 22 |
+
eval_status = attrs.get("eval_status")
|
| 23 |
+
objective_metric = attrs.get("objective_metric")
|
| 24 |
+
|
| 25 |
+
if source in {"train_log_fallback", "missing_metric", "missing_metrics", "missing_final_val"}:
|
| 26 |
+
return f"objective_source={source}"
|
| 27 |
+
if eval_status not in {None, "completed"}:
|
| 28 |
+
return f"eval_status={eval_status}"
|
| 29 |
+
if objective_metric not in {None, metric}:
|
| 30 |
+
return f"objective_metric={objective_metric}"
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _serialize_trial(study_name: str, trial: optuna.trial.FrozenTrial, metric: str) -> dict[str, Any]:
|
| 35 |
+
attrs = dict(trial.user_attrs)
|
| 36 |
+
source = attrs.get("objective_source") or "legacy_completed_value"
|
| 37 |
+
row = {
|
| 38 |
+
"study_name": study_name,
|
| 39 |
+
"trial_number": trial.number,
|
| 40 |
+
"value": float(trial.value) if trial.value is not None else None,
|
| 41 |
+
"metric": metric,
|
| 42 |
+
"objective_source": source,
|
| 43 |
+
"objective_metric": attrs.get("objective_metric", metric),
|
| 44 |
+
"eval_status": attrs.get("eval_status"),
|
| 45 |
+
"hf_job_id": attrs.get("hf_job_id"),
|
| 46 |
+
"tps": attrs.get("tps"),
|
| 47 |
+
"params": dict(trial.params),
|
| 48 |
+
"user_attrs": attrs,
|
| 49 |
+
}
|
| 50 |
+
reason = _contamination_reason(trial, metric)
|
| 51 |
+
if reason is not None:
|
| 52 |
+
row["contamination_reason"] = reason
|
| 53 |
+
return row
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _is_pareto_dominated(candidate: dict[str, Any], peers: list[dict[str, Any]]) -> bool:
|
| 57 |
+
candidate_value = float(candidate["value"])
|
| 58 |
+
candidate_tps = float(candidate["tps"])
|
| 59 |
+
for peer in peers:
|
| 60 |
+
if peer is candidate or peer.get("tps") is None:
|
| 61 |
+
continue
|
| 62 |
+
peer_value = float(peer["value"])
|
| 63 |
+
peer_tps = float(peer["tps"])
|
| 64 |
+
no_worse = peer_value <= candidate_value and peer_tps >= candidate_tps
|
| 65 |
+
strictly_better = peer_value < candidate_value or peer_tps > candidate_tps
|
| 66 |
+
if no_worse and strictly_better:
|
| 67 |
+
return True
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _annotate_pareto(clean_trials: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 72 |
+
pareto_trials: list[dict[str, Any]] = []
|
| 73 |
+
comparable = [row for row in clean_trials if row.get("tps") is not None]
|
| 74 |
+
for row in clean_trials:
|
| 75 |
+
if row.get("tps") is None:
|
| 76 |
+
row["pareto_frontier"] = False
|
| 77 |
+
row["pareto_dominated"] = None
|
| 78 |
+
row["pareto_reason"] = "missing_tps"
|
| 79 |
+
continue
|
| 80 |
+
dominated = _is_pareto_dominated(row, comparable)
|
| 81 |
+
row["pareto_frontier"] = not dominated
|
| 82 |
+
row["pareto_dominated"] = dominated
|
| 83 |
+
row["pareto_reason"] = "frontier" if not dominated else "dominated"
|
| 84 |
+
if not dominated:
|
| 85 |
+
pareto_trials.append(row)
|
| 86 |
+
pareto_trials.sort(key=lambda row: (float(row["value"]), -float(row["tps"])))
|
| 87 |
+
return pareto_trials
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_leaderboard(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
|
| 91 |
+
clean_trials: list[dict[str, Any]] = []
|
| 92 |
+
contaminated_trials: list[dict[str, Any]] = []
|
| 93 |
+
study_summaries: list[dict[str, Any]] = []
|
| 94 |
+
direction = "minimize"
|
| 95 |
+
|
| 96 |
+
for study_name in study_names:
|
| 97 |
+
study = optuna.load_study(study_name=study_name, storage=storage)
|
| 98 |
+
direction = _trial_direction(study)
|
| 99 |
+
completed = [t for t in study.trials if t.value is not None]
|
| 100 |
+
study_summaries.append({
|
| 101 |
+
"study_name": study_name,
|
| 102 |
+
"direction": direction,
|
| 103 |
+
"n_trials": len(study.trials),
|
| 104 |
+
"n_completed": len(completed),
|
| 105 |
+
})
|
| 106 |
+
for trial in completed:
|
| 107 |
+
row = _serialize_trial(study_name, trial, metric)
|
| 108 |
+
if "contamination_reason" in row:
|
| 109 |
+
contaminated_trials.append(row)
|
| 110 |
+
else:
|
| 111 |
+
clean_trials.append(row)
|
| 112 |
+
|
| 113 |
+
reverse = direction == "maximize"
|
| 114 |
+
clean_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 115 |
+
contaminated_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 116 |
+
pareto_trials = _annotate_pareto(clean_trials)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"schema_version": 1,
|
| 120 |
+
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 121 |
+
"metric": metric,
|
| 122 |
+
"direction": direction,
|
| 123 |
+
"study_names": study_names,
|
| 124 |
+
"studies": study_summaries,
|
| 125 |
+
"n_clean_trials": len(clean_trials),
|
| 126 |
+
"n_contaminated_trials": len(contaminated_trials),
|
| 127 |
+
"pareto_metric_x": metric,
|
| 128 |
+
"pareto_metric_y": "tps",
|
| 129 |
+
"n_pareto_trials": len(pareto_trials),
|
| 130 |
+
"clean_trials": clean_trials,
|
| 131 |
+
"contaminated_trials": contaminated_trials,
|
| 132 |
+
"pareto_trials": pareto_trials,
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 137 |
+
parser = argparse.ArgumentParser(description="Build a clean Optuna HPO leaderboard")
|
| 138 |
+
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 139 |
+
parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
|
| 140 |
+
parser.add_argument("--metric", default="val_bpb")
|
| 141 |
+
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "leaderboard.json")
|
| 142 |
+
return parser.parse_args(argv)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def main(argv: list[str] | None = None) -> int:
|
| 146 |
+
args = parse_args(argv)
|
| 147 |
+
study_names = args.study_name or ["hydra_hpo"]
|
| 148 |
+
payload = build_leaderboard(storage=args.storage, study_names=study_names, metric=args.metric)
|
| 149 |
+
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 150 |
+
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 151 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 152 |
+
return 0
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
raise SystemExit(main())
|
overlay/scripts/hpo_orchestrator.py
CHANGED
|
@@ -1,25 +1,25 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import subprocess
|
| 8 |
-
import sys
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
from typing import Any
|
| 11 |
-
|
| 12 |
-
import optuna
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 16 |
-
if str(REPO_ROOT) not in sys.path:
|
| 17 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
-
|
| 19 |
-
from scripts.hf_routing import resolve_routing
|
| 20 |
-
from scripts.optuna_hpo import _enqueue_transfer_priors
|
| 21 |
-
|
| 22 |
-
HPO_SCRIPT = REPO_ROOT / "scripts" / "optuna_hpo.py"
|
| 23 |
|
| 24 |
|
| 25 |
def _run_worker(args: list[str]) -> int:
|
|
@@ -28,7 +28,7 @@ def _run_worker(args: list[str]) -> int:
|
|
| 28 |
return proc.returncode
|
| 29 |
|
| 30 |
|
| 31 |
-
def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
|
| 32 |
try:
|
| 33 |
study = optuna.load_study(study_name=study_name, storage=storage)
|
| 34 |
except KeyError:
|
|
@@ -62,29 +62,29 @@ def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
|
|
| 62 |
"best_trial_user_attrs": study.best_trial.user_attrs,
|
| 63 |
}
|
| 64 |
)
|
| 65 |
-
return stats
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def _prime_transfer_priors(storage: str, study_name: str, priors_file: Path, apply_priors: bool) -> int:
|
| 69 |
-
if not apply_priors:
|
| 70 |
-
return 0
|
| 71 |
-
study = optuna.create_study(
|
| 72 |
-
study_name=study_name,
|
| 73 |
-
storage=storage,
|
| 74 |
-
load_if_exists=True,
|
| 75 |
-
direction="minimize",
|
| 76 |
-
)
|
| 77 |
-
return _enqueue_transfer_priors(study, priors_file, apply_priors=True)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def _disable_worker_priors(worker_args: list[str]) -> list[str]:
|
| 81 |
-
cleaned: list[str] = []
|
| 82 |
-
for item in worker_args:
|
| 83 |
-
if item in {"--apply-priors", "--no-apply-priors"}:
|
| 84 |
-
continue
|
| 85 |
-
cleaned.append(item)
|
| 86 |
-
cleaned.append("--no-apply-priors")
|
| 87 |
-
return cleaned
|
| 88 |
|
| 89 |
|
| 90 |
def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
|
|
@@ -117,25 +117,25 @@ def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
|
|
| 117 |
base.hf_command,
|
| 118 |
"--hf-token-env",
|
| 119 |
base.hf_token_env,
|
| 120 |
-
"--hf-poll-interval",
|
| 121 |
-
str(base.hf_poll_interval),
|
| 122 |
-
"--hf-launcher-script",
|
| 123 |
-
str(base.hf_launcher_script),
|
| 124 |
-
"--priors-file",
|
| 125 |
-
str(base.priors_file),
|
| 126 |
-
]
|
| 127 |
if base.hf_output_repo:
|
| 128 |
common.extend(["--hf-output-repo", base.hf_output_repo])
|
| 129 |
if base.hf_use_bash:
|
| 130 |
common.append("--hf-use-bash")
|
| 131 |
-
if base.hf_stop_after_metric:
|
| 132 |
-
common.append("--hf-stop-after-metric")
|
| 133 |
-
else:
|
| 134 |
-
common.append("--no-hf-stop-after-metric")
|
| 135 |
-
if base.apply_priors:
|
| 136 |
-
common.append("--apply-priors")
|
| 137 |
-
else:
|
| 138 |
-
common.append("--no-apply-priors")
|
| 139 |
if phase == "phase1":
|
| 140 |
return [
|
| 141 |
*common,
|
|
@@ -184,32 +184,32 @@ def cmd_phase(args: argparse.Namespace) -> int:
|
|
| 184 |
return rc
|
| 185 |
|
| 186 |
|
| 187 |
-
def cmd_parallel(args: argparse.Namespace) -> int:
|
| 188 |
-
enqueued_priors = _prime_transfer_priors(args.storage, args.study_name, args.priors_file, args.apply_priors)
|
| 189 |
-
worker_args = _disable_worker_priors(_phase_args(args.phase, args))
|
| 190 |
-
procs: list[subprocess.Popen[str]] = []
|
| 191 |
-
for _ in range(args.workers):
|
| 192 |
-
cmd = [sys.executable, str(HPO_SCRIPT), *worker_args]
|
| 193 |
procs.append(subprocess.Popen(cmd, cwd=str(REPO_ROOT), text=True))
|
| 194 |
|
| 195 |
exit_codes = [p.wait() for p in procs]
|
| 196 |
stats = _study_stats(args.storage, args.study_name)
|
| 197 |
payload = {
|
| 198 |
"phase": args.phase,
|
| 199 |
-
"workers": args.workers,
|
| 200 |
-
"exit_codes": exit_codes,
|
| 201 |
-
"enqueued_priors": enqueued_priors,
|
| 202 |
-
"stats": stats,
|
| 203 |
-
}
|
| 204 |
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
| 205 |
args.summary_out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 206 |
print(json.dumps(payload, indent=2))
|
| 207 |
return 0 if all(code == 0 for code in exit_codes) else 1
|
| 208 |
|
| 209 |
|
| 210 |
-
def cmd_recommend(args: argparse.Namespace) -> int:
|
| 211 |
-
stats = _study_stats(args.storage, args.study_name)
|
| 212 |
-
min_tps_floor = float(args.min_tps)
|
| 213 |
if stats.get("status") == "missing":
|
| 214 |
payload = {
|
| 215 |
"stats": stats,
|
|
@@ -226,41 +226,41 @@ def cmd_recommend(args: argparse.Namespace) -> int:
|
|
| 226 |
|
| 227 |
n_completed = int(stats.get("n_completed", 0))
|
| 228 |
|
| 229 |
-
if n_completed < 10:
|
| 230 |
-
recommendation = {
|
| 231 |
-
"status": "insufficient_data",
|
| 232 |
-
"next_step": "Run phase1 with 2-4 parallel workers until >=10 completed trials.",
|
| 233 |
-
"early_stop_policy": {
|
| 234 |
-
"patience_trials": 8,
|
| 235 |
-
"min_improvement": 0.001,
|
| 236 |
-
},
|
| 237 |
-
"throughput_guard": {
|
| 238 |
-
"min_tps": min_tps_floor,
|
| 239 |
-
"note": "Trials below this TPS floor are pruned.",
|
| 240 |
-
},
|
| 241 |
-
"transfer_learning": {
|
| 242 |
-
"export_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 10 --out docs/hpo_transfer_priors.json",
|
| 243 |
-
"use_priors": "Enabled by default in scripts/optuna_hpo.py (override with --no-apply-priors)",
|
| 244 |
-
},
|
| 245 |
-
}
|
| 246 |
-
else:
|
| 247 |
-
recommendation = {
|
| 248 |
-
"status": "ready_for_full_optimization",
|
| 249 |
-
"next_step": "Run phase2 with 3-4 parallel workers.",
|
| 250 |
"suggested_full_run": {
|
| 251 |
"trials": 60,
|
| 252 |
-
"workers": 4,
|
| 253 |
-
"trial_time_budget": 300,
|
| 254 |
-
"trial_timeout": 900,
|
| 255 |
-
"min_tps": min_tps_floor,
|
| 256 |
-
"patience_trials": 12,
|
| 257 |
-
"min_improvement": 0.0005,
|
| 258 |
-
},
|
| 259 |
-
"transfer_learning": {
|
| 260 |
-
"refresh_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 20 --out docs/hpo_transfer_priors.json",
|
| 261 |
-
"notes": "Carry priors into new studies unless architecture/objective diverges significantly.",
|
| 262 |
-
},
|
| 263 |
-
}
|
| 264 |
|
| 265 |
payload = {"stats": stats, "recommendation": recommendation}
|
| 266 |
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -269,9 +269,9 @@ def cmd_recommend(args: argparse.Namespace) -> int:
|
|
| 269 |
return 0
|
| 270 |
|
| 271 |
|
| 272 |
-
def build_parser() -> argparse.ArgumentParser:
|
| 273 |
-
routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
|
| 274 |
-
parser = argparse.ArgumentParser(description="Phase-oriented orchestration for Optuna HPO")
|
| 275 |
sub = parser.add_subparsers(dest="cmd", required=True)
|
| 276 |
|
| 277 |
def add_common(p: argparse.ArgumentParser) -> None:
|
|
@@ -283,21 +283,21 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 283 |
p.add_argument("--min-tps", type=float, default=50000.0)
|
| 284 |
p.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "orchestrator_summary.json")
|
| 285 |
p.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local")
|
| 286 |
-
p.add_argument("--hf-namespace", default=routing_defaults.job_namespace)
|
| 287 |
-
p.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}")
|
| 288 |
p.add_argument("--hf-flavor", default="a10g-large")
|
| 289 |
p.add_argument("--hf-timeout", default="25m")
|
| 290 |
p.add_argument("--hf-command", default="/app/entrypoint.py")
|
| 291 |
p.add_argument("--hf-use-bash", action="store_true")
|
| 292 |
p.add_argument("--hf-token-env", default="HF_TOKEN")
|
| 293 |
-
p.add_argument("--hf-poll-interval", type=int, default=12)
|
| 294 |
-
p.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py")
|
| 295 |
-
p.add_argument("--hf-output-repo", default=routing_defaults.output_repo)
|
| 296 |
-
p.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json")
|
| 297 |
-
p.add_argument("--apply-priors", action="store_true", default=True)
|
| 298 |
-
p.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
|
| 299 |
-
p.add_argument("--hf-stop-after-metric", action="store_true", default=True)
|
| 300 |
-
p.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
|
| 301 |
|
| 302 |
# Phase-1 defaults
|
| 303 |
p.add_argument("--phase1-trials", type=int, default=30)
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
import optuna
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 16 |
+
if str(REPO_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
+
|
| 19 |
+
from scripts.hf_routing import resolve_routing
|
| 20 |
+
from scripts.optuna_hpo import _enqueue_transfer_priors
|
| 21 |
+
|
| 22 |
+
HPO_SCRIPT = REPO_ROOT / "scripts" / "optuna_hpo.py"
|
| 23 |
|
| 24 |
|
| 25 |
def _run_worker(args: list[str]) -> int:
|
|
|
|
| 28 |
return proc.returncode
|
| 29 |
|
| 30 |
|
| 31 |
+
def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
|
| 32 |
try:
|
| 33 |
study = optuna.load_study(study_name=study_name, storage=storage)
|
| 34 |
except KeyError:
|
|
|
|
| 62 |
"best_trial_user_attrs": study.best_trial.user_attrs,
|
| 63 |
}
|
| 64 |
)
|
| 65 |
+
return stats
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _prime_transfer_priors(storage: str, study_name: str, priors_file: Path, apply_priors: bool) -> int:
|
| 69 |
+
if not apply_priors:
|
| 70 |
+
return 0
|
| 71 |
+
study = optuna.create_study(
|
| 72 |
+
study_name=study_name,
|
| 73 |
+
storage=storage,
|
| 74 |
+
load_if_exists=True,
|
| 75 |
+
direction="minimize",
|
| 76 |
+
)
|
| 77 |
+
return _enqueue_transfer_priors(study, priors_file, apply_priors=True)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _disable_worker_priors(worker_args: list[str]) -> list[str]:
|
| 81 |
+
cleaned: list[str] = []
|
| 82 |
+
for item in worker_args:
|
| 83 |
+
if item in {"--apply-priors", "--no-apply-priors"}:
|
| 84 |
+
continue
|
| 85 |
+
cleaned.append(item)
|
| 86 |
+
cleaned.append("--no-apply-priors")
|
| 87 |
+
return cleaned
|
| 88 |
|
| 89 |
|
| 90 |
def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
|
|
|
|
| 117 |
base.hf_command,
|
| 118 |
"--hf-token-env",
|
| 119 |
base.hf_token_env,
|
| 120 |
+
"--hf-poll-interval",
|
| 121 |
+
str(base.hf_poll_interval),
|
| 122 |
+
"--hf-launcher-script",
|
| 123 |
+
str(base.hf_launcher_script),
|
| 124 |
+
"--priors-file",
|
| 125 |
+
str(base.priors_file),
|
| 126 |
+
]
|
| 127 |
if base.hf_output_repo:
|
| 128 |
common.extend(["--hf-output-repo", base.hf_output_repo])
|
| 129 |
if base.hf_use_bash:
|
| 130 |
common.append("--hf-use-bash")
|
| 131 |
+
if base.hf_stop_after_metric:
|
| 132 |
+
common.append("--hf-stop-after-metric")
|
| 133 |
+
else:
|
| 134 |
+
common.append("--no-hf-stop-after-metric")
|
| 135 |
+
if base.apply_priors:
|
| 136 |
+
common.append("--apply-priors")
|
| 137 |
+
else:
|
| 138 |
+
common.append("--no-apply-priors")
|
| 139 |
if phase == "phase1":
|
| 140 |
return [
|
| 141 |
*common,
|
|
|
|
| 184 |
return rc
|
| 185 |
|
| 186 |
|
| 187 |
+
def cmd_parallel(args: argparse.Namespace) -> int:
|
| 188 |
+
enqueued_priors = _prime_transfer_priors(args.storage, args.study_name, args.priors_file, args.apply_priors)
|
| 189 |
+
worker_args = _disable_worker_priors(_phase_args(args.phase, args))
|
| 190 |
+
procs: list[subprocess.Popen[str]] = []
|
| 191 |
+
for _ in range(args.workers):
|
| 192 |
+
cmd = [sys.executable, str(HPO_SCRIPT), *worker_args]
|
| 193 |
procs.append(subprocess.Popen(cmd, cwd=str(REPO_ROOT), text=True))
|
| 194 |
|
| 195 |
exit_codes = [p.wait() for p in procs]
|
| 196 |
stats = _study_stats(args.storage, args.study_name)
|
| 197 |
payload = {
|
| 198 |
"phase": args.phase,
|
| 199 |
+
"workers": args.workers,
|
| 200 |
+
"exit_codes": exit_codes,
|
| 201 |
+
"enqueued_priors": enqueued_priors,
|
| 202 |
+
"stats": stats,
|
| 203 |
+
}
|
| 204 |
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
| 205 |
args.summary_out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 206 |
print(json.dumps(payload, indent=2))
|
| 207 |
return 0 if all(code == 0 for code in exit_codes) else 1
|
| 208 |
|
| 209 |
|
| 210 |
+
def cmd_recommend(args: argparse.Namespace) -> int:
|
| 211 |
+
stats = _study_stats(args.storage, args.study_name)
|
| 212 |
+
min_tps_floor = float(args.min_tps)
|
| 213 |
if stats.get("status") == "missing":
|
| 214 |
payload = {
|
| 215 |
"stats": stats,
|
|
|
|
| 226 |
|
| 227 |
n_completed = int(stats.get("n_completed", 0))
|
| 228 |
|
| 229 |
+
if n_completed < 10:
|
| 230 |
+
recommendation = {
|
| 231 |
+
"status": "insufficient_data",
|
| 232 |
+
"next_step": "Run phase1 with 2-4 parallel workers until >=10 completed trials.",
|
| 233 |
+
"early_stop_policy": {
|
| 234 |
+
"patience_trials": 8,
|
| 235 |
+
"min_improvement": 0.001,
|
| 236 |
+
},
|
| 237 |
+
"throughput_guard": {
|
| 238 |
+
"min_tps": min_tps_floor,
|
| 239 |
+
"note": "Trials below this TPS floor are pruned.",
|
| 240 |
+
},
|
| 241 |
+
"transfer_learning": {
|
| 242 |
+
"export_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 10 --out docs/hpo_transfer_priors.json",
|
| 243 |
+
"use_priors": "Enabled by default in scripts/optuna_hpo.py (override with --no-apply-priors)",
|
| 244 |
+
},
|
| 245 |
+
}
|
| 246 |
+
else:
|
| 247 |
+
recommendation = {
|
| 248 |
+
"status": "ready_for_full_optimization",
|
| 249 |
+
"next_step": "Run phase2 with 3-4 parallel workers.",
|
| 250 |
"suggested_full_run": {
|
| 251 |
"trials": 60,
|
| 252 |
+
"workers": 4,
|
| 253 |
+
"trial_time_budget": 300,
|
| 254 |
+
"trial_timeout": 900,
|
| 255 |
+
"min_tps": min_tps_floor,
|
| 256 |
+
"patience_trials": 12,
|
| 257 |
+
"min_improvement": 0.0005,
|
| 258 |
+
},
|
| 259 |
+
"transfer_learning": {
|
| 260 |
+
"refresh_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 20 --out docs/hpo_transfer_priors.json",
|
| 261 |
+
"notes": "Carry priors into new studies unless architecture/objective diverges significantly.",
|
| 262 |
+
},
|
| 263 |
+
}
|
| 264 |
|
| 265 |
payload = {"stats": stats, "recommendation": recommendation}
|
| 266 |
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 269 |
return 0
|
| 270 |
|
| 271 |
|
| 272 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 273 |
+
routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
|
| 274 |
+
parser = argparse.ArgumentParser(description="Phase-oriented orchestration for Optuna HPO")
|
| 275 |
sub = parser.add_subparsers(dest="cmd", required=True)
|
| 276 |
|
| 277 |
def add_common(p: argparse.ArgumentParser) -> None:
|
|
|
|
| 283 |
p.add_argument("--min-tps", type=float, default=50000.0)
|
| 284 |
p.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "orchestrator_summary.json")
|
| 285 |
p.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local")
|
| 286 |
+
p.add_argument("--hf-namespace", default=routing_defaults.job_namespace)
|
| 287 |
+
p.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}")
|
| 288 |
p.add_argument("--hf-flavor", default="a10g-large")
|
| 289 |
p.add_argument("--hf-timeout", default="25m")
|
| 290 |
p.add_argument("--hf-command", default="/app/entrypoint.py")
|
| 291 |
p.add_argument("--hf-use-bash", action="store_true")
|
| 292 |
p.add_argument("--hf-token-env", default="HF_TOKEN")
|
| 293 |
+
p.add_argument("--hf-poll-interval", type=int, default=12)
|
| 294 |
+
p.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py")
|
| 295 |
+
p.add_argument("--hf-output-repo", default=routing_defaults.output_repo)
|
| 296 |
+
p.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json")
|
| 297 |
+
p.add_argument("--apply-priors", action="store_true", default=True)
|
| 298 |
+
p.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
|
| 299 |
+
p.add_argument("--hf-stop-after-metric", action="store_true", default=True)
|
| 300 |
+
p.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
|
| 301 |
|
| 302 |
# Phase-1 defaults
|
| 303 |
p.add_argument("--phase1-trials", type=int, default=30)
|
overlay/scripts/hpo_retest.py
CHANGED
|
@@ -1,151 +1,151 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import datetime as dt
|
| 6 |
-
import json
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Any
|
| 9 |
-
|
| 10 |
-
import optuna
|
| 11 |
-
|
| 12 |
-
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
_PARAM_TO_ENV = {
|
| 16 |
-
"d_model": "HYDRA_D_MODEL",
|
| 17 |
-
"n_layer": "HYDRA_N_LAYER",
|
| 18 |
-
"d_state": "HYDRA_D_STATE",
|
| 19 |
-
"headdim": "HYDRA_HEADDIM",
|
| 20 |
-
"expand": "HYDRA_EXPAND",
|
| 21 |
-
"seq_len": "HYDRA_SEQ_LEN",
|
| 22 |
-
"batch_size": "HYDRA_BATCH_SIZE",
|
| 23 |
-
"matrix_lr": "HYDRA_MATRIX_LR",
|
| 24 |
-
"embed_lr": "HYDRA_EMBED_LR",
|
| 25 |
-
"unembed_lr": "HYDRA_UNEMBED_LR",
|
| 26 |
-
"engram_n_columns": "HYDRA_ENGRAM_N_COLUMNS",
|
| 27 |
-
"engram_layer_idx": "HYDRA_ENGRAM_LAYER_IDX",
|
| 28 |
-
"sdr_target_active": "HYDRA_SDR_TARGET_ACTIVE",
|
| 29 |
-
"htm_learn_every": "HYDRA_HTM_LEARN_EVERY",
|
| 30 |
-
"htm_subsample": "HYDRA_HTM_SUBSAMPLE",
|
| 31 |
-
"engram_subsample": "HYDRA_ENGRAM_SUBSAMPLE",
|
| 32 |
-
"mamba3_chunk": "HYDRA_MAMBA3_CHUNK",
|
| 33 |
-
"dropout": "HYDRA_DROPOUT",
|
| 34 |
-
}
|
| 35 |
-
|
| 36 |
-
_DEFAULT_ENV = {
|
| 37 |
-
"HYDRA_USE_NEMOTRON": "1",
|
| 38 |
-
"HYDRA_LOCAL_SHARDS_ONLY": "0",
|
| 39 |
-
"HYDRA_THROUGHPUT_MODE": "0",
|
| 40 |
-
"HYDRA_FASTPATH": "0",
|
| 41 |
-
"HYDRA_FORCE_HTM_CPU": "0",
|
| 42 |
-
"HYDRA_INERT_MAMBA": "0",
|
| 43 |
-
"HYDRA_ALLOW_SYNTHETIC_RETINA": "0",
|
| 44 |
-
"HYDRA_HTM_FUSED": "1",
|
| 45 |
-
"HYDRA_HYENA_LAYERS": "",
|
| 46 |
-
"HYDRA_CKPT_INTERVAL": "0",
|
| 47 |
-
"HYDRA_ENGRAM_SUBSAMPLE": "1",
|
| 48 |
-
"HYDRA_HTM_SUBSAMPLE": "2",
|
| 49 |
-
"HYDRA_HTM_LEARN_EVERY": "8",
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
|
| 54 |
-
params = dict(row.get("params") or {})
|
| 55 |
-
attrs = row.get("user_attrs") or {}
|
| 56 |
-
for key, value in attrs.items():
|
| 57 |
-
if key.startswith("param_"):
|
| 58 |
-
params.setdefault(key.removeprefix("param_"), value)
|
| 59 |
-
return params
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def _candidate_env(params: dict[str, Any], *, eval_tokens: int, eval_batch: int, time_budget: int) -> dict[str, str]:
|
| 63 |
-
env = dict(_DEFAULT_ENV)
|
| 64 |
-
env["HYDRA_EVAL_TOKENS"] = str(eval_tokens)
|
| 65 |
-
env["HYDRA_EVAL_BATCH"] = str(eval_batch)
|
| 66 |
-
env["HYDRA_TIME_BUDGET"] = str(time_budget)
|
| 67 |
-
for key, value in params.items():
|
| 68 |
-
env_key = _PARAM_TO_ENV.get(key)
|
| 69 |
-
if env_key is not None:
|
| 70 |
-
env[env_key] = str(value)
|
| 71 |
-
if "HYDRA_BATCH_SIZE" in env and "HYDRA_SEQ_LEN" in env:
|
| 72 |
-
grad_accum = int(params.get("grad_accum", 16))
|
| 73 |
-
env["HYDRA_TOTAL_BATCH"] = str(int(env["HYDRA_BATCH_SIZE"]) * int(env["HYDRA_SEQ_LEN"]) * grad_accum)
|
| 74 |
-
return env
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def build_retest_plan(
|
| 78 |
-
*,
|
| 79 |
-
storage: str,
|
| 80 |
-
study_names: list[str],
|
| 81 |
-
top_k: int,
|
| 82 |
-
metric: str = "val_bpb",
|
| 83 |
-
eval_tokens: int = 16384,
|
| 84 |
-
eval_batch: int = 2,
|
| 85 |
-
time_budget: int = 420,
|
| 86 |
-
) -> dict[str, Any]:
|
| 87 |
-
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 88 |
-
rows = [*leaderboard["contaminated_trials"], *leaderboard["clean_trials"]]
|
| 89 |
-
reverse = leaderboard["direction"] == "maximize"
|
| 90 |
-
rows.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 91 |
-
candidates = []
|
| 92 |
-
for row in rows[: max(0, top_k)]:
|
| 93 |
-
params = _recover_params(row)
|
| 94 |
-
env = _candidate_env(params, eval_tokens=eval_tokens, eval_batch=eval_batch, time_budget=time_budget)
|
| 95 |
-
reason = row.get("contamination_reason") or "canonical_truth_eval_retest"
|
| 96 |
-
candidates.append({
|
| 97 |
-
"study_name": row["study_name"],
|
| 98 |
-
"trial_number": row["trial_number"],
|
| 99 |
-
"source_value": row["value"],
|
| 100 |
-
"source_objective": row["objective_source"],
|
| 101 |
-
"source_job_id": row.get("hf_job_id"),
|
| 102 |
-
"needs_retest_reason": reason,
|
| 103 |
-
"params": params,
|
| 104 |
-
"env": env,
|
| 105 |
-
})
|
| 106 |
-
return {
|
| 107 |
-
"schema_version": 1,
|
| 108 |
-
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 109 |
-
"metric": metric,
|
| 110 |
-
"study_names": study_names,
|
| 111 |
-
"eval_tokens": eval_tokens,
|
| 112 |
-
"eval_batch": eval_batch,
|
| 113 |
-
"time_budget": time_budget,
|
| 114 |
-
"n_candidates": len(candidates),
|
| 115 |
-
"candidates": candidates,
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 120 |
-
parser = argparse.ArgumentParser(description="Plan canonical-eval retests for historical HPO configs")
|
| 121 |
-
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 122 |
-
parser.add_argument("--study-name", action="append", default=[])
|
| 123 |
-
parser.add_argument("--metric", default="val_bpb")
|
| 124 |
-
parser.add_argument("--top-k", type=int, default=10)
|
| 125 |
-
parser.add_argument("--eval-tokens", type=int, default=16384)
|
| 126 |
-
parser.add_argument("--eval-batch", type=int, default=2)
|
| 127 |
-
parser.add_argument("--time-budget", type=int, default=420)
|
| 128 |
-
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "retest_plan.json")
|
| 129 |
-
return parser.parse_args(argv)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def main(argv: list[str] | None = None) -> int:
|
| 133 |
-
args = parse_args(argv)
|
| 134 |
-
study_names = args.study_name or ["hydra_hpo"]
|
| 135 |
-
payload = build_retest_plan(
|
| 136 |
-
storage=args.storage,
|
| 137 |
-
study_names=study_names,
|
| 138 |
-
top_k=args.top_k,
|
| 139 |
-
metric=args.metric,
|
| 140 |
-
eval_tokens=args.eval_tokens,
|
| 141 |
-
eval_batch=args.eval_batch,
|
| 142 |
-
time_budget=args.time_budget,
|
| 143 |
-
)
|
| 144 |
-
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
-
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 146 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 147 |
-
return 0
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if __name__ == "__main__":
|
| 151 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import datetime as dt
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import optuna
|
| 11 |
+
|
| 12 |
+
from scripts.hpo_leaderboard import build_leaderboard
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_PARAM_TO_ENV = {
|
| 16 |
+
"d_model": "HYDRA_D_MODEL",
|
| 17 |
+
"n_layer": "HYDRA_N_LAYER",
|
| 18 |
+
"d_state": "HYDRA_D_STATE",
|
| 19 |
+
"headdim": "HYDRA_HEADDIM",
|
| 20 |
+
"expand": "HYDRA_EXPAND",
|
| 21 |
+
"seq_len": "HYDRA_SEQ_LEN",
|
| 22 |
+
"batch_size": "HYDRA_BATCH_SIZE",
|
| 23 |
+
"matrix_lr": "HYDRA_MATRIX_LR",
|
| 24 |
+
"embed_lr": "HYDRA_EMBED_LR",
|
| 25 |
+
"unembed_lr": "HYDRA_UNEMBED_LR",
|
| 26 |
+
"engram_n_columns": "HYDRA_ENGRAM_N_COLUMNS",
|
| 27 |
+
"engram_layer_idx": "HYDRA_ENGRAM_LAYER_IDX",
|
| 28 |
+
"sdr_target_active": "HYDRA_SDR_TARGET_ACTIVE",
|
| 29 |
+
"htm_learn_every": "HYDRA_HTM_LEARN_EVERY",
|
| 30 |
+
"htm_subsample": "HYDRA_HTM_SUBSAMPLE",
|
| 31 |
+
"engram_subsample": "HYDRA_ENGRAM_SUBSAMPLE",
|
| 32 |
+
"mamba3_chunk": "HYDRA_MAMBA3_CHUNK",
|
| 33 |
+
"dropout": "HYDRA_DROPOUT",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
_DEFAULT_ENV = {
|
| 37 |
+
"HYDRA_USE_NEMOTRON": "1",
|
| 38 |
+
"HYDRA_LOCAL_SHARDS_ONLY": "0",
|
| 39 |
+
"HYDRA_THROUGHPUT_MODE": "0",
|
| 40 |
+
"HYDRA_FASTPATH": "0",
|
| 41 |
+
"HYDRA_FORCE_HTM_CPU": "0",
|
| 42 |
+
"HYDRA_INERT_MAMBA": "0",
|
| 43 |
+
"HYDRA_ALLOW_SYNTHETIC_RETINA": "0",
|
| 44 |
+
"HYDRA_HTM_FUSED": "1",
|
| 45 |
+
"HYDRA_HYENA_LAYERS": "",
|
| 46 |
+
"HYDRA_CKPT_INTERVAL": "0",
|
| 47 |
+
"HYDRA_ENGRAM_SUBSAMPLE": "1",
|
| 48 |
+
"HYDRA_HTM_SUBSAMPLE": "2",
|
| 49 |
+
"HYDRA_HTM_LEARN_EVERY": "8",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
|
| 54 |
+
params = dict(row.get("params") or {})
|
| 55 |
+
attrs = row.get("user_attrs") or {}
|
| 56 |
+
for key, value in attrs.items():
|
| 57 |
+
if key.startswith("param_"):
|
| 58 |
+
params.setdefault(key.removeprefix("param_"), value)
|
| 59 |
+
return params
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _candidate_env(params: dict[str, Any], *, eval_tokens: int, eval_batch: int, time_budget: int) -> dict[str, str]:
|
| 63 |
+
env = dict(_DEFAULT_ENV)
|
| 64 |
+
env["HYDRA_EVAL_TOKENS"] = str(eval_tokens)
|
| 65 |
+
env["HYDRA_EVAL_BATCH"] = str(eval_batch)
|
| 66 |
+
env["HYDRA_TIME_BUDGET"] = str(time_budget)
|
| 67 |
+
for key, value in params.items():
|
| 68 |
+
env_key = _PARAM_TO_ENV.get(key)
|
| 69 |
+
if env_key is not None:
|
| 70 |
+
env[env_key] = str(value)
|
| 71 |
+
if "HYDRA_BATCH_SIZE" in env and "HYDRA_SEQ_LEN" in env:
|
| 72 |
+
grad_accum = int(params.get("grad_accum", 16))
|
| 73 |
+
env["HYDRA_TOTAL_BATCH"] = str(int(env["HYDRA_BATCH_SIZE"]) * int(env["HYDRA_SEQ_LEN"]) * grad_accum)
|
| 74 |
+
return env
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_retest_plan(
|
| 78 |
+
*,
|
| 79 |
+
storage: str,
|
| 80 |
+
study_names: list[str],
|
| 81 |
+
top_k: int,
|
| 82 |
+
metric: str = "val_bpb",
|
| 83 |
+
eval_tokens: int = 16384,
|
| 84 |
+
eval_batch: int = 2,
|
| 85 |
+
time_budget: int = 420,
|
| 86 |
+
) -> dict[str, Any]:
|
| 87 |
+
leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
|
| 88 |
+
rows = [*leaderboard["contaminated_trials"], *leaderboard["clean_trials"]]
|
| 89 |
+
reverse = leaderboard["direction"] == "maximize"
|
| 90 |
+
rows.sort(key=lambda row: float(row["value"]), reverse=reverse)
|
| 91 |
+
candidates = []
|
| 92 |
+
for row in rows[: max(0, top_k)]:
|
| 93 |
+
params = _recover_params(row)
|
| 94 |
+
env = _candidate_env(params, eval_tokens=eval_tokens, eval_batch=eval_batch, time_budget=time_budget)
|
| 95 |
+
reason = row.get("contamination_reason") or "canonical_truth_eval_retest"
|
| 96 |
+
candidates.append({
|
| 97 |
+
"study_name": row["study_name"],
|
| 98 |
+
"trial_number": row["trial_number"],
|
| 99 |
+
"source_value": row["value"],
|
| 100 |
+
"source_objective": row["objective_source"],
|
| 101 |
+
"source_job_id": row.get("hf_job_id"),
|
| 102 |
+
"needs_retest_reason": reason,
|
| 103 |
+
"params": params,
|
| 104 |
+
"env": env,
|
| 105 |
+
})
|
| 106 |
+
return {
|
| 107 |
+
"schema_version": 1,
|
| 108 |
+
"generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
|
| 109 |
+
"metric": metric,
|
| 110 |
+
"study_names": study_names,
|
| 111 |
+
"eval_tokens": eval_tokens,
|
| 112 |
+
"eval_batch": eval_batch,
|
| 113 |
+
"time_budget": time_budget,
|
| 114 |
+
"n_candidates": len(candidates),
|
| 115 |
+
"candidates": candidates,
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 120 |
+
parser = argparse.ArgumentParser(description="Plan canonical-eval retests for historical HPO configs")
|
| 121 |
+
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
|
| 122 |
+
parser.add_argument("--study-name", action="append", default=[])
|
| 123 |
+
parser.add_argument("--metric", default="val_bpb")
|
| 124 |
+
parser.add_argument("--top-k", type=int, default=10)
|
| 125 |
+
parser.add_argument("--eval-tokens", type=int, default=16384)
|
| 126 |
+
parser.add_argument("--eval-batch", type=int, default=2)
|
| 127 |
+
parser.add_argument("--time-budget", type=int, default=420)
|
| 128 |
+
parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "retest_plan.json")
|
| 129 |
+
return parser.parse_args(argv)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main(argv: list[str] | None = None) -> int:
|
| 133 |
+
args = parse_args(argv)
|
| 134 |
+
study_names = args.study_name or ["hydra_hpo"]
|
| 135 |
+
payload = build_retest_plan(
|
| 136 |
+
storage=args.storage,
|
| 137 |
+
study_names=study_names,
|
| 138 |
+
top_k=args.top_k,
|
| 139 |
+
metric=args.metric,
|
| 140 |
+
eval_tokens=args.eval_tokens,
|
| 141 |
+
eval_batch=args.eval_batch,
|
| 142 |
+
time_budget=args.time_budget,
|
| 143 |
+
)
|
| 144 |
+
args.out.parent.mkdir(parents=True, exist_ok=True)
|
| 145 |
+
args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 146 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 147 |
+
return 0
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
raise SystemExit(main())
|
overlay/scripts/hydra_generation.py
CHANGED
|
@@ -1,183 +1,180 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Callable
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
|
| 10 |
-
from scripts.benchmark_checkpoint import hydrate_checkpoint
|
| 11 |
-
from scripts.hf_routing import resolve_routing
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def default_checkpoint_path() -> Path:
|
| 15 |
-
return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
|
| 19 |
-
base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
|
| 20 |
-
return [
|
| 21 |
-
base / "best_bpb.pt",
|
| 22 |
-
base / "pretrain_final.pt",
|
| 23 |
-
base / "latest.pt",
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
|
| 28 |
-
if explicit_path is not None:
|
| 29 |
-
return explicit_path
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
unexpected_keys:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
from
|
| 87 |
-
from
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
return
|
| 113 |
-
|
| 114 |
-
def
|
| 115 |
-
return
|
| 116 |
-
|
| 117 |
-
def
|
| 118 |
-
return
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
gen_cfg =
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
):
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
with torch.no_grad()
|
| 177 |
-
out = model.generate(ids, generation_config=gen_config)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
return tokenizer.decode(out[0].tolist())
|
| 182 |
-
|
| 183 |
-
return _generate
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from scripts.benchmark_checkpoint import hydrate_checkpoint
|
| 11 |
+
from scripts.hf_routing import resolve_routing
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def default_checkpoint_path() -> Path:
|
| 15 |
+
return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
|
| 19 |
+
base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
|
| 20 |
+
return [
|
| 21 |
+
base / "best_bpb.pt",
|
| 22 |
+
base / "pretrain_final.pt",
|
| 23 |
+
base / "latest.pt",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
|
| 28 |
+
if explicit_path is not None:
|
| 29 |
+
return explicit_path
|
| 30 |
+
for candidate in checkpoint_candidates(cache_dir=cache_dir):
|
| 31 |
+
if candidate.exists():
|
| 32 |
+
return candidate
|
| 33 |
+
return default_checkpoint_path()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def validate_checkpoint_compatibility(
|
| 37 |
+
*,
|
| 38 |
+
baseline_arch: str,
|
| 39 |
+
missing_keys: list[str],
|
| 40 |
+
unexpected_keys: list[str],
|
| 41 |
+
total_model_keys: int,
|
| 42 |
+
) -> None:
|
| 43 |
+
if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"checkpoint incompatible with transformer baseline architecture; "
|
| 46 |
+
"use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
|
| 47 |
+
)
|
| 48 |
+
mismatch_count = len(missing_keys) + len(unexpected_keys)
|
| 49 |
+
if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
|
| 50 |
+
raise RuntimeError("checkpoint incompatible with requested model architecture")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_from_callable(
|
| 54 |
+
generator: Callable[[str], str] | Callable[..., str],
|
| 55 |
+
prompt: str,
|
| 56 |
+
*,
|
| 57 |
+
max_new_tokens: int,
|
| 58 |
+
temperature: float,
|
| 59 |
+
top_p: float,
|
| 60 |
+
) -> str:
|
| 61 |
+
text = generator(
|
| 62 |
+
prompt,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
+
temperature=temperature,
|
| 65 |
+
top_p=top_p,
|
| 66 |
+
)
|
| 67 |
+
return str(text).strip()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
|
| 71 |
+
ckpt_path = resolve_checkpoint_path(checkpoint_path)
|
| 72 |
+
if not ckpt_path.exists():
|
| 73 |
+
hydrated = hydrate_checkpoint(
|
| 74 |
+
cache_dir=ckpt_path.parent,
|
| 75 |
+
output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
|
| 76 |
+
token=os.environ.get("HF_TOKEN"),
|
| 77 |
+
)
|
| 78 |
+
if hydrated is not None:
|
| 79 |
+
ckpt_path = hydrated
|
| 80 |
+
if not ckpt_path.exists():
|
| 81 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 82 |
+
|
| 83 |
+
from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
|
| 84 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 85 |
+
|
| 86 |
+
from hydra.config import PostSemClawConfig
|
| 87 |
+
from hydra.model import PostSemClawModel
|
| 88 |
+
from prepare import Tokenizer
|
| 89 |
+
|
| 90 |
+
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 91 |
+
|
| 92 |
+
class _HydraGenConfig(PretrainedConfig):
|
| 93 |
+
model_type = "hydra"
|
| 94 |
+
|
| 95 |
+
def __init__(self, vocab_size: int = 65536, **kw):
|
| 96 |
+
super().__init__(**kw)
|
| 97 |
+
self.vocab_size = vocab_size
|
| 98 |
+
|
| 99 |
+
class HydraForCausalLM(PreTrainedModel, GenerationMixin):
|
| 100 |
+
config_class = _HydraGenConfig
|
| 101 |
+
|
| 102 |
+
def __init__(self, gen_config, inner_model):
|
| 103 |
+
super().__init__(gen_config)
|
| 104 |
+
self.inner = inner_model
|
| 105 |
+
self.config.vocab_size = gen_config.vocab_size
|
| 106 |
+
|
| 107 |
+
def forward(self, input_ids, attention_mask=None, **kw):
|
| 108 |
+
logits = self.inner(input_ids)
|
| 109 |
+
return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
|
| 110 |
+
|
| 111 |
+
def prepare_inputs_for_generation(self, input_ids, **kw):
|
| 112 |
+
return {"input_ids": input_ids}
|
| 113 |
+
|
| 114 |
+
def get_input_embeddings(self):
|
| 115 |
+
return self.inner.wte
|
| 116 |
+
|
| 117 |
+
def can_generate(self) -> bool:
|
| 118 |
+
return True
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def _supports_cache_class(self):
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
tokenizer = Tokenizer.from_directory()
|
| 125 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 126 |
+
bos = tokenizer.get_bos_token_id()
|
| 127 |
+
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
|
| 128 |
+
cfg = PostSemClawConfig(**ckpt["config"])
|
| 129 |
+
with torch.device("meta"):
|
| 130 |
+
inner = PostSemClawModel(cfg)
|
| 131 |
+
inner.to_empty(device=resolved_device)
|
| 132 |
+
missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
|
| 133 |
+
validate_checkpoint_compatibility(
|
| 134 |
+
baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
|
| 135 |
+
missing_keys=list(missing),
|
| 136 |
+
unexpected_keys=list(unexpected),
|
| 137 |
+
total_model_keys=len(inner.state_dict()),
|
| 138 |
+
)
|
| 139 |
+
inner.eval()
|
| 140 |
+
|
| 141 |
+
gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
|
| 142 |
+
gen_cfg.bos_token_id = bos
|
| 143 |
+
gen_cfg.eos_token_id = bos
|
| 144 |
+
gen_cfg.pad_token_id = bos
|
| 145 |
+
model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
|
| 146 |
+
model.eval()
|
| 147 |
+
return tokenizer, model, bos, resolved_device, GenerationConfig
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def build_hydra_generator(
|
| 151 |
+
*,
|
| 152 |
+
checkpoint_path: Path | None = None,
|
| 153 |
+
device: str | None = None,
|
| 154 |
+
max_new_tokens: int,
|
| 155 |
+
temperature: float,
|
| 156 |
+
top_p: float,
|
| 157 |
+
):
|
| 158 |
+
tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)
|
| 159 |
+
|
| 160 |
+
def _generate(prompt: str) -> str:
|
| 161 |
+
ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
|
| 162 |
+
gen_config = GenerationConfig(
|
| 163 |
+
max_new_tokens=max_new_tokens,
|
| 164 |
+
use_cache=False,
|
| 165 |
+
do_sample=temperature > 0.0,
|
| 166 |
+
temperature=temperature,
|
| 167 |
+
top_p=top_p,
|
| 168 |
+
bos_token_id=bos,
|
| 169 |
+
eos_token_id=bos,
|
| 170 |
+
pad_token_id=bos,
|
| 171 |
+
)
|
| 172 |
+
if str(resolved_device).startswith("cuda"):
|
| 173 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 174 |
+
out = model.generate(ids, generation_config=gen_config)
|
| 175 |
+
else:
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
out = model.generate(ids, generation_config=gen_config)
|
| 178 |
+
return tokenizer.decode(out[0].tolist())
|
| 179 |
+
|
| 180 |
+
return _generate
|
|
|
|
|
|
|
|
|
overlay/scripts/launch_benchmark_hf_job.py
CHANGED
|
@@ -1,222 +1,157 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import os
|
| 7 |
-
import sys
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 11 |
-
if str(REPO_ROOT) not in sys.path:
|
| 12 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 13 |
-
|
| 14 |
-
from huggingface_hub import HfApi
|
| 15 |
-
from huggingface_hub.utils import get_token
|
| 16 |
-
|
| 17 |
-
from scripts.
|
| 18 |
-
from scripts.
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
output_repo=args.output_repo,
|
| 159 |
-
tokenizer_repo=args.tokenizer_repo,
|
| 160 |
-
retina_repo=args.retina_repo,
|
| 161 |
-
freeze=freeze,
|
| 162 |
-
)
|
| 163 |
-
command = build_benchmark_job_command(
|
| 164 |
-
benchmark=args.benchmark,
|
| 165 |
-
variant=args.variant,
|
| 166 |
-
seed=args.seed,
|
| 167 |
-
suite_path=args.suite,
|
| 168 |
-
freeze=freeze,
|
| 169 |
-
)
|
| 170 |
-
payload = {
|
| 171 |
-
"benchmark": args.benchmark,
|
| 172 |
-
"variant": args.variant,
|
| 173 |
-
"seed": args.seed,
|
| 174 |
-
"output_repo": args.output_repo,
|
| 175 |
-
"tokenizer_repo": args.tokenizer_repo,
|
| 176 |
-
"retina_repo": args.retina_repo,
|
| 177 |
-
"freeze": str(args.freeze),
|
| 178 |
-
"suite": str(args.suite) if args.suite is not None else None,
|
| 179 |
-
"image": args.image,
|
| 180 |
-
"namespace": args.namespace,
|
| 181 |
-
"command": command,
|
| 182 |
-
"env": env,
|
| 183 |
-
"dry_run": args.dry_run,
|
| 184 |
-
}
|
| 185 |
-
if not args.dry_run:
|
| 186 |
-
token = os.environ.get("HF_TOKEN") or get_token()
|
| 187 |
-
if not token:
|
| 188 |
-
raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
|
| 189 |
-
api = HfApi(token=token)
|
| 190 |
-
if args.refresh_image:
|
| 191 |
-
space_repo = args.image.removeprefix("hf.co/spaces/")
|
| 192 |
-
if args.sync_overlay:
|
| 193 |
-
sync_overlay_from_repo()
|
| 194 |
-
api.upload_folder(
|
| 195 |
-
repo_id=space_repo,
|
| 196 |
-
repo_type="space",
|
| 197 |
-
folder_path=str(IMAGE_DIR),
|
| 198 |
-
commit_message="Update benchmark runtime image",
|
| 199 |
-
token=token,
|
| 200 |
-
)
|
| 201 |
-
wait_for_space(api, space_repo, token=token)
|
| 202 |
-
payload.update(
|
| 203 |
-
submit_benchmark_job(
|
| 204 |
-
api=api,
|
| 205 |
-
image=args.image,
|
| 206 |
-
command=command,
|
| 207 |
-
env=env,
|
| 208 |
-
token=token,
|
| 209 |
-
namespace=args.namespace,
|
| 210 |
-
flavor=args.flavor,
|
| 211 |
-
timeout=args.timeout,
|
| 212 |
-
)
|
| 213 |
-
)
|
| 214 |
-
if args.summary_out is not None:
|
| 215 |
-
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
| 216 |
-
args.summary_out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 217 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 218 |
-
return 0
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
if __name__ == "__main__":
|
| 222 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 11 |
+
if str(REPO_ROOT) not in sys.path:
|
| 12 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import HfApi
|
| 15 |
+
from huggingface_hub.utils import get_token
|
| 16 |
+
|
| 17 |
+
from scripts.hf_routing import resolve_routing
|
| 18 |
+
from scripts.launch_feather_hf_job import IMAGE_DIR, sync_overlay_from_repo, wait_for_space
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_benchmark_job_env(
|
| 22 |
+
*,
|
| 23 |
+
benchmark: str,
|
| 24 |
+
variant: str,
|
| 25 |
+
seed: int,
|
| 26 |
+
output_repo: str,
|
| 27 |
+
tokenizer_repo: str,
|
| 28 |
+
) -> dict[str, str]:
|
| 29 |
+
env = {
|
| 30 |
+
"FEATHER_HF_OUTPUT_REPO": output_repo,
|
| 31 |
+
"FEATHER_RUNTIME_MODE": "benchmark",
|
| 32 |
+
"HYDRA_TOKENIZER_CACHE_REPO": tokenizer_repo,
|
| 33 |
+
"HYDRA_BENCHMARK_NAME": benchmark,
|
| 34 |
+
"HYDRA_BENCHMARK_VARIANT": variant,
|
| 35 |
+
"HYDRA_SEED": str(seed),
|
| 36 |
+
"PYTHONUNBUFFERED": "1",
|
| 37 |
+
}
|
| 38 |
+
for key, value in os.environ.items():
|
| 39 |
+
if key.startswith("HYDRA_") and key not in env:
|
| 40 |
+
env[key] = value
|
| 41 |
+
return env
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def build_benchmark_job_command(*, benchmark: str, variant: str, seed: int) -> list[str]:
|
| 45 |
+
return [
|
| 46 |
+
"python",
|
| 47 |
+
"/app/entrypoint.py",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def submit_benchmark_job(
|
| 52 |
+
*,
|
| 53 |
+
api,
|
| 54 |
+
image: str,
|
| 55 |
+
command: list[str],
|
| 56 |
+
env: dict[str, str],
|
| 57 |
+
token: str,
|
| 58 |
+
namespace: str,
|
| 59 |
+
flavor: str,
|
| 60 |
+
timeout: str,
|
| 61 |
+
) -> dict[str, str]:
|
| 62 |
+
job = api.run_job(
|
| 63 |
+
image=image,
|
| 64 |
+
command=command,
|
| 65 |
+
env=env,
|
| 66 |
+
secrets={"HF_TOKEN": token},
|
| 67 |
+
flavor=flavor,
|
| 68 |
+
timeout=timeout,
|
| 69 |
+
namespace=namespace,
|
| 70 |
+
token=token,
|
| 71 |
+
)
|
| 72 |
+
return {
|
| 73 |
+
"job_id": job.id,
|
| 74 |
+
"job_url": job.url,
|
| 75 |
+
"job_stage": str(job.status.stage),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 80 |
+
routing = resolve_routing(token=os.environ.get("HF_TOKEN"))
|
| 81 |
+
parser = argparse.ArgumentParser(description="Prepare or submit a remote HF benchmark job")
|
| 82 |
+
parser.add_argument("--benchmark", required=True)
|
| 83 |
+
parser.add_argument("--variant", required=True)
|
| 84 |
+
parser.add_argument("--seed", type=int, required=True)
|
| 85 |
+
parser.add_argument("--output-repo", default=routing.output_repo)
|
| 86 |
+
parser.add_argument("--tokenizer-repo", default=routing.output_repo)
|
| 87 |
+
parser.add_argument("--image", default=f"hf.co/spaces/{routing.space_repo}")
|
| 88 |
+
parser.add_argument("--namespace", default=routing.job_namespace)
|
| 89 |
+
parser.add_argument("--flavor", default="a10g-small")
|
| 90 |
+
parser.add_argument("--timeout", default="30m")
|
| 91 |
+
parser.add_argument("--summary-out", type=Path)
|
| 92 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 93 |
+
parser.add_argument("--refresh-image", action="store_true")
|
| 94 |
+
parser.add_argument("--sync-overlay", action="store_true")
|
| 95 |
+
return parser.parse_args(argv)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main(argv: list[str] | None = None) -> int:
|
| 99 |
+
args = parse_args(argv)
|
| 100 |
+
env = build_benchmark_job_env(
|
| 101 |
+
benchmark=args.benchmark,
|
| 102 |
+
variant=args.variant,
|
| 103 |
+
seed=args.seed,
|
| 104 |
+
output_repo=args.output_repo,
|
| 105 |
+
tokenizer_repo=args.tokenizer_repo,
|
| 106 |
+
)
|
| 107 |
+
command = build_benchmark_job_command(benchmark=args.benchmark, variant=args.variant, seed=args.seed)
|
| 108 |
+
payload = {
|
| 109 |
+
"benchmark": args.benchmark,
|
| 110 |
+
"variant": args.variant,
|
| 111 |
+
"seed": args.seed,
|
| 112 |
+
"output_repo": args.output_repo,
|
| 113 |
+
"tokenizer_repo": args.tokenizer_repo,
|
| 114 |
+
"image": args.image,
|
| 115 |
+
"namespace": args.namespace,
|
| 116 |
+
"command": command,
|
| 117 |
+
"env": env,
|
| 118 |
+
"dry_run": args.dry_run,
|
| 119 |
+
}
|
| 120 |
+
if not args.dry_run:
|
| 121 |
+
token = os.environ.get("HF_TOKEN") or get_token()
|
| 122 |
+
if not token:
|
| 123 |
+
raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
|
| 124 |
+
api = HfApi(token=token)
|
| 125 |
+
if args.refresh_image:
|
| 126 |
+
space_repo = args.image.removeprefix("hf.co/spaces/")
|
| 127 |
+
if args.sync_overlay:
|
| 128 |
+
sync_overlay_from_repo()
|
| 129 |
+
api.upload_folder(
|
| 130 |
+
repo_id=space_repo,
|
| 131 |
+
repo_type="space",
|
| 132 |
+
folder_path=str(IMAGE_DIR),
|
| 133 |
+
commit_message="Update benchmark runtime image",
|
| 134 |
+
token=token,
|
| 135 |
+
)
|
| 136 |
+
wait_for_space(api, space_repo, token=token)
|
| 137 |
+
payload.update(
|
| 138 |
+
submit_benchmark_job(
|
| 139 |
+
api=api,
|
| 140 |
+
image=args.image,
|
| 141 |
+
command=command,
|
| 142 |
+
env=env,
|
| 143 |
+
token=token,
|
| 144 |
+
namespace=args.namespace,
|
| 145 |
+
flavor=args.flavor,
|
| 146 |
+
timeout=args.timeout,
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
if args.summary_out is not None:
|
| 150 |
+
args.summary_out.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
args.summary_out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 152 |
+
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 153 |
+
return 0
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
raise SystemExit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/scripts/launch_feather_hf_job.py
CHANGED
|
@@ -1,224 +1,218 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import os
|
| 5 |
-
import shutil
|
| 6 |
-
import sys
|
| 7 |
-
import time
|
| 8 |
-
import json
|
| 9 |
-
from collections.abc import Mapping, Sequence
|
| 10 |
-
from typing import Any, cast
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
import httpx
|
| 14 |
-
from huggingface_hub import HfApi
|
| 15 |
-
from huggingface_hub.utils import HfHubHTTPError
|
| 16 |
-
from huggingface_hub.utils import get_token
|
| 17 |
-
|
| 18 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
-
if str(REPO_ROOT) not in sys.path:
|
| 20 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 21 |
-
|
| 22 |
-
from scripts.hf_routing import resolve_routing
|
| 23 |
-
from configs.harness_config import HarnessConfig
|
| 24 |
-
|
| 25 |
-
DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:latest')
|
| 26 |
-
IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
|
| 27 |
-
TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
|
| 28 |
-
TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
|
| 29 |
-
TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
|
| 30 |
-
DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 31 |
-
CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
|
| 32 |
-
JOB_FLAVOR = os.environ.get('FEATHER_HF_FLAVOR', 'a10g-small')
|
| 33 |
-
DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
|
| 34 |
-
USE_SPACE_IMAGE = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
|
| 35 |
# When true, assume the Space image has already been built by a previous
|
| 36 |
# invocation and skip the upload+build wait. Used by sweep drivers that fan
|
| 37 |
# out many jobs against a single pre-uploaded image.
|
| 38 |
-
SKIP_UPLOAD = os.environ.get('FEATHER_HF_SKIP_UPLOAD', '0') == '1'
|
| 39 |
-
SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
|
| 40 |
-
JOB_SUBMIT_RETRIES = max(1, int(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRIES', '3')))
|
| 41 |
-
JOB_SUBMIT_RETRY_BASE_S = float(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRY_BASE_S', '5'))
|
| 42 |
-
BUILD_LOG_TAIL_LINES = max(1, int(os.environ.get('FEATHER_HF_BUILD_LOG_TAIL_LINES', '120')))
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
|
| 46 |
-
"""Use streaming data path for short-budget launch profiles."""
|
| 47 |
-
try:
|
| 48 |
-
shards = int(target_shards)
|
| 49 |
-
budget = int(time_budget)
|
| 50 |
-
except ValueError:
|
| 51 |
-
return False
|
| 52 |
-
return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def apply_a10_env_profile(
|
| 56 |
-
env: dict[str, str],
|
| 57 |
-
*,
|
| 58 |
-
job_flavor: str,
|
| 59 |
-
parent_env: Mapping[str, str] = os.environ,
|
| 60 |
-
) -> str | None:
|
| 61 |
-
if not job_flavor.startswith('a10'):
|
| 62 |
-
return None
|
| 63 |
-
|
| 64 |
-
full_arch = parent_env.get('HYDRA_THROUGHPUT_MODE') == '0' or env.get('HYDRA_THROUGHPUT_MODE') == '0'
|
| 65 |
-
if full_arch:
|
| 66 |
-
defaults = {
|
| 67 |
-
'HYDRA_THROUGHPUT_MODE': '0',
|
| 68 |
-
'HYDRA_MUON_COMPILE': '0',
|
| 69 |
-
'HYDRA_FORCE_HTM_CPU': '0',
|
| 70 |
-
'HYDRA_INERT_MAMBA': '0',
|
| 71 |
-
'HYDRA_ALLOW_SYNTHETIC_RETINA': '0',
|
| 72 |
-
'HYDRA_FASTPATH': '0',
|
| 73 |
-
}
|
| 74 |
-
profile = 'full-architecture'
|
| 75 |
-
else:
|
| 76 |
-
defaults = {
|
| 77 |
-
'HYDRA_MUON_COMPILE': '0',
|
| 78 |
-
'HYDRA_FORCE_HTM_CPU': '1',
|
| 79 |
-
'HYDRA_INERT_MAMBA': '1',
|
| 80 |
-
'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
|
| 81 |
-
'HYDRA_FASTPATH': '1',
|
| 82 |
-
}
|
| 83 |
-
profile = 'compatibility'
|
| 84 |
-
|
| 85 |
-
for key, default in defaults.items():
|
| 86 |
-
env[key] = parent_env.get(key, env.get(key, default))
|
| 87 |
-
return profile
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _http_status(exc: BaseException) -> int | None:
|
| 91 |
-
response = getattr(exc, 'response', None)
|
| 92 |
-
status = getattr(response, 'status_code', None)
|
| 93 |
-
return int(status) if isinstance(status, int) else None
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def submit_job_with_retry(
|
| 97 |
-
api: HfApi,
|
| 98 |
-
*,
|
| 99 |
-
image: str,
|
| 100 |
-
command: Sequence[str],
|
| 101 |
-
env: dict[str, str],
|
| 102 |
-
secrets: dict[str, str],
|
| 103 |
-
flavor: Any,
|
| 104 |
-
timeout: str,
|
| 105 |
-
token: str,
|
| 106 |
-
namespace: str,
|
| 107 |
-
):
|
| 108 |
-
for attempt in range(1, JOB_SUBMIT_RETRIES + 1):
|
| 109 |
-
try:
|
| 110 |
-
return api.run_job(
|
| 111 |
-
image=image,
|
| 112 |
-
command=list(command),
|
| 113 |
-
env=env,
|
| 114 |
-
secrets=secrets,
|
| 115 |
-
flavor=flavor,
|
| 116 |
-
timeout=timeout,
|
| 117 |
-
namespace=namespace,
|
| 118 |
-
token=token,
|
| 119 |
-
)
|
| 120 |
-
except HfHubHTTPError as exc:
|
| 121 |
-
status = _http_status(exc)
|
| 122 |
-
if status is None or status < 500 or status >= 600:
|
| 123 |
-
raise
|
| 124 |
-
if attempt >= JOB_SUBMIT_RETRIES:
|
| 125 |
-
raise SystemExit(
|
| 126 |
-
f'HF job submit failed after {JOB_SUBMIT_RETRIES} attempts '
|
| 127 |
-
f'(http {status}); failing fast'
|
| 128 |
-
) from exc
|
| 129 |
-
time.sleep(JOB_SUBMIT_RETRY_BASE_S * attempt)
|
| 130 |
-
raise SystemExit('HF job submit failed unexpectedly; failing fast')
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def fetch_space_build_log_tail(
|
| 134 |
-
api: HfApi,
|
| 135 |
-
repo_id: str,
|
| 136 |
-
token: str,
|
| 137 |
-
*,
|
| 138 |
-
limit: int = BUILD_LOG_TAIL_LINES,
|
| 139 |
-
) -> str:
|
| 140 |
-
try:
|
| 141 |
-
lines = list(api.fetch_space_logs(repo_id, build=True, follow=False, token=token))
|
| 142 |
-
except Exception as exc:
|
| 143 |
-
return f'[space-build-log] failed to fetch build logs: {exc!r}'
|
| 144 |
-
tail = lines[-limit:]
|
| 145 |
-
text = ''.join(tail)
|
| 146 |
-
if text and not text.endswith('\n'):
|
| 147 |
-
text += '\n'
|
| 148 |
-
return text or '[space-build-log] no buffered build logs available\n'
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def sync_overlay_from_repo() -> None:
|
| 152 |
-
"""Refresh Space overlay with required project files."""
|
| 153 |
-
overlay = IMAGE_DIR / 'overlay'
|
| 154 |
-
overlay.mkdir(parents=True, exist_ok=True)
|
| 155 |
-
|
| 156 |
-
for child in overlay.iterdir():
|
| 157 |
-
if child.is_dir():
|
| 158 |
-
shutil.rmtree(child)
|
| 159 |
-
else:
|
| 160 |
-
child.unlink()
|
| 161 |
-
|
| 162 |
-
include_paths = [
|
| 163 |
-
'hydra',
|
| 164 |
-
'subsystems',
|
| 165 |
-
'scripts',
|
| 166 |
-
'
|
| 167 |
-
'
|
| 168 |
-
'
|
| 169 |
-
'
|
| 170 |
-
'
|
| 171 |
-
'
|
| 172 |
-
'
|
| 173 |
-
'
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
'
|
| 179 |
-
'.
|
| 180 |
-
'.
|
| 181 |
-
'
|
| 182 |
-
'.
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
if not token:
|
| 217 |
-
raise SystemExit('HF_TOKEN must be set or cached via huggingface-cli login for launch_feather_hf_job.py')
|
| 218 |
-
return token
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800) -> None:
|
| 222 |
"""Wait until the Space image has been built.
|
| 223 |
|
| 224 |
We use the Space purely as a container-image builder for HF Jobs. The Space
|
|
@@ -233,134 +227,134 @@ def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800)
|
|
| 233 |
and APP_STARTING_ERROR after a successful BUILDING→APP_STARTING transition
|
| 234 |
are acceptable — the image exists in the registry and Jobs can use it.
|
| 235 |
"""
|
| 236 |
-
start = time.time()
|
| 237 |
-
seen_build_completion = False
|
| 238 |
-
seen_building = False
|
| 239 |
-
while True:
|
| 240 |
-
try:
|
| 241 |
-
runtime = api.get_space_runtime(repo_id, token=token)
|
| 242 |
-
except httpx.TransportError as exc:
|
| 243 |
-
if time.time() - start > timeout_s:
|
| 244 |
-
raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with network errors') from exc
|
| 245 |
-
print(f'[space] transient runtime endpoint network error: {exc!r}; retrying', flush=True)
|
| 246 |
-
time.sleep(20)
|
| 247 |
-
continue
|
| 248 |
-
except HfHubHTTPError as exc:
|
| 249 |
-
status = _http_status(exc)
|
| 250 |
-
if status is not None and 500 <= status < 600:
|
| 251 |
-
if time.time() - start > timeout_s:
|
| 252 |
-
raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with HTTP {status}') from exc
|
| 253 |
-
time.sleep(20)
|
| 254 |
-
continue
|
| 255 |
-
raise
|
| 256 |
-
stage = getattr(runtime, 'stage', None)
|
| 257 |
-
hardware = getattr(runtime, 'hardware', None)
|
| 258 |
-
err = getattr(runtime, 'errorMessage', None) or getattr(runtime, 'error_message', None)
|
| 259 |
-
print(f'[space] stage={stage} hardware={hardware}', flush=True)
|
| 260 |
-
if stage == 'BUILDING':
|
| 261 |
-
seen_building = True
|
| 262 |
-
if stage in {'APP_STARTING', 'RUNNING', 'PAUSED', 'SLEEPING'}:
|
| 263 |
-
seen_build_completion = True
|
| 264 |
-
if stage in {'RUNNING', 'PAUSED', 'SLEEPING'}:
|
| 265 |
-
return
|
| 266 |
-
# Image is built — Jobs can use it regardless of Space boot outcome.
|
| 267 |
-
if (seen_build_completion or seen_building) and stage in {'RUNTIME_ERROR', 'APP_STARTING_ERROR'}:
|
| 268 |
-
print(f'[space] Space boot failed with {stage} but built image is '
|
| 269 |
-
f'available in the Space registry and is usable by HF Jobs.',
|
| 270 |
-
flush=True)
|
| 271 |
-
return
|
| 272 |
-
# Hard build failures — no image was produced.
|
| 273 |
-
if stage in {'BUILD_ERROR', 'CONFIG_ERROR', 'NO_APP_FILE'}:
|
| 274 |
-
build_log_tail = fetch_space_build_log_tail(api, repo_id, token)
|
| 275 |
-
raise RuntimeError(
|
| 276 |
-
f'Space {repo_id} build failed: stage={stage} error={err!r}\n'
|
| 277 |
-
f'--- Space build log tail ---\n{build_log_tail}'
|
| 278 |
-
)
|
| 279 |
if time.time() - start > timeout_s:
|
| 280 |
raise TimeoutError(f'Space {repo_id} did not become ready in {timeout_s}s (last stage={stage})')
|
| 281 |
time.sleep(20)
|
| 282 |
|
| 283 |
|
| 284 |
-
def main() -> int:
|
| 285 |
-
token = require_token()
|
| 286 |
-
routing = resolve_routing(token=token)
|
| 287 |
-
api = HfApi(token=token)
|
| 288 |
-
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 289 |
-
|
| 290 |
-
print(f'[launch] image_dir={IMAGE_DIR}', flush=True)
|
| 291 |
-
print(f'[launch] owner={routing.owner}', flush=True)
|
| 292 |
-
print(f'[launch] space_repo={routing.space_repo}', flush=True)
|
| 293 |
-
print(f'[launch] output_repo={routing.output_repo}', flush=True)
|
| 294 |
-
print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
|
| 295 |
-
print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
|
| 296 |
-
print(f'[launch] flavor={JOB_FLAVOR}', flush=True)
|
| 297 |
-
print(f'[launch] namespace={routing.job_namespace}', flush=True)
|
| 298 |
-
print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
|
| 299 |
-
print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
|
| 300 |
-
if not USE_SPACE_IMAGE:
|
| 301 |
-
print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
|
| 302 |
-
|
| 303 |
-
api.create_repo(repo_id=routing.space_repo, repo_type='space', space_sdk='docker', private=True, exist_ok=True, token=token)
|
| 304 |
-
api.create_repo(repo_id=routing.output_repo, repo_type='model', private=True, exist_ok=True, token=token)
|
| 305 |
|
| 306 |
if DRY_RUN:
|
| 307 |
print('[launch] dry-run mode; skipping upload and job submission', flush=True)
|
| 308 |
return 0
|
| 309 |
|
| 310 |
-
image_ref = DEFAULT_IMAGE
|
| 311 |
-
if USE_SPACE_IMAGE:
|
| 312 |
-
if SKIP_UPLOAD:
|
| 313 |
-
print('[launch] FEATHER_HF_SKIP_UPLOAD=1; reusing existing Space image', flush=True)
|
| 314 |
-
else:
|
| 315 |
-
if SYNC_OVERLAY:
|
| 316 |
-
sync_overlay_from_repo()
|
| 317 |
-
print('[launch] uploading custom Docker Space image context...', flush=True)
|
| 318 |
-
api.upload_folder(
|
| 319 |
-
repo_id=routing.space_repo,
|
| 320 |
-
repo_type='space',
|
| 321 |
-
folder_path=str(IMAGE_DIR),
|
| 322 |
-
commit_message='Update Feather training runtime image',
|
| 323 |
-
token=token,
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
print('[launch] waiting for Space image build to become ready...', flush=True)
|
| 327 |
-
wait_for_space(api, routing.space_repo, token=token)
|
| 328 |
-
image_ref = f'hf.co/spaces/{routing.space_repo}'
|
| 329 |
-
|
| 330 |
-
env = {
|
| 331 |
-
'HF_REPO_ID': routing.output_repo,
|
| 332 |
-
'FEATHER_HF_OWNER': routing.owner,
|
| 333 |
-
'FEATHER_HF_SPACE_REPO': routing.space_repo,
|
| 334 |
-
'FEATHER_HF_OUTPUT_REPO': routing.output_repo,
|
| 335 |
-
'FEATHER_HF_RETINA_CACHE_REPO': routing.retina_cache_repo,
|
| 336 |
-
'HYDRA_RETINA_CACHE_REPO': routing.retina_cache_repo,
|
| 337 |
-
'HYDRA_TARGET_SHARDS': TARGET_SHARDS,
|
| 338 |
-
'HYDRA_TIME_BUDGET': TIME_BUDGET,
|
| 339 |
-
'HYDRA_DOWNLOAD_WORKERS': DOWNLOAD_WORKERS,
|
| 340 |
'HYDRA_CKPT_INTERVAL': CKPT_INTERVAL,
|
| 341 |
'PYTHONUNBUFFERED': '1',
|
| 342 |
-
'FEATHER_RUNTIME_MODE': 'job',
|
| 343 |
-
}
|
| 344 |
-
if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
|
| 345 |
-
env['HYDRA_USE_NEMOTRON'] = '1'
|
| 346 |
-
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 347 |
-
# A10 profile: default compatibility mode avoids known PTX/compile runtime
|
| 348 |
-
# pitfalls, but HYDRA_THROUGHPUT_MODE=0 explicitly selects the full
|
| 349 |
-
# SDR/HTM/Engram architecture instead of silently inheriting bypass defaults.
|
| 350 |
-
_a10_profile = apply_a10_env_profile(env, job_flavor=JOB_FLAVOR)
|
| 351 |
-
if _a10_profile is not None:
|
| 352 |
-
if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
|
| 353 |
-
env['HYDRA_FASTPATH'] = '0'
|
| 354 |
-
print(
|
| 355 |
-
f'[launch] applied A10 {_a10_profile} env profile '
|
| 356 |
-
f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
|
| 357 |
-
f"HYDRA_THROUGHPUT_MODE={env.get('HYDRA_THROUGHPUT_MODE', 'unset')}, "
|
| 358 |
-
f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
|
| 359 |
-
f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
|
| 360 |
-
f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
|
| 361 |
-
f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
|
| 362 |
-
flush=True,
|
| 363 |
-
)
|
| 364 |
# Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
|
| 365 |
# sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
|
| 366 |
# HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
|
|
@@ -370,18 +364,18 @@ def main() -> int:
|
|
| 370 |
env[_k] = _v
|
| 371 |
secrets = {'HF_TOKEN': token}
|
| 372 |
|
| 373 |
-
print(f'[launch] submitting HF Job on flavor={JOB_FLAVOR}...', flush=True)
|
| 374 |
-
job = submit_job_with_retry(
|
| 375 |
-
api,
|
| 376 |
-
image=image_ref,
|
| 377 |
-
command=['python', '/app/entrypoint.py'],
|
| 378 |
-
env=env,
|
| 379 |
-
secrets=secrets,
|
| 380 |
-
flavor=cast(Any, JOB_FLAVOR),
|
| 381 |
-
timeout=TIMEOUT,
|
| 382 |
-
namespace=routing.job_namespace,
|
| 383 |
-
token=token,
|
| 384 |
-
)
|
| 385 |
print(f'[launch] submitted job_id={job.id} status={job.status.stage} url={job.url}', flush=True)
|
| 386 |
return 0
|
| 387 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
from collections.abc import Mapping, Sequence
|
| 10 |
+
from typing import Any, cast
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
from huggingface_hub import HfApi
|
| 15 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 16 |
+
from huggingface_hub.utils import get_token
|
| 17 |
+
|
| 18 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
if str(REPO_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 21 |
+
|
| 22 |
+
from scripts.hf_routing import resolve_routing
|
| 23 |
+
from configs.harness_config import HarnessConfig
|
| 24 |
+
|
| 25 |
+
DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:latest')
|
| 26 |
+
IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
|
| 27 |
+
TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
|
| 28 |
+
TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
|
| 29 |
+
TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
|
| 30 |
+
DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
|
| 31 |
+
CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
|
| 32 |
+
JOB_FLAVOR = os.environ.get('FEATHER_HF_FLAVOR', 'a10g-small')
|
| 33 |
+
DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
|
| 34 |
+
USE_SPACE_IMAGE = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
|
| 35 |
# When true, assume the Space image has already been built by a previous
|
| 36 |
# invocation and skip the upload+build wait. Used by sweep drivers that fan
|
| 37 |
# out many jobs against a single pre-uploaded image.
|
| 38 |
+
SKIP_UPLOAD = os.environ.get('FEATHER_HF_SKIP_UPLOAD', '0') == '1'
|
| 39 |
+
SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
|
| 40 |
+
JOB_SUBMIT_RETRIES = max(1, int(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRIES', '3')))
|
| 41 |
+
JOB_SUBMIT_RETRY_BASE_S = float(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRY_BASE_S', '5'))
|
| 42 |
+
BUILD_LOG_TAIL_LINES = max(1, int(os.environ.get('FEATHER_HF_BUILD_LOG_TAIL_LINES', '120')))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
|
| 46 |
+
"""Use streaming data path for short-budget launch profiles."""
|
| 47 |
+
try:
|
| 48 |
+
shards = int(target_shards)
|
| 49 |
+
budget = int(time_budget)
|
| 50 |
+
except ValueError:
|
| 51 |
+
return False
|
| 52 |
+
return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def apply_a10_env_profile(
|
| 56 |
+
env: dict[str, str],
|
| 57 |
+
*,
|
| 58 |
+
job_flavor: str,
|
| 59 |
+
parent_env: Mapping[str, str] = os.environ,
|
| 60 |
+
) -> str | None:
|
| 61 |
+
if not job_flavor.startswith('a10'):
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
full_arch = parent_env.get('HYDRA_THROUGHPUT_MODE') == '0' or env.get('HYDRA_THROUGHPUT_MODE') == '0'
|
| 65 |
+
if full_arch:
|
| 66 |
+
defaults = {
|
| 67 |
+
'HYDRA_THROUGHPUT_MODE': '0',
|
| 68 |
+
'HYDRA_MUON_COMPILE': '0',
|
| 69 |
+
'HYDRA_FORCE_HTM_CPU': '0',
|
| 70 |
+
'HYDRA_INERT_MAMBA': '0',
|
| 71 |
+
'HYDRA_ALLOW_SYNTHETIC_RETINA': '0',
|
| 72 |
+
'HYDRA_FASTPATH': '0',
|
| 73 |
+
}
|
| 74 |
+
profile = 'full-architecture'
|
| 75 |
+
else:
|
| 76 |
+
defaults = {
|
| 77 |
+
'HYDRA_MUON_COMPILE': '0',
|
| 78 |
+
'HYDRA_FORCE_HTM_CPU': '1',
|
| 79 |
+
'HYDRA_INERT_MAMBA': '1',
|
| 80 |
+
'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
|
| 81 |
+
'HYDRA_FASTPATH': '1',
|
| 82 |
+
}
|
| 83 |
+
profile = 'compatibility'
|
| 84 |
+
|
| 85 |
+
for key, default in defaults.items():
|
| 86 |
+
env[key] = parent_env.get(key, env.get(key, default))
|
| 87 |
+
return profile
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _http_status(exc: BaseException) -> int | None:
|
| 91 |
+
response = getattr(exc, 'response', None)
|
| 92 |
+
status = getattr(response, 'status_code', None)
|
| 93 |
+
return int(status) if isinstance(status, int) else None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def submit_job_with_retry(
|
| 97 |
+
api: HfApi,
|
| 98 |
+
*,
|
| 99 |
+
image: str,
|
| 100 |
+
command: Sequence[str],
|
| 101 |
+
env: dict[str, str],
|
| 102 |
+
secrets: dict[str, str],
|
| 103 |
+
flavor: Any,
|
| 104 |
+
timeout: str,
|
| 105 |
+
token: str,
|
| 106 |
+
namespace: str,
|
| 107 |
+
):
|
| 108 |
+
for attempt in range(1, JOB_SUBMIT_RETRIES + 1):
|
| 109 |
+
try:
|
| 110 |
+
return api.run_job(
|
| 111 |
+
image=image,
|
| 112 |
+
command=list(command),
|
| 113 |
+
env=env,
|
| 114 |
+
secrets=secrets,
|
| 115 |
+
flavor=flavor,
|
| 116 |
+
timeout=timeout,
|
| 117 |
+
namespace=namespace,
|
| 118 |
+
token=token,
|
| 119 |
+
)
|
| 120 |
+
except HfHubHTTPError as exc:
|
| 121 |
+
status = _http_status(exc)
|
| 122 |
+
if status is None or status < 500 or status >= 600:
|
| 123 |
+
raise
|
| 124 |
+
if attempt >= JOB_SUBMIT_RETRIES:
|
| 125 |
+
raise SystemExit(
|
| 126 |
+
f'HF job submit failed after {JOB_SUBMIT_RETRIES} attempts '
|
| 127 |
+
f'(http {status}); failing fast'
|
| 128 |
+
) from exc
|
| 129 |
+
time.sleep(JOB_SUBMIT_RETRY_BASE_S * attempt)
|
| 130 |
+
raise SystemExit('HF job submit failed unexpectedly; failing fast')
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def fetch_space_build_log_tail(
|
| 134 |
+
api: HfApi,
|
| 135 |
+
repo_id: str,
|
| 136 |
+
token: str,
|
| 137 |
+
*,
|
| 138 |
+
limit: int = BUILD_LOG_TAIL_LINES,
|
| 139 |
+
) -> str:
|
| 140 |
+
try:
|
| 141 |
+
lines = list(api.fetch_space_logs(repo_id, build=True, follow=False, token=token))
|
| 142 |
+
except Exception as exc:
|
| 143 |
+
return f'[space-build-log] failed to fetch build logs: {exc!r}'
|
| 144 |
+
tail = lines[-limit:]
|
| 145 |
+
text = ''.join(tail)
|
| 146 |
+
if text and not text.endswith('\n'):
|
| 147 |
+
text += '\n'
|
| 148 |
+
return text or '[space-build-log] no buffered build logs available\n'
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def sync_overlay_from_repo() -> None:
|
| 152 |
+
"""Refresh Space overlay with required project files."""
|
| 153 |
+
overlay = IMAGE_DIR / 'overlay'
|
| 154 |
+
overlay.mkdir(parents=True, exist_ok=True)
|
| 155 |
+
|
| 156 |
+
for child in overlay.iterdir():
|
| 157 |
+
if child.is_dir():
|
| 158 |
+
shutil.rmtree(child)
|
| 159 |
+
else:
|
| 160 |
+
child.unlink()
|
| 161 |
+
|
| 162 |
+
include_paths = [
|
| 163 |
+
'hydra',
|
| 164 |
+
'subsystems',
|
| 165 |
+
'scripts',
|
| 166 |
+
'htm_rust',
|
| 167 |
+
'harness',
|
| 168 |
+
'configs',
|
| 169 |
+
'prepare.py',
|
| 170 |
+
'prepare_nemotron.py',
|
| 171 |
+
'train.py',
|
| 172 |
+
'pyproject.toml',
|
| 173 |
+
'uv.lock',
|
| 174 |
+
]
|
| 175 |
+
ignore = shutil.ignore_patterns(
|
| 176 |
+
'__pycache__',
|
| 177 |
+
'.pytest_cache',
|
| 178 |
+
'.ruff_cache',
|
| 179 |
+
'.venv',
|
| 180 |
+
'.git',
|
| 181 |
+
'target',
|
| 182 |
+
'*.pyc',
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
copied: list[str] = []
|
| 186 |
+
for rel in include_paths:
|
| 187 |
+
src = REPO_ROOT / rel
|
| 188 |
+
dst = overlay / rel
|
| 189 |
+
if not src.exists():
|
| 190 |
+
continue
|
| 191 |
+
if src.is_dir():
|
| 192 |
+
shutil.copytree(src, dst, dirs_exist_ok=True, ignore=ignore)
|
| 193 |
+
else:
|
| 194 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 195 |
+
shutil.copy2(src, dst)
|
| 196 |
+
copied.append(rel)
|
| 197 |
+
|
| 198 |
+
scripts_dir = overlay / 'scripts'
|
| 199 |
+
if scripts_dir.exists():
|
| 200 |
+
for sh_path in scripts_dir.rglob('*.sh'):
|
| 201 |
+
data = sh_path.read_bytes()
|
| 202 |
+
data = data.replace(b'\r\n', b'\n').replace(b'\r', b'\n')
|
| 203 |
+
sh_path.write_bytes(data)
|
| 204 |
+
|
| 205 |
+
print(f'[launch] overlay synced from repo ({len(copied)} paths): {copied}', flush=True)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def require_token() -> str:
|
| 209 |
+
token = os.environ.get('HF_TOKEN') or get_token()
|
| 210 |
+
if not token:
|
| 211 |
+
raise SystemExit('HF_TOKEN must be set or cached via huggingface-cli login for launch_feather_hf_job.py')
|
| 212 |
+
return token
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
"""Wait until the Space image has been built.
|
| 217 |
|
| 218 |
We use the Space purely as a container-image builder for HF Jobs. The Space
|
|
|
|
| 227 |
and APP_STARTING_ERROR after a successful BUILDING→APP_STARTING transition
|
| 228 |
are acceptable — the image exists in the registry and Jobs can use it.
|
| 229 |
"""
|
| 230 |
+
start = time.time()
|
| 231 |
+
seen_build_completion = False
|
| 232 |
+
seen_building = False
|
| 233 |
+
while True:
|
| 234 |
+
try:
|
| 235 |
+
runtime = api.get_space_runtime(repo_id, token=token)
|
| 236 |
+
except httpx.TransportError as exc:
|
| 237 |
+
if time.time() - start > timeout_s:
|
| 238 |
+
raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with network errors') from exc
|
| 239 |
+
print(f'[space] transient runtime endpoint network error: {exc!r}; retrying', flush=True)
|
| 240 |
+
time.sleep(20)
|
| 241 |
+
continue
|
| 242 |
+
except HfHubHTTPError as exc:
|
| 243 |
+
status = _http_status(exc)
|
| 244 |
+
if status is not None and 500 <= status < 600:
|
| 245 |
+
if time.time() - start > timeout_s:
|
| 246 |
+
raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with HTTP {status}') from exc
|
| 247 |
+
time.sleep(20)
|
| 248 |
+
continue
|
| 249 |
+
raise
|
| 250 |
+
stage = getattr(runtime, 'stage', None)
|
| 251 |
+
hardware = getattr(runtime, 'hardware', None)
|
| 252 |
+
err = getattr(runtime, 'errorMessage', None) or getattr(runtime, 'error_message', None)
|
| 253 |
+
print(f'[space] stage={stage} hardware={hardware}', flush=True)
|
| 254 |
+
if stage == 'BUILDING':
|
| 255 |
+
seen_building = True
|
| 256 |
+
if stage in {'APP_STARTING', 'RUNNING', 'PAUSED', 'SLEEPING'}:
|
| 257 |
+
seen_build_completion = True
|
| 258 |
+
if stage in {'RUNNING', 'PAUSED', 'SLEEPING'}:
|
| 259 |
+
return
|
| 260 |
+
# Image is built — Jobs can use it regardless of Space boot outcome.
|
| 261 |
+
if (seen_build_completion or seen_building) and stage in {'RUNTIME_ERROR', 'APP_STARTING_ERROR'}:
|
| 262 |
+
print(f'[space] Space boot failed with {stage} but built image is '
|
| 263 |
+
f'available in the Space registry and is usable by HF Jobs.',
|
| 264 |
+
flush=True)
|
| 265 |
+
return
|
| 266 |
+
# Hard build failures — no image was produced.
|
| 267 |
+
if stage in {'BUILD_ERROR', 'CONFIG_ERROR', 'NO_APP_FILE'}:
|
| 268 |
+
build_log_tail = fetch_space_build_log_tail(api, repo_id, token)
|
| 269 |
+
raise RuntimeError(
|
| 270 |
+
f'Space {repo_id} build failed: stage={stage} error={err!r}\n'
|
| 271 |
+
f'--- Space build log tail ---\n{build_log_tail}'
|
| 272 |
+
)
|
| 273 |
if time.time() - start > timeout_s:
|
| 274 |
raise TimeoutError(f'Space {repo_id} did not become ready in {timeout_s}s (last stage={stage})')
|
| 275 |
time.sleep(20)
|
| 276 |
|
| 277 |
|
| 278 |
+
def main() -> int:
|
| 279 |
+
token = require_token()
|
| 280 |
+
routing = resolve_routing(token=token)
|
| 281 |
+
api = HfApi(token=token)
|
| 282 |
+
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 283 |
+
|
| 284 |
+
print(f'[launch] image_dir={IMAGE_DIR}', flush=True)
|
| 285 |
+
print(f'[launch] owner={routing.owner}', flush=True)
|
| 286 |
+
print(f'[launch] space_repo={routing.space_repo}', flush=True)
|
| 287 |
+
print(f'[launch] output_repo={routing.output_repo}', flush=True)
|
| 288 |
+
print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
|
| 289 |
+
print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
|
| 290 |
+
print(f'[launch] flavor={JOB_FLAVOR}', flush=True)
|
| 291 |
+
print(f'[launch] namespace={routing.job_namespace}', flush=True)
|
| 292 |
+
print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
|
| 293 |
+
print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
|
| 294 |
+
if not USE_SPACE_IMAGE:
|
| 295 |
+
print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
|
| 296 |
+
|
| 297 |
+
api.create_repo(repo_id=routing.space_repo, repo_type='space', space_sdk='docker', private=True, exist_ok=True, token=token)
|
| 298 |
+
api.create_repo(repo_id=routing.output_repo, repo_type='model', private=True, exist_ok=True, token=token)
|
| 299 |
|
| 300 |
if DRY_RUN:
|
| 301 |
print('[launch] dry-run mode; skipping upload and job submission', flush=True)
|
| 302 |
return 0
|
| 303 |
|
| 304 |
+
image_ref = DEFAULT_IMAGE
|
| 305 |
+
if USE_SPACE_IMAGE:
|
| 306 |
+
if SKIP_UPLOAD:
|
| 307 |
+
print('[launch] FEATHER_HF_SKIP_UPLOAD=1; reusing existing Space image', flush=True)
|
| 308 |
+
else:
|
| 309 |
+
if SYNC_OVERLAY:
|
| 310 |
+
sync_overlay_from_repo()
|
| 311 |
+
print('[launch] uploading custom Docker Space image context...', flush=True)
|
| 312 |
+
api.upload_folder(
|
| 313 |
+
repo_id=routing.space_repo,
|
| 314 |
+
repo_type='space',
|
| 315 |
+
folder_path=str(IMAGE_DIR),
|
| 316 |
+
commit_message='Update Feather training runtime image',
|
| 317 |
+
token=token,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
print('[launch] waiting for Space image build to become ready...', flush=True)
|
| 321 |
+
wait_for_space(api, routing.space_repo, token=token)
|
| 322 |
+
image_ref = f'hf.co/spaces/{routing.space_repo}'
|
| 323 |
+
|
| 324 |
+
env = {
|
| 325 |
+
'HF_REPO_ID': routing.output_repo,
|
| 326 |
+
'FEATHER_HF_OWNER': routing.owner,
|
| 327 |
+
'FEATHER_HF_SPACE_REPO': routing.space_repo,
|
| 328 |
+
'FEATHER_HF_OUTPUT_REPO': routing.output_repo,
|
| 329 |
+
'FEATHER_HF_RETINA_CACHE_REPO': routing.retina_cache_repo,
|
| 330 |
+
'HYDRA_RETINA_CACHE_REPO': routing.retina_cache_repo,
|
| 331 |
+
'HYDRA_TARGET_SHARDS': TARGET_SHARDS,
|
| 332 |
+
'HYDRA_TIME_BUDGET': TIME_BUDGET,
|
| 333 |
+
'HYDRA_DOWNLOAD_WORKERS': DOWNLOAD_WORKERS,
|
| 334 |
'HYDRA_CKPT_INTERVAL': CKPT_INTERVAL,
|
| 335 |
'PYTHONUNBUFFERED': '1',
|
| 336 |
+
'FEATHER_RUNTIME_MODE': 'job',
|
| 337 |
+
}
|
| 338 |
+
if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
|
| 339 |
+
env['HYDRA_USE_NEMOTRON'] = '1'
|
| 340 |
+
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 341 |
+
# A10 profile: default compatibility mode avoids known PTX/compile runtime
|
| 342 |
+
# pitfalls, but HYDRA_THROUGHPUT_MODE=0 explicitly selects the full
|
| 343 |
+
# SDR/HTM/Engram architecture instead of silently inheriting bypass defaults.
|
| 344 |
+
_a10_profile = apply_a10_env_profile(env, job_flavor=JOB_FLAVOR)
|
| 345 |
+
if _a10_profile is not None:
|
| 346 |
+
if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
|
| 347 |
+
env['HYDRA_FASTPATH'] = '0'
|
| 348 |
+
print(
|
| 349 |
+
f'[launch] applied A10 {_a10_profile} env profile '
|
| 350 |
+
f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
|
| 351 |
+
f"HYDRA_THROUGHPUT_MODE={env.get('HYDRA_THROUGHPUT_MODE', 'unset')}, "
|
| 352 |
+
f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
|
| 353 |
+
f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
|
| 354 |
+
f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
|
| 355 |
+
f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
|
| 356 |
+
flush=True,
|
| 357 |
+
)
|
| 358 |
# Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
|
| 359 |
# sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
|
| 360 |
# HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
|
|
|
|
| 364 |
env[_k] = _v
|
| 365 |
secrets = {'HF_TOKEN': token}
|
| 366 |
|
| 367 |
+
print(f'[launch] submitting HF Job on flavor={JOB_FLAVOR}...', flush=True)
|
| 368 |
+
job = submit_job_with_retry(
|
| 369 |
+
api,
|
| 370 |
+
image=image_ref,
|
| 371 |
+
command=['python', '/app/entrypoint.py'],
|
| 372 |
+
env=env,
|
| 373 |
+
secrets=secrets,
|
| 374 |
+
flavor=cast(Any, JOB_FLAVOR),
|
| 375 |
+
timeout=TIMEOUT,
|
| 376 |
+
namespace=routing.job_namespace,
|
| 377 |
+
token=token,
|
| 378 |
+
)
|
| 379 |
print(f'[launch] submitted job_id={job.id} status={job.status.stage} url={job.url}', flush=True)
|
| 380 |
return 0
|
| 381 |
|
overlay/scripts/optuna_hpo.py
CHANGED
|
@@ -5,131 +5,131 @@ import argparse
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
-
import subprocess
|
| 9 |
-
import sys
|
| 10 |
-
import time
|
| 11 |
-
import tempfile
|
| 12 |
-
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
from typing import Any
|
| 15 |
-
|
| 16 |
-
import optuna
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
_HF_ENV_KEY_RE = re.compile(r"^[A-Z][A-Z0-9_]*$")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 23 |
-
if str(REPO_ROOT) not in sys.path:
|
| 24 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 25 |
-
|
| 26 |
-
from scripts.hf_routing import resolve_routing
|
| 27 |
-
|
| 28 |
-
TRAIN_ENTRYPOINT = REPO_ROOT / "train.py"
|
| 29 |
-
SEARCH_SPACE_KEYS = {
|
| 30 |
-
"d_model",
|
| 31 |
-
"n_layer",
|
| 32 |
-
"d_state",
|
| 33 |
-
"headdim",
|
| 34 |
-
"expand",
|
| 35 |
-
"seq_len",
|
| 36 |
-
"batch_size",
|
| 37 |
-
"grad_accum",
|
| 38 |
-
"matrix_lr",
|
| 39 |
-
"embed_lr",
|
| 40 |
-
"unembed_lr",
|
| 41 |
-
"hyena_layers",
|
| 42 |
-
"engram_n_columns",
|
| 43 |
-
"engram_layer_idx",
|
| 44 |
-
"sdr_target_active",
|
| 45 |
-
"htm_learn_every",
|
| 46 |
-
"htm_subsample",
|
| 47 |
-
"engram_subsample",
|
| 48 |
-
"mamba3_chunk",
|
| 49 |
-
"dropout",
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _filter_prior_params(raw: dict[str, Any]) -> dict[str, Any]:
|
| 54 |
-
return {k: v for k, v in raw.items() if k in SEARCH_SPACE_KEYS}
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def _load_prior_param_sets(path: Path) -> list[dict[str, Any]]:
|
| 58 |
-
if not path.exists():
|
| 59 |
-
return []
|
| 60 |
-
|
| 61 |
-
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 62 |
-
if isinstance(payload, dict):
|
| 63 |
-
rows = payload.get("trials", [])
|
| 64 |
-
elif isinstance(payload, list):
|
| 65 |
-
rows = payload
|
| 66 |
-
else:
|
| 67 |
-
rows = []
|
| 68 |
-
|
| 69 |
-
out: list[dict[str, Any]] = []
|
| 70 |
-
for item in rows:
|
| 71 |
-
if not isinstance(item, dict):
|
| 72 |
-
continue
|
| 73 |
-
params_obj = item.get("params", item)
|
| 74 |
-
if not isinstance(params_obj, dict):
|
| 75 |
-
continue
|
| 76 |
-
filtered = _filter_prior_params(params_obj)
|
| 77 |
-
if filtered:
|
| 78 |
-
out.append(filtered)
|
| 79 |
-
return out
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def _enqueue_transfer_priors(study: optuna.Study, priors_file: Path, apply_priors: bool) -> int:
|
| 83 |
-
if not apply_priors:
|
| 84 |
-
return 0
|
| 85 |
-
|
| 86 |
-
priors_raw = _load_prior_param_sets(priors_file)
|
| 87 |
-
if not priors_raw:
|
| 88 |
-
return 0
|
| 89 |
-
|
| 90 |
-
# Deduplicate param sets across merged studies.
|
| 91 |
-
priors: list[dict[str, Any]] = []
|
| 92 |
-
seen: set[str] = set()
|
| 93 |
-
for params in priors_raw:
|
| 94 |
-
key = json.dumps(params, sort_keys=True)
|
| 95 |
-
if key in seen:
|
| 96 |
-
continue
|
| 97 |
-
seen.add(key)
|
| 98 |
-
priors.append(params)
|
| 99 |
-
|
| 100 |
-
enqueued = 0
|
| 101 |
-
for params in priors:
|
| 102 |
-
before = len(study.get_trials(deepcopy=False))
|
| 103 |
-
try:
|
| 104 |
-
study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"}, skip_if_exists=True)
|
| 105 |
-
except TypeError:
|
| 106 |
-
study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"})
|
| 107 |
-
after = len(study.get_trials(deepcopy=False))
|
| 108 |
-
if after > before:
|
| 109 |
-
enqueued += 1
|
| 110 |
-
return enqueued
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def _enqueue_quality_anchors(study: optuna.Study, priors_file: Path, quality_mode_local: bool, top_k: int) -> int:
|
| 114 |
-
if not quality_mode_local or top_k <= 0:
|
| 115 |
-
return 0
|
| 116 |
-
|
| 117 |
-
priors = _load_prior_param_sets(priors_file)[:top_k]
|
| 118 |
-
enqueued = 0
|
| 119 |
-
for params in priors:
|
| 120 |
-
before = len(study.get_trials(deepcopy=False))
|
| 121 |
-
try:
|
| 122 |
-
study.enqueue_trial(
|
| 123 |
-
params,
|
| 124 |
-
user_attrs={"seed_source": "quality_anchor"},
|
| 125 |
-
skip_if_exists=True,
|
| 126 |
-
)
|
| 127 |
-
except TypeError:
|
| 128 |
-
study.enqueue_trial(params, user_attrs={"seed_source": "quality_anchor"})
|
| 129 |
-
after = len(study.get_trials(deepcopy=False))
|
| 130 |
-
if after > before:
|
| 131 |
-
enqueued += 1
|
| 132 |
-
return enqueued
|
| 133 |
|
| 134 |
|
| 135 |
def _parse_metrics_from_stdout(stdout: str) -> dict[str, Any] | None:
|
|
@@ -164,241 +164,241 @@ def _parse_metrics_from_log_lines(lines: list[str]) -> dict[str, Any] | None:
|
|
| 164 |
return None
|
| 165 |
|
| 166 |
|
| 167 |
-
def _parse_last_train_bpb_from_logs(lines: list[str]) -> float | None:
|
| 168 |
-
"""Best-effort fallback when final eval crashes before metrics JSON write."""
|
| 169 |
-
last: float | None = None
|
| 170 |
-
for line in lines:
|
| 171 |
-
m = re.search(r"\bbpb=([0-9]+(?:\.[0-9]+)?)", line)
|
| 172 |
if m:
|
| 173 |
-
last = float(m.group(1))
|
| 174 |
-
return last
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def _persist_trial_artifacts(
|
| 178 |
-
*,
|
| 179 |
-
trial_dir: Path,
|
| 180 |
-
metrics: dict[str, Any] | None,
|
| 181 |
-
log_lines: list[str] | None,
|
| 182 |
-
log_name: str,
|
| 183 |
-
metadata: dict[str, Any],
|
| 184 |
-
) -> dict[str, str | None]:
|
| 185 |
-
trial_dir.mkdir(parents=True, exist_ok=True)
|
| 186 |
-
metrics_path = trial_dir / "metrics.json"
|
| 187 |
-
log_path = trial_dir / log_name
|
| 188 |
-
manifest_path = trial_dir / "trial_artifacts.json"
|
| 189 |
-
|
| 190 |
-
if metrics is not None:
|
| 191 |
-
metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True), encoding="utf-8")
|
| 192 |
-
if log_lines is not None:
|
| 193 |
-
log_path.write_text("\n".join(log_lines), encoding="utf-8")
|
| 194 |
-
|
| 195 |
-
manifest = {
|
| 196 |
-
**metadata,
|
| 197 |
-
"metrics_path": str(metrics_path) if metrics is not None else None,
|
| 198 |
-
"log_path": str(log_path) if log_lines is not None else None,
|
| 199 |
-
}
|
| 200 |
-
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
|
| 201 |
-
return {
|
| 202 |
-
"metrics_path": str(metrics_path) if metrics is not None else None,
|
| 203 |
-
"log_path": str(log_path) if log_lines is not None else None,
|
| 204 |
-
"manifest_path": str(manifest_path),
|
| 205 |
-
}
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def _resolve_objective_metric(
|
| 209 |
-
trial: optuna.Trial,
|
| 210 |
-
*,
|
| 211 |
-
metric_key: str,
|
| 212 |
-
metrics: dict[str, Any] | None,
|
| 213 |
-
allow_log_metric_fallback: bool,
|
| 214 |
-
fallback_bpb: float | None,
|
| 215 |
-
tps_seen: float | None,
|
| 216 |
-
) -> float:
|
| 217 |
-
"""Resolve the objective value while labeling where it came from.
|
| 218 |
-
|
| 219 |
-
Validation metrics and live training-log fallbacks are intentionally
|
| 220 |
-
different sources. Keeping that distinction in trial attrs prevents a
|
| 221 |
-
skipped/OOM eval from being mistaken for a real validation result.
|
| 222 |
-
"""
|
| 223 |
-
if metrics is None:
|
| 224 |
-
if allow_log_metric_fallback and metric_key == "val_bpb" and fallback_bpb is not None:
|
| 225 |
-
trial.set_user_attr("objective_source", "train_log_fallback")
|
| 226 |
-
trial.set_user_attr("objective_metric", "train_bpb")
|
| 227 |
-
trial.set_user_attr("eval_status", "missing_metrics")
|
| 228 |
-
trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
|
| 229 |
-
if tps_seen is not None:
|
| 230 |
-
trial.set_user_attr("tps", float(tps_seen))
|
| 231 |
-
return float(fallback_bpb)
|
| 232 |
-
trial.set_user_attr("objective_source", "missing_metrics")
|
| 233 |
-
raise optuna.TrialPruned("No metrics payload found")
|
| 234 |
-
|
| 235 |
-
eval_status = str(
|
| 236 |
-
metrics.get(
|
| 237 |
-
"eval_status",
|
| 238 |
-
"completed" if metrics.get("val_bpb") is not None else "unknown",
|
| 239 |
-
)
|
| 240 |
-
)
|
| 241 |
-
trial.set_user_attr("eval_status", eval_status)
|
| 242 |
-
|
| 243 |
-
if fallback_bpb is not None:
|
| 244 |
-
trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
|
| 245 |
-
|
| 246 |
-
if metric_key not in metrics or metrics[metric_key] is None:
|
| 247 |
-
trial.set_user_attr("objective_source", "missing_metric")
|
| 248 |
-
trial.set_user_attr("objective_metric", metric_key)
|
| 249 |
-
raise optuna.TrialPruned(f"Metric '{metric_key}' missing in metrics payload")
|
| 250 |
-
|
| 251 |
-
value = float(metrics[metric_key])
|
| 252 |
-
trial.set_user_attr("objective_metric", metric_key)
|
| 253 |
-
if metric_key == "val_bpb":
|
| 254 |
-
trial.set_user_attr("objective_source", "final_val")
|
| 255 |
-
trial.set_user_attr("final_val_bpb", value)
|
| 256 |
-
else:
|
| 257 |
-
trial.set_user_attr("objective_source", "metrics_json")
|
| 258 |
-
return value
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def _fetch_job_logs_safe(
|
| 262 |
-
api,
|
| 263 |
-
*,
|
| 264 |
-
job_id: str,
|
| 265 |
-
token: str,
|
| 266 |
-
namespace: str,
|
| 267 |
-
retries: int = 3,
|
| 268 |
-
sleep_s: float = 2.0,
|
| 269 |
-
timeout_s: float = 20.0,
|
| 270 |
-
) -> list[str]:
|
| 271 |
-
last_exc: Exception | None = None
|
| 272 |
-
for attempt in range(1, retries + 1):
|
| 273 |
-
try:
|
| 274 |
-
with ThreadPoolExecutor(max_workers=1) as executor:
|
| 275 |
-
future = executor.submit(
|
| 276 |
-
lambda: list(api.fetch_job_logs(job_id=job_id, follow=False, token=token, namespace=namespace))
|
| 277 |
-
)
|
| 278 |
-
return future.result(timeout=timeout_s)
|
| 279 |
-
except FuturesTimeoutError as exc:
|
| 280 |
-
last_exc = TimeoutError(f"Timed out fetching HF job logs for {job_id} after {timeout_s:.1f}s")
|
| 281 |
-
except Exception as exc: # noqa: BLE001
|
| 282 |
-
last_exc = exc
|
| 283 |
-
if attempt >= retries:
|
| 284 |
-
raise
|
| 285 |
-
time.sleep(sleep_s)
|
| 286 |
-
if last_exc is not None:
|
| 287 |
-
raise last_exc
|
| 288 |
-
return []
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def _effective_min_tps(args: argparse.Namespace) -> float | None:
|
| 292 |
-
min_tps = args.min_tps
|
| 293 |
-
if getattr(args, "quality_mode_local", False) and min_tps == 50000.0:
|
| 294 |
-
return 0.0
|
| 295 |
-
return min_tps
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def _trial_env(trial: optuna.Trial, args: argparse.Namespace, metrics_path: Path) -> dict[str, str]:
|
| 299 |
-
env = os.environ.copy()
|
| 300 |
-
full_arch_hpo = env.get("HYDRA_HPO_FULL_ARCH", "0") == "1"
|
| 301 |
-
speed_arch_hpo = full_arch_hpo and env.get("HYDRA_HPO_SPEED_ARCH", "0") == "1"
|
| 302 |
-
quality_mode_local = bool(getattr(args, "quality_mode_local", False))
|
| 303 |
|
| 304 |
# Runtime and reporting
|
| 305 |
env["HYDRA_METRICS_OUT"] = str(metrics_path)
|
| 306 |
env["HYDRA_TIME_BUDGET"] = str(args.trial_time_budget)
|
| 307 |
env["PYTHONUNBUFFERED"] = "1"
|
| 308 |
|
| 309 |
-
# Search space — fully env-driven to match existing training stack.
|
| 310 |
-
if speed_arch_hpo:
|
| 311 |
-
# Full-arch speed mode targets A10 underutilization observed in HPO:
|
| 312 |
-
# low VRAM/MFU, strong BPB from shallow models, and fixed SDR/HTM
|
| 313 |
-
# overhead dominating small microbatches. Keep all components enabled
|
| 314 |
-
# while amortizing overhead over more tokens.
|
| 315 |
-
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96]))
|
| 316 |
-
env["HYDRA_N_LAYER"] = str(trial.suggest_categorical("n_layer", [2]))
|
| 317 |
-
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
|
| 318 |
-
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
|
| 319 |
-
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 320 |
-
elif quality_mode_local and full_arch_hpo:
|
| 321 |
-
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128]))
|
| 322 |
-
env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 2, 3))
|
| 323 |
-
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
|
| 324 |
-
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
|
| 325 |
-
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 326 |
-
else:
|
| 327 |
-
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128, 160, 192]))
|
| 328 |
-
env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 1, 4))
|
| 329 |
-
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32, 48]))
|
| 330 |
-
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [8, 16, 32]))
|
| 331 |
-
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 332 |
-
|
| 333 |
-
if speed_arch_hpo:
|
| 334 |
-
seq_len = trial.suggest_categorical("seq_len", [64, 128])
|
| 335 |
-
batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
|
| 336 |
-
grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
|
| 337 |
-
elif quality_mode_local and full_arch_hpo:
|
| 338 |
-
seq_len = trial.suggest_categorical("seq_len", [64])
|
| 339 |
-
batch_size = trial.suggest_categorical("batch_size", [4, 8])
|
| 340 |
-
grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
|
| 341 |
-
else:
|
| 342 |
-
seq_len = trial.suggest_categorical("seq_len", [32, 64])
|
| 343 |
-
batch_size = trial.suggest_categorical("batch_size", [4, 8] if full_arch_hpo else [4, 8, 16])
|
| 344 |
-
grad_accum = trial.suggest_categorical("grad_accum", [1, 4, 8, 16] if full_arch_hpo else [8, 16, 32, 64])
|
| 345 |
# Keep TOTAL_BATCH_SIZE divisible by DEVICE_BATCH_SIZE * MAX_SEQ_LEN.
|
| 346 |
total_batch = batch_size * seq_len * grad_accum
|
| 347 |
env["HYDRA_SEQ_LEN"] = str(seq_len)
|
| 348 |
env["HYDRA_BATCH_SIZE"] = str(batch_size)
|
| 349 |
env["HYDRA_TOTAL_BATCH"] = str(total_batch)
|
| 350 |
|
| 351 |
-
if quality_mode_local and full_arch_hpo:
|
| 352 |
-
env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.008, 0.03, log=True))
|
| 353 |
-
env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.15, 0.6, log=True))
|
| 354 |
-
env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.001, 0.01, log=True))
|
| 355 |
-
else:
|
| 356 |
-
env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.005, 0.2, log=True))
|
| 357 |
-
env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.05, 1.0, log=True))
|
| 358 |
-
env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.0005, 0.02, log=True))
|
| 359 |
-
|
| 360 |
-
if full_arch_hpo:
|
| 361 |
-
env["HYDRA_HYENA_LAYERS"] = ""
|
| 362 |
-
env["HYDRA_ENGRAM_N_COLUMNS"] = str(
|
| 363 |
-
trial.suggest_categorical(
|
| 364 |
-
"engram_n_columns",
|
| 365 |
-
[512, 1024] if (speed_arch_hpo or quality_mode_local) else [512, 1024, 2048],
|
| 366 |
-
)
|
| 367 |
-
)
|
| 368 |
-
env["HYDRA_ENGRAM_LAYER_IDX"] = str(trial.suggest_int("engram_layer_idx", 0, max(0, int(env["HYDRA_N_LAYER"]) - 1)))
|
| 369 |
-
env["HYDRA_SDR_TARGET_ACTIVE"] = str(
|
| 370 |
-
trial.suggest_categorical(
|
| 371 |
-
"sdr_target_active",
|
| 372 |
-
[327] if quality_mode_local else ([164, 327] if speed_arch_hpo else [164, 327, 512]),
|
| 373 |
-
)
|
| 374 |
-
)
|
| 375 |
-
env["HYDRA_HTM_LEARN_EVERY"] = str(
|
| 376 |
-
trial.suggest_categorical("htm_learn_every", [8, 16] if (speed_arch_hpo or quality_mode_local) else [4, 8, 16])
|
| 377 |
-
)
|
| 378 |
-
env["HYDRA_HTM_SUBSAMPLE"] = str(
|
| 379 |
-
trial.suggest_categorical("htm_subsample", [1, 2] if quality_mode_local else ([4, 8, 16] if speed_arch_hpo else [1, 2, 4, 8]))
|
| 380 |
-
)
|
| 381 |
-
env["HYDRA_ENGRAM_SUBSAMPLE"] = str(
|
| 382 |
-
trial.suggest_categorical("engram_subsample", [1, 2] if quality_mode_local else ([1, 2, 4] if speed_arch_hpo else [1]))
|
| 383 |
-
)
|
| 384 |
-
env["HYDRA_MAMBA3_CHUNK"] = str(trial.suggest_categorical("mamba3_chunk", [32, 64]))
|
| 385 |
-
env["HYDRA_DROPOUT"] = str(trial.suggest_categorical("dropout", [0.0, 0.1] if (speed_arch_hpo or quality_mode_local) else [0.0, 0.1, 0.2]))
|
| 386 |
-
else:
|
| 387 |
-
env["HYDRA_HYENA_LAYERS"] = trial.suggest_categorical("hyena_layers", ["", "0", "1", "0,1"])
|
| 388 |
|
| 389 |
# Keep trials alive long enough to emit metrics.
|
| 390 |
env["HYDRA_FAIL_LOSS_THRESHOLD"] = "1000000"
|
| 391 |
env["HYDRA_USE_NEMOTRON"] = os.environ.get("HYDRA_USE_NEMOTRON", "1")
|
| 392 |
env["HYDRA_LOCAL_SHARDS_ONLY"] = os.environ.get("HYDRA_LOCAL_SHARDS_ONLY", "0")
|
| 393 |
# Strict optimal-path defaults (no forced fallback profile).
|
| 394 |
-
env["HYDRA_MUON_COMPILE"] = os.environ.get("HYDRA_MUON_COMPILE", "1")
|
| 395 |
-
env["HYDRA_THROUGHPUT_MODE"] = os.environ.get("HYDRA_THROUGHPUT_MODE", "0" if full_arch_hpo else "1")
|
| 396 |
-
env["HYDRA_FORCE_HTM_CPU"] = os.environ.get("HYDRA_FORCE_HTM_CPU", "0")
|
| 397 |
-
env["HYDRA_ALLOW_SYNTHETIC_RETINA"] = os.environ.get("HYDRA_ALLOW_SYNTHETIC_RETINA", "0")
|
| 398 |
-
env["HYDRA_INERT_MAMBA"] = os.environ.get("HYDRA_INERT_MAMBA", "0")
|
| 399 |
-
env["HYDRA_FASTPATH"] = os.environ.get("HYDRA_FASTPATH", "0" if full_arch_hpo else "1")
|
| 400 |
-
|
| 401 |
-
return env
|
| 402 |
|
| 403 |
|
| 404 |
def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
|
|
@@ -410,7 +410,7 @@ def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
|
|
| 410 |
return sanitized
|
| 411 |
|
| 412 |
|
| 413 |
-
def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
|
| 414 |
if args.hf_use_bash:
|
| 415 |
return [["bash", "-lc", args.hf_command]]
|
| 416 |
|
|
@@ -432,20 +432,20 @@ def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
|
|
| 432 |
uniq.append(c)
|
| 433 |
return uniq
|
| 434 |
|
| 435 |
-
return [raw.split()]
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
def _space_repo_from_hf_image(image: str, namespace: str) -> str:
|
| 439 |
-
prefix = "hf.co/spaces/"
|
| 440 |
-
if image.startswith(prefix):
|
| 441 |
-
return image[len(prefix):]
|
| 442 |
-
return os.environ.get("FEATHER_HF_SPACE_REPO", f"{namespace}/feather-a10-runtime")
|
| 443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
|
|
|
| 449 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
| 450 |
metrics_path = trial_dir / "metrics.json"
|
| 451 |
|
|
@@ -460,67 +460,67 @@ def _objective_local(args: argparse.Namespace):
|
|
| 460 |
timeout=args.trial_timeout,
|
| 461 |
)
|
| 462 |
|
| 463 |
-
metrics: dict[str, Any] | None = None
|
| 464 |
if metrics_path.exists():
|
| 465 |
try:
|
| 466 |
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
|
| 467 |
except json.JSONDecodeError:
|
| 468 |
metrics = None
|
| 469 |
-
if metrics is None:
|
| 470 |
-
metrics = _parse_metrics_from_stdout(proc.stdout)
|
| 471 |
-
|
| 472 |
-
artifact_paths = _persist_trial_artifacts(
|
| 473 |
-
trial_dir=trial_dir,
|
| 474 |
-
metrics=metrics,
|
| 475 |
-
log_lines=(proc.stdout or "").splitlines(),
|
| 476 |
-
log_name="train_stdout.log",
|
| 477 |
-
metadata={"runner": "local", "returncode": proc.returncode},
|
| 478 |
-
)
|
| 479 |
-
(trial_dir / "train_stderr.log").write_text(proc.stderr or "", encoding="utf-8")
|
| 480 |
-
|
| 481 |
-
fallback_bpb = _parse_last_train_bpb_from_logs(proc.stdout.splitlines())
|
| 482 |
-
if metrics is None:
|
| 483 |
-
_resolve_objective_metric(
|
| 484 |
-
trial,
|
| 485 |
-
metric_key=args.metric,
|
| 486 |
-
metrics=None,
|
| 487 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 488 |
-
fallback_bpb=fallback_bpb,
|
| 489 |
-
tps_seen=None,
|
| 490 |
-
)
|
| 491 |
-
raise optuna.TrialPruned("No metrics found (HYDRA_METRICS_OUT/[METRICS_JSON])")
|
| 492 |
|
| 493 |
if proc.returncode != 0:
|
| 494 |
raise optuna.TrialPruned(f"Training failed rc={proc.returncode}")
|
| 495 |
|
| 496 |
-
metric_key = args.metric
|
| 497 |
|
| 498 |
tps_val = metrics.get("tps")
|
| 499 |
if tps_val is not None:
|
| 500 |
tps_f = float(tps_val)
|
| 501 |
trial.set_user_attr("tps", tps_f)
|
| 502 |
-
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 503 |
-
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 504 |
-
|
| 505 |
-
value = _resolve_objective_metric(
|
| 506 |
-
trial,
|
| 507 |
-
metric_key=metric_key,
|
| 508 |
-
metrics=metrics,
|
| 509 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 510 |
-
fallback_bpb=fallback_bpb,
|
| 511 |
-
tps_seen=None,
|
| 512 |
-
)
|
| 513 |
-
|
| 514 |
-
# Keep useful context on trial
|
| 515 |
-
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 516 |
-
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 517 |
-
|
| 518 |
-
return value
|
| 519 |
|
| 520 |
return objective
|
| 521 |
|
| 522 |
|
| 523 |
-
def _objective_hf_job(args: argparse.Namespace):
|
| 524 |
from huggingface_hub import HfApi
|
| 525 |
from huggingface_hub.utils import get_token
|
| 526 |
|
|
@@ -530,9 +530,9 @@ def _objective_hf_job(args: argparse.Namespace):
|
|
| 530 |
f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
|
| 531 |
)
|
| 532 |
|
| 533 |
-
api = HfApi(token=token)
|
| 534 |
-
terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
|
| 535 |
-
effective_min_tps = _effective_min_tps(args)
|
| 536 |
|
| 537 |
def objective(trial: optuna.Trial) -> float:
|
| 538 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
|
@@ -568,14 +568,14 @@ def _objective_hf_job(args: argparse.Namespace):
|
|
| 568 |
info = api.inspect_job(job_id=job.id, token=token, namespace=args.hf_namespace)
|
| 569 |
bootstrap_stage = str(info.status.stage)
|
| 570 |
bootstrap_msg = str(getattr(info.status, "message", "") or "")
|
| 571 |
-
bootstrap_logs = _fetch_job_logs_safe(
|
| 572 |
-
api,
|
| 573 |
-
job_id=job.id,
|
| 574 |
-
token=token,
|
| 575 |
-
namespace=args.hf_namespace,
|
| 576 |
-
retries=2,
|
| 577 |
-
sleep_s=1.0,
|
| 578 |
-
)
|
| 579 |
if bootstrap_stage in {"RUNNING", "COMPLETED"} or bootstrap_logs:
|
| 580 |
break
|
| 581 |
if bootstrap_stage in {"ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}:
|
|
@@ -611,12 +611,12 @@ def _objective_hf_job(args: argparse.Namespace):
|
|
| 611 |
info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
|
| 612 |
stage = str(info.status.stage)
|
| 613 |
terminal_detail = str(getattr(info.status, "message", "")) or terminal_detail
|
| 614 |
-
log_lines = _fetch_job_logs_safe(
|
| 615 |
-
api,
|
| 616 |
-
job_id=job_id,
|
| 617 |
-
token=token,
|
| 618 |
-
namespace=args.hf_namespace,
|
| 619 |
-
)
|
| 620 |
|
| 621 |
m = _parse_metrics_from_log_lines(log_lines)
|
| 622 |
if m is not None:
|
|
@@ -643,66 +643,66 @@ def _objective_hf_job(args: argparse.Namespace):
|
|
| 643 |
except Exception:
|
| 644 |
pass
|
| 645 |
|
| 646 |
-
artifact_paths = _persist_trial_artifacts(
|
| 647 |
-
trial_dir=trial_dir,
|
| 648 |
-
metrics=metrics,
|
| 649 |
-
log_lines=log_lines,
|
| 650 |
-
log_name="hf_job.log",
|
| 651 |
-
metadata={"runner": "hf-job", "hf_job_id": job_id, "hf_stage": stage},
|
| 652 |
-
)
|
| 653 |
-
trial.set_user_attr("hf_stage", stage)
|
| 654 |
-
trial.set_user_attr("hf_log_lines", len(log_lines))
|
| 655 |
if terminal_detail:
|
| 656 |
trial.set_user_attr("hf_status_message", terminal_detail)
|
| 657 |
|
| 658 |
-
fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
|
| 659 |
-
if metrics is None:
|
| 660 |
-
try:
|
| 661 |
-
value = _resolve_objective_metric(
|
| 662 |
-
trial,
|
| 663 |
-
metric_key=args.metric,
|
| 664 |
-
metrics=None,
|
| 665 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 666 |
-
fallback_bpb=fallback_bpb,
|
| 667 |
-
tps_seen=tps_seen,
|
| 668 |
-
)
|
| 669 |
-
if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
|
| 670 |
-
raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
|
| 671 |
-
return value
|
| 672 |
-
except optuna.TrialPruned:
|
| 673 |
-
pass
|
| 674 |
-
if tps_seen is not None:
|
| 675 |
-
trial.set_user_attr("tps", tps_seen)
|
| 676 |
-
detail = f"stage={stage}, logs={len(log_lines)}"
|
| 677 |
-
if terminal_detail:
|
| 678 |
-
detail = f"{detail}, message={terminal_detail}"
|
| 679 |
raise optuna.TrialPruned(f"No metrics found from HF job ({detail})")
|
| 680 |
|
| 681 |
-
metric_key = args.metric
|
| 682 |
|
| 683 |
tps_val = metrics.get("tps")
|
| 684 |
if tps_val is not None:
|
| 685 |
tps_f = float(tps_val)
|
| 686 |
trial.set_user_attr("tps", tps_f)
|
| 687 |
-
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 688 |
-
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 689 |
-
|
| 690 |
-
value = _resolve_objective_metric(
|
| 691 |
-
trial,
|
| 692 |
-
metric_key=metric_key,
|
| 693 |
-
metrics=metrics,
|
| 694 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 695 |
-
fallback_bpb=fallback_bpb,
|
| 696 |
-
tps_seen=tps_seen,
|
| 697 |
-
)
|
| 698 |
-
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 699 |
-
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 700 |
-
return value
|
| 701 |
|
| 702 |
return objective
|
| 703 |
|
| 704 |
|
| 705 |
-
def _objective_hf_launcher(args: argparse.Namespace):
|
| 706 |
from huggingface_hub import HfApi
|
| 707 |
from huggingface_hub.utils import get_token
|
| 708 |
|
|
@@ -712,9 +712,9 @@ def _objective_hf_launcher(args: argparse.Namespace):
|
|
| 712 |
f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
|
| 713 |
)
|
| 714 |
|
| 715 |
-
api = HfApi(token=token)
|
| 716 |
-
terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
|
| 717 |
-
effective_min_tps = _effective_min_tps(args)
|
| 718 |
|
| 719 |
def objective(trial: optuna.Trial) -> float:
|
| 720 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
|
@@ -725,11 +725,11 @@ def _objective_hf_launcher(args: argparse.Namespace):
|
|
| 725 |
local_env = os.environ.copy()
|
| 726 |
local_env.update(env)
|
| 727 |
local_env[args.hf_token_env] = token
|
| 728 |
-
local_env["FEATHER_HF_NAMESPACE"] = args.hf_namespace
|
| 729 |
-
local_env["FEATHER_HF_FLAVOR"] = args.hf_flavor
|
| 730 |
-
local_env["FEATHER_HF_JOB_TIMEOUT"] = args.hf_timeout
|
| 731 |
-
local_env["FEATHER_HF_IMAGE"] = args.hf_image
|
| 732 |
-
local_env["FEATHER_HF_SPACE_REPO"] = _space_repo_from_hf_image(args.hf_image, args.hf_namespace)
|
| 733 |
if args.hf_output_repo:
|
| 734 |
local_env["FEATHER_HF_OUTPUT_REPO"] = args.hf_output_repo
|
| 735 |
else:
|
|
@@ -766,12 +766,12 @@ def _objective_hf_launcher(args: argparse.Namespace):
|
|
| 766 |
info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
|
| 767 |
stage = str(info.status.stage)
|
| 768 |
terminal_detail = str(getattr(info.status, "message", "") or "") or terminal_detail
|
| 769 |
-
log_lines = _fetch_job_logs_safe(
|
| 770 |
-
api,
|
| 771 |
-
job_id=job_id,
|
| 772 |
-
token=token,
|
| 773 |
-
namespace=args.hf_namespace,
|
| 774 |
-
)
|
| 775 |
|
| 776 |
mtr = _parse_metrics_from_log_lines(log_lines)
|
| 777 |
if mtr is not None:
|
|
@@ -796,85 +796,85 @@ def _objective_hf_launcher(args: argparse.Namespace):
|
|
| 796 |
except Exception:
|
| 797 |
pass
|
| 798 |
|
| 799 |
-
artifact_paths = _persist_trial_artifacts(
|
| 800 |
-
trial_dir=trial_dir,
|
| 801 |
-
metrics=metrics,
|
| 802 |
-
log_lines=log_lines,
|
| 803 |
-
log_name="hf_job.log",
|
| 804 |
-
metadata={"runner": "hf-launcher", "hf_job_id": job_id, "hf_stage": stage},
|
| 805 |
-
)
|
| 806 |
-
trial.set_user_attr("hf_stage", stage)
|
| 807 |
-
trial.set_user_attr("hf_log_lines", len(log_lines))
|
| 808 |
if terminal_detail:
|
| 809 |
trial.set_user_attr("hf_status_message", terminal_detail)
|
| 810 |
|
| 811 |
-
fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
|
| 812 |
-
if metrics is None:
|
| 813 |
-
try:
|
| 814 |
-
value = _resolve_objective_metric(
|
| 815 |
-
trial,
|
| 816 |
-
metric_key=args.metric,
|
| 817 |
-
metrics=None,
|
| 818 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 819 |
-
fallback_bpb=fallback_bpb,
|
| 820 |
-
tps_seen=tps_seen,
|
| 821 |
-
)
|
| 822 |
-
if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
|
| 823 |
-
raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
|
| 824 |
-
return value
|
| 825 |
-
except optuna.TrialPruned:
|
| 826 |
-
pass
|
| 827 |
-
if tps_seen is not None:
|
| 828 |
-
trial.set_user_attr("tps", tps_seen)
|
| 829 |
-
detail = f"stage={stage}, logs={len(log_lines)}"
|
| 830 |
-
if terminal_detail:
|
| 831 |
-
detail = f"{detail}, message={terminal_detail}"
|
| 832 |
raise optuna.TrialPruned(f"No metrics found from HF launcher job ({detail})")
|
| 833 |
|
| 834 |
-
metric_key = args.metric
|
| 835 |
|
| 836 |
tps_val = metrics.get("tps")
|
| 837 |
if tps_val is not None:
|
| 838 |
tps_f = float(tps_val)
|
| 839 |
trial.set_user_attr("tps", tps_f)
|
| 840 |
-
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 841 |
-
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 842 |
-
|
| 843 |
-
value = _resolve_objective_metric(
|
| 844 |
-
trial,
|
| 845 |
-
metric_key=metric_key,
|
| 846 |
-
metrics=metrics,
|
| 847 |
-
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 848 |
-
fallback_bpb=fallback_bpb,
|
| 849 |
-
tps_seen=tps_seen,
|
| 850 |
-
)
|
| 851 |
-
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 852 |
-
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 853 |
-
return value
|
| 854 |
|
| 855 |
return objective
|
| 856 |
|
| 857 |
|
| 858 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 859 |
-
routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
|
| 860 |
-
parser = argparse.ArgumentParser(description="Optuna HPO runner for HYDRA train.py")
|
| 861 |
parser.add_argument("--study-name", default="hydra_hpo", help="Optuna study name")
|
| 862 |
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db", help="Optuna storage URL")
|
| 863 |
parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize")
|
| 864 |
parser.add_argument("--metric", default="val_bpb", help="Metric key to optimize from HYDRA metrics")
|
| 865 |
-
parser.add_argument(
|
| 866 |
-
"--min-tps",
|
| 867 |
-
type=float,
|
| 868 |
-
default=50000.0,
|
| 869 |
-
help="TPS floor; prune trials under this value (set 0 to disable)",
|
| 870 |
-
)
|
| 871 |
parser.add_argument("--trials", type=int, default=20, help="Number of Optuna trials")
|
| 872 |
parser.add_argument("--study-timeout", type=int, default=None, help="Study timeout in seconds")
|
| 873 |
parser.add_argument("--trial-time-budget", type=int, default=300, help="HYDRA_TIME_BUDGET passed to each trial")
|
| 874 |
parser.add_argument("--trial-timeout", type=int, default=900, help="Subprocess timeout per trial in seconds")
|
| 875 |
parser.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local", help="Trial execution backend")
|
| 876 |
-
parser.add_argument("--hf-namespace", default=routing_defaults.job_namespace, help="HF namespace for jobs")
|
| 877 |
-
parser.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}", help="HF jobs image")
|
| 878 |
parser.add_argument("--hf-flavor", default="a10g-large", help="HF jobs hardware flavor")
|
| 879 |
parser.add_argument("--hf-timeout", default="25m", help="HF job timeout string")
|
| 880 |
parser.add_argument("--hf-command", default="/app/entrypoint.py", help="Command executed inside HF job")
|
|
@@ -886,23 +886,23 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|
| 886 |
parser.add_argument("--hf-token-env", default="HF_TOKEN", help="Token env key passed as HF job secret")
|
| 887 |
parser.add_argument("--hf-stop-after-metric", action="store_true", default=True, help="Cancel running job after metrics captured")
|
| 888 |
parser.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
|
| 889 |
-
parser.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py", help="Local launcher script for hf-launcher runner")
|
| 890 |
-
parser.add_argument("--hf-output-repo", default=routing_defaults.output_repo, help="Optional FEATHER_HF_OUTPUT_REPO override for launcher runner")
|
| 891 |
-
parser.add_argument("--allow-log-metric-fallback", action="store_true", default=False, help="When metrics JSON is missing, allow val_bpb fallback from latest logged train bpb")
|
| 892 |
-
parser.add_argument("--no-allow-log-metric-fallback", action="store_false", dest="allow_log_metric_fallback")
|
| 893 |
-
parser.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json", help="Path to transfer-learning prior trials JSON")
|
| 894 |
-
parser.add_argument("--apply-priors", action="store_true", default=True, help="Enqueue transfer-learning prior trials before optimize")
|
| 895 |
-
parser.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
|
| 896 |
-
parser.add_argument("--quality-mode-local", action="store_true", default=False, help="Narrow local full-architecture search around the proven quality-winning region")
|
| 897 |
-
parser.add_argument("--quality-anchor-top-k", type=int, default=3, help="Number of top clean priors to enqueue as deterministic local quality anchors")
|
| 898 |
-
parser.add_argument("--seed", type=int, default=42, help="Seed for sampler")
|
| 899 |
parser.add_argument("--n-startup-trials", type=int, default=5, help="Pruner startup trials before pruning")
|
| 900 |
parser.add_argument("--n-warmup-steps", type=int, default=0, help="Pruner warmup steps")
|
| 901 |
parser.add_argument("--patience-trials", type=int, default=None, help="Stop study after this many completed trials without meaningful improvement")
|
| 902 |
parser.add_argument("--min-improvement", type=float, default=0.0, help="Minimum best-value improvement to reset patience")
|
| 903 |
parser.add_argument("--work-dir", type=Path, default=REPO_ROOT / ".tmp" / "optuna", help="Directory for trial artifacts")
|
| 904 |
parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "best_summary.json")
|
| 905 |
-
return parser.parse_args(argv)
|
| 906 |
|
| 907 |
|
| 908 |
def main() -> int:
|
|
@@ -916,22 +916,22 @@ def main() -> int:
|
|
| 916 |
n_warmup_steps=args.n_warmup_steps,
|
| 917 |
)
|
| 918 |
|
| 919 |
-
study = optuna.create_study(
|
| 920 |
-
study_name=args.study_name,
|
| 921 |
-
storage=args.storage,
|
| 922 |
-
load_if_exists=True,
|
| 923 |
-
direction=args.direction,
|
| 924 |
-
sampler=sampler,
|
| 925 |
-
pruner=pruner,
|
| 926 |
-
)
|
| 927 |
-
|
| 928 |
-
enqueued_quality_anchors = _enqueue_quality_anchors(study, args.priors_file, args.quality_mode_local, args.quality_anchor_top_k)
|
| 929 |
-
if enqueued_quality_anchors:
|
| 930 |
-
print(f"[hpo] enqueued {enqueued_quality_anchors} local quality anchors from {args.priors_file}")
|
| 931 |
-
|
| 932 |
-
enqueued_priors = _enqueue_transfer_priors(study, args.priors_file, args.apply_priors)
|
| 933 |
-
if enqueued_priors:
|
| 934 |
-
print(f"[hpo] enqueued {enqueued_priors} transfer priors from {args.priors_file}")
|
| 935 |
|
| 936 |
state: dict[str, Any] = {
|
| 937 |
"best": None,
|
|
@@ -990,29 +990,29 @@ def main() -> int:
|
|
| 990 |
"best_trial_number": study.best_trial.number,
|
| 991 |
"best_trial_user_attrs": study.best_trial.user_attrs,
|
| 992 |
"n_trials": len(study.trials),
|
| 993 |
-
"n_completed": len(completed),
|
| 994 |
-
"patience_trials": args.patience_trials,
|
| 995 |
-
"min_improvement": args.min_improvement,
|
| 996 |
-
"quality_mode_local": args.quality_mode_local,
|
| 997 |
-
"enqueued_quality_anchors": enqueued_quality_anchors,
|
| 998 |
-
"enqueued_priors": enqueued_priors,
|
| 999 |
-
}
|
| 1000 |
-
else:
|
| 1001 |
-
best = {
|
| 1002 |
"study_name": study.study_name,
|
| 1003 |
"direction": args.direction,
|
| 1004 |
"metric": args.metric,
|
| 1005 |
"best_value": None,
|
| 1006 |
"best_params": {},
|
| 1007 |
-
"best_trial_number": None,
|
| 1008 |
-
"best_trial_user_attrs": {},
|
| 1009 |
-
"n_trials": len(study.trials),
|
| 1010 |
-
"n_completed": 0,
|
| 1011 |
-
"quality_mode_local": args.quality_mode_local,
|
| 1012 |
-
"enqueued_quality_anchors": enqueued_quality_anchors,
|
| 1013 |
-
"enqueued_priors": enqueued_priors,
|
| 1014 |
-
"note": "No completed trials with metrics found.",
|
| 1015 |
-
}
|
| 1016 |
args.summary_out.write_text(json.dumps(best, indent=2), encoding="utf-8")
|
| 1017 |
print(json.dumps(best, indent=2))
|
| 1018 |
return 0
|
|
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
+
import subprocess
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
import tempfile
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any
|
| 15 |
+
|
| 16 |
+
import optuna
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_HF_ENV_KEY_RE = re.compile(r"^[A-Z][A-Z0-9_]*$")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 23 |
+
if str(REPO_ROOT) not in sys.path:
|
| 24 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 25 |
+
|
| 26 |
+
from scripts.hf_routing import resolve_routing
|
| 27 |
+
|
| 28 |
+
TRAIN_ENTRYPOINT = REPO_ROOT / "train.py"
|
| 29 |
+
SEARCH_SPACE_KEYS = {
|
| 30 |
+
"d_model",
|
| 31 |
+
"n_layer",
|
| 32 |
+
"d_state",
|
| 33 |
+
"headdim",
|
| 34 |
+
"expand",
|
| 35 |
+
"seq_len",
|
| 36 |
+
"batch_size",
|
| 37 |
+
"grad_accum",
|
| 38 |
+
"matrix_lr",
|
| 39 |
+
"embed_lr",
|
| 40 |
+
"unembed_lr",
|
| 41 |
+
"hyena_layers",
|
| 42 |
+
"engram_n_columns",
|
| 43 |
+
"engram_layer_idx",
|
| 44 |
+
"sdr_target_active",
|
| 45 |
+
"htm_learn_every",
|
| 46 |
+
"htm_subsample",
|
| 47 |
+
"engram_subsample",
|
| 48 |
+
"mamba3_chunk",
|
| 49 |
+
"dropout",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _filter_prior_params(raw: dict[str, Any]) -> dict[str, Any]:
|
| 54 |
+
return {k: v for k, v in raw.items() if k in SEARCH_SPACE_KEYS}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _load_prior_param_sets(path: Path) -> list[dict[str, Any]]:
|
| 58 |
+
if not path.exists():
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 62 |
+
if isinstance(payload, dict):
|
| 63 |
+
rows = payload.get("trials", [])
|
| 64 |
+
elif isinstance(payload, list):
|
| 65 |
+
rows = payload
|
| 66 |
+
else:
|
| 67 |
+
rows = []
|
| 68 |
+
|
| 69 |
+
out: list[dict[str, Any]] = []
|
| 70 |
+
for item in rows:
|
| 71 |
+
if not isinstance(item, dict):
|
| 72 |
+
continue
|
| 73 |
+
params_obj = item.get("params", item)
|
| 74 |
+
if not isinstance(params_obj, dict):
|
| 75 |
+
continue
|
| 76 |
+
filtered = _filter_prior_params(params_obj)
|
| 77 |
+
if filtered:
|
| 78 |
+
out.append(filtered)
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _enqueue_transfer_priors(study: optuna.Study, priors_file: Path, apply_priors: bool) -> int:
|
| 83 |
+
if not apply_priors:
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
priors_raw = _load_prior_param_sets(priors_file)
|
| 87 |
+
if not priors_raw:
|
| 88 |
+
return 0
|
| 89 |
+
|
| 90 |
+
# Deduplicate param sets across merged studies.
|
| 91 |
+
priors: list[dict[str, Any]] = []
|
| 92 |
+
seen: set[str] = set()
|
| 93 |
+
for params in priors_raw:
|
| 94 |
+
key = json.dumps(params, sort_keys=True)
|
| 95 |
+
if key in seen:
|
| 96 |
+
continue
|
| 97 |
+
seen.add(key)
|
| 98 |
+
priors.append(params)
|
| 99 |
+
|
| 100 |
+
enqueued = 0
|
| 101 |
+
for params in priors:
|
| 102 |
+
before = len(study.get_trials(deepcopy=False))
|
| 103 |
+
try:
|
| 104 |
+
study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"}, skip_if_exists=True)
|
| 105 |
+
except TypeError:
|
| 106 |
+
study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"})
|
| 107 |
+
after = len(study.get_trials(deepcopy=False))
|
| 108 |
+
if after > before:
|
| 109 |
+
enqueued += 1
|
| 110 |
+
return enqueued
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _enqueue_quality_anchors(study: optuna.Study, priors_file: Path, quality_mode_local: bool, top_k: int) -> int:
|
| 114 |
+
if not quality_mode_local or top_k <= 0:
|
| 115 |
+
return 0
|
| 116 |
+
|
| 117 |
+
priors = _load_prior_param_sets(priors_file)[:top_k]
|
| 118 |
+
enqueued = 0
|
| 119 |
+
for params in priors:
|
| 120 |
+
before = len(study.get_trials(deepcopy=False))
|
| 121 |
+
try:
|
| 122 |
+
study.enqueue_trial(
|
| 123 |
+
params,
|
| 124 |
+
user_attrs={"seed_source": "quality_anchor"},
|
| 125 |
+
skip_if_exists=True,
|
| 126 |
+
)
|
| 127 |
+
except TypeError:
|
| 128 |
+
study.enqueue_trial(params, user_attrs={"seed_source": "quality_anchor"})
|
| 129 |
+
after = len(study.get_trials(deepcopy=False))
|
| 130 |
+
if after > before:
|
| 131 |
+
enqueued += 1
|
| 132 |
+
return enqueued
|
| 133 |
|
| 134 |
|
| 135 |
def _parse_metrics_from_stdout(stdout: str) -> dict[str, Any] | None:
|
|
|
|
| 164 |
return None
|
| 165 |
|
| 166 |
|
| 167 |
+
def _parse_last_train_bpb_from_logs(lines: list[str]) -> float | None:
|
| 168 |
+
"""Best-effort fallback when final eval crashes before metrics JSON write."""
|
| 169 |
+
last: float | None = None
|
| 170 |
+
for line in lines:
|
| 171 |
+
m = re.search(r"\bbpb=([0-9]+(?:\.[0-9]+)?)", line)
|
| 172 |
if m:
|
| 173 |
+
last = float(m.group(1))
|
| 174 |
+
return last
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _persist_trial_artifacts(
|
| 178 |
+
*,
|
| 179 |
+
trial_dir: Path,
|
| 180 |
+
metrics: dict[str, Any] | None,
|
| 181 |
+
log_lines: list[str] | None,
|
| 182 |
+
log_name: str,
|
| 183 |
+
metadata: dict[str, Any],
|
| 184 |
+
) -> dict[str, str | None]:
|
| 185 |
+
trial_dir.mkdir(parents=True, exist_ok=True)
|
| 186 |
+
metrics_path = trial_dir / "metrics.json"
|
| 187 |
+
log_path = trial_dir / log_name
|
| 188 |
+
manifest_path = trial_dir / "trial_artifacts.json"
|
| 189 |
+
|
| 190 |
+
if metrics is not None:
|
| 191 |
+
metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True), encoding="utf-8")
|
| 192 |
+
if log_lines is not None:
|
| 193 |
+
log_path.write_text("\n".join(log_lines), encoding="utf-8")
|
| 194 |
+
|
| 195 |
+
manifest = {
|
| 196 |
+
**metadata,
|
| 197 |
+
"metrics_path": str(metrics_path) if metrics is not None else None,
|
| 198 |
+
"log_path": str(log_path) if log_lines is not None else None,
|
| 199 |
+
}
|
| 200 |
+
manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
|
| 201 |
+
return {
|
| 202 |
+
"metrics_path": str(metrics_path) if metrics is not None else None,
|
| 203 |
+
"log_path": str(log_path) if log_lines is not None else None,
|
| 204 |
+
"manifest_path": str(manifest_path),
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _resolve_objective_metric(
|
| 209 |
+
trial: optuna.Trial,
|
| 210 |
+
*,
|
| 211 |
+
metric_key: str,
|
| 212 |
+
metrics: dict[str, Any] | None,
|
| 213 |
+
allow_log_metric_fallback: bool,
|
| 214 |
+
fallback_bpb: float | None,
|
| 215 |
+
tps_seen: float | None,
|
| 216 |
+
) -> float:
|
| 217 |
+
"""Resolve the objective value while labeling where it came from.
|
| 218 |
+
|
| 219 |
+
Validation metrics and live training-log fallbacks are intentionally
|
| 220 |
+
different sources. Keeping that distinction in trial attrs prevents a
|
| 221 |
+
skipped/OOM eval from being mistaken for a real validation result.
|
| 222 |
+
"""
|
| 223 |
+
if metrics is None:
|
| 224 |
+
if allow_log_metric_fallback and metric_key == "val_bpb" and fallback_bpb is not None:
|
| 225 |
+
trial.set_user_attr("objective_source", "train_log_fallback")
|
| 226 |
+
trial.set_user_attr("objective_metric", "train_bpb")
|
| 227 |
+
trial.set_user_attr("eval_status", "missing_metrics")
|
| 228 |
+
trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
|
| 229 |
+
if tps_seen is not None:
|
| 230 |
+
trial.set_user_attr("tps", float(tps_seen))
|
| 231 |
+
return float(fallback_bpb)
|
| 232 |
+
trial.set_user_attr("objective_source", "missing_metrics")
|
| 233 |
+
raise optuna.TrialPruned("No metrics payload found")
|
| 234 |
+
|
| 235 |
+
eval_status = str(
|
| 236 |
+
metrics.get(
|
| 237 |
+
"eval_status",
|
| 238 |
+
"completed" if metrics.get("val_bpb") is not None else "unknown",
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
trial.set_user_attr("eval_status", eval_status)
|
| 242 |
+
|
| 243 |
+
if fallback_bpb is not None:
|
| 244 |
+
trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
|
| 245 |
+
|
| 246 |
+
if metric_key not in metrics or metrics[metric_key] is None:
|
| 247 |
+
trial.set_user_attr("objective_source", "missing_metric")
|
| 248 |
+
trial.set_user_attr("objective_metric", metric_key)
|
| 249 |
+
raise optuna.TrialPruned(f"Metric '{metric_key}' missing in metrics payload")
|
| 250 |
+
|
| 251 |
+
value = float(metrics[metric_key])
|
| 252 |
+
trial.set_user_attr("objective_metric", metric_key)
|
| 253 |
+
if metric_key == "val_bpb":
|
| 254 |
+
trial.set_user_attr("objective_source", "final_val")
|
| 255 |
+
trial.set_user_attr("final_val_bpb", value)
|
| 256 |
+
else:
|
| 257 |
+
trial.set_user_attr("objective_source", "metrics_json")
|
| 258 |
+
return value
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _fetch_job_logs_safe(
|
| 262 |
+
api,
|
| 263 |
+
*,
|
| 264 |
+
job_id: str,
|
| 265 |
+
token: str,
|
| 266 |
+
namespace: str,
|
| 267 |
+
retries: int = 3,
|
| 268 |
+
sleep_s: float = 2.0,
|
| 269 |
+
timeout_s: float = 20.0,
|
| 270 |
+
) -> list[str]:
|
| 271 |
+
last_exc: Exception | None = None
|
| 272 |
+
for attempt in range(1, retries + 1):
|
| 273 |
+
try:
|
| 274 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
| 275 |
+
future = executor.submit(
|
| 276 |
+
lambda: list(api.fetch_job_logs(job_id=job_id, follow=False, token=token, namespace=namespace))
|
| 277 |
+
)
|
| 278 |
+
return future.result(timeout=timeout_s)
|
| 279 |
+
except FuturesTimeoutError as exc:
|
| 280 |
+
last_exc = TimeoutError(f"Timed out fetching HF job logs for {job_id} after {timeout_s:.1f}s")
|
| 281 |
+
except Exception as exc: # noqa: BLE001
|
| 282 |
+
last_exc = exc
|
| 283 |
+
if attempt >= retries:
|
| 284 |
+
raise
|
| 285 |
+
time.sleep(sleep_s)
|
| 286 |
+
if last_exc is not None:
|
| 287 |
+
raise last_exc
|
| 288 |
+
return []
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def _effective_min_tps(args: argparse.Namespace) -> float | None:
|
| 292 |
+
min_tps = args.min_tps
|
| 293 |
+
if getattr(args, "quality_mode_local", False) and min_tps == 50000.0:
|
| 294 |
+
return 0.0
|
| 295 |
+
return min_tps
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _trial_env(trial: optuna.Trial, args: argparse.Namespace, metrics_path: Path) -> dict[str, str]:
|
| 299 |
+
env = os.environ.copy()
|
| 300 |
+
full_arch_hpo = env.get("HYDRA_HPO_FULL_ARCH", "0") == "1"
|
| 301 |
+
speed_arch_hpo = full_arch_hpo and env.get("HYDRA_HPO_SPEED_ARCH", "0") == "1"
|
| 302 |
+
quality_mode_local = bool(getattr(args, "quality_mode_local", False))
|
| 303 |
|
| 304 |
# Runtime and reporting
|
| 305 |
env["HYDRA_METRICS_OUT"] = str(metrics_path)
|
| 306 |
env["HYDRA_TIME_BUDGET"] = str(args.trial_time_budget)
|
| 307 |
env["PYTHONUNBUFFERED"] = "1"
|
| 308 |
|
| 309 |
+
# Search space — fully env-driven to match existing training stack.
|
| 310 |
+
if speed_arch_hpo:
|
| 311 |
+
# Full-arch speed mode targets A10 underutilization observed in HPO:
|
| 312 |
+
# low VRAM/MFU, strong BPB from shallow models, and fixed SDR/HTM
|
| 313 |
+
# overhead dominating small microbatches. Keep all components enabled
|
| 314 |
+
# while amortizing overhead over more tokens.
|
| 315 |
+
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96]))
|
| 316 |
+
env["HYDRA_N_LAYER"] = str(trial.suggest_categorical("n_layer", [2]))
|
| 317 |
+
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
|
| 318 |
+
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
|
| 319 |
+
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 320 |
+
elif quality_mode_local and full_arch_hpo:
|
| 321 |
+
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128]))
|
| 322 |
+
env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 2, 3))
|
| 323 |
+
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
|
| 324 |
+
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
|
| 325 |
+
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 326 |
+
else:
|
| 327 |
+
env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128, 160, 192]))
|
| 328 |
+
env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 1, 4))
|
| 329 |
+
env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32, 48]))
|
| 330 |
+
env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [8, 16, 32]))
|
| 331 |
+
env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
|
| 332 |
+
|
| 333 |
+
if speed_arch_hpo:
|
| 334 |
+
seq_len = trial.suggest_categorical("seq_len", [64, 128])
|
| 335 |
+
batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
|
| 336 |
+
grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
|
| 337 |
+
elif quality_mode_local and full_arch_hpo:
|
| 338 |
+
seq_len = trial.suggest_categorical("seq_len", [64])
|
| 339 |
+
batch_size = trial.suggest_categorical("batch_size", [4, 8])
|
| 340 |
+
grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
|
| 341 |
+
else:
|
| 342 |
+
seq_len = trial.suggest_categorical("seq_len", [32, 64])
|
| 343 |
+
batch_size = trial.suggest_categorical("batch_size", [4, 8] if full_arch_hpo else [4, 8, 16])
|
| 344 |
+
grad_accum = trial.suggest_categorical("grad_accum", [1, 4, 8, 16] if full_arch_hpo else [8, 16, 32, 64])
|
| 345 |
# Keep TOTAL_BATCH_SIZE divisible by DEVICE_BATCH_SIZE * MAX_SEQ_LEN.
|
| 346 |
total_batch = batch_size * seq_len * grad_accum
|
| 347 |
env["HYDRA_SEQ_LEN"] = str(seq_len)
|
| 348 |
env["HYDRA_BATCH_SIZE"] = str(batch_size)
|
| 349 |
env["HYDRA_TOTAL_BATCH"] = str(total_batch)
|
| 350 |
|
| 351 |
+
if quality_mode_local and full_arch_hpo:
|
| 352 |
+
env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.008, 0.03, log=True))
|
| 353 |
+
env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.15, 0.6, log=True))
|
| 354 |
+
env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.001, 0.01, log=True))
|
| 355 |
+
else:
|
| 356 |
+
env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.005, 0.2, log=True))
|
| 357 |
+
env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.05, 1.0, log=True))
|
| 358 |
+
env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.0005, 0.02, log=True))
|
| 359 |
+
|
| 360 |
+
if full_arch_hpo:
|
| 361 |
+
env["HYDRA_HYENA_LAYERS"] = ""
|
| 362 |
+
env["HYDRA_ENGRAM_N_COLUMNS"] = str(
|
| 363 |
+
trial.suggest_categorical(
|
| 364 |
+
"engram_n_columns",
|
| 365 |
+
[512, 1024] if (speed_arch_hpo or quality_mode_local) else [512, 1024, 2048],
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
+
env["HYDRA_ENGRAM_LAYER_IDX"] = str(trial.suggest_int("engram_layer_idx", 0, max(0, int(env["HYDRA_N_LAYER"]) - 1)))
|
| 369 |
+
env["HYDRA_SDR_TARGET_ACTIVE"] = str(
|
| 370 |
+
trial.suggest_categorical(
|
| 371 |
+
"sdr_target_active",
|
| 372 |
+
[327] if quality_mode_local else ([164, 327] if speed_arch_hpo else [164, 327, 512]),
|
| 373 |
+
)
|
| 374 |
+
)
|
| 375 |
+
env["HYDRA_HTM_LEARN_EVERY"] = str(
|
| 376 |
+
trial.suggest_categorical("htm_learn_every", [8, 16] if (speed_arch_hpo or quality_mode_local) else [4, 8, 16])
|
| 377 |
+
)
|
| 378 |
+
env["HYDRA_HTM_SUBSAMPLE"] = str(
|
| 379 |
+
trial.suggest_categorical("htm_subsample", [1, 2] if quality_mode_local else ([4, 8, 16] if speed_arch_hpo else [1, 2, 4, 8]))
|
| 380 |
+
)
|
| 381 |
+
env["HYDRA_ENGRAM_SUBSAMPLE"] = str(
|
| 382 |
+
trial.suggest_categorical("engram_subsample", [1, 2] if quality_mode_local else ([1, 2, 4] if speed_arch_hpo else [1]))
|
| 383 |
+
)
|
| 384 |
+
env["HYDRA_MAMBA3_CHUNK"] = str(trial.suggest_categorical("mamba3_chunk", [32, 64]))
|
| 385 |
+
env["HYDRA_DROPOUT"] = str(trial.suggest_categorical("dropout", [0.0, 0.1] if (speed_arch_hpo or quality_mode_local) else [0.0, 0.1, 0.2]))
|
| 386 |
+
else:
|
| 387 |
+
env["HYDRA_HYENA_LAYERS"] = trial.suggest_categorical("hyena_layers", ["", "0", "1", "0,1"])
|
| 388 |
|
| 389 |
# Keep trials alive long enough to emit metrics.
|
| 390 |
env["HYDRA_FAIL_LOSS_THRESHOLD"] = "1000000"
|
| 391 |
env["HYDRA_USE_NEMOTRON"] = os.environ.get("HYDRA_USE_NEMOTRON", "1")
|
| 392 |
env["HYDRA_LOCAL_SHARDS_ONLY"] = os.environ.get("HYDRA_LOCAL_SHARDS_ONLY", "0")
|
| 393 |
# Strict optimal-path defaults (no forced fallback profile).
|
| 394 |
+
env["HYDRA_MUON_COMPILE"] = os.environ.get("HYDRA_MUON_COMPILE", "1")
|
| 395 |
+
env["HYDRA_THROUGHPUT_MODE"] = os.environ.get("HYDRA_THROUGHPUT_MODE", "0" if full_arch_hpo else "1")
|
| 396 |
+
env["HYDRA_FORCE_HTM_CPU"] = os.environ.get("HYDRA_FORCE_HTM_CPU", "0")
|
| 397 |
+
env["HYDRA_ALLOW_SYNTHETIC_RETINA"] = os.environ.get("HYDRA_ALLOW_SYNTHETIC_RETINA", "0")
|
| 398 |
+
env["HYDRA_INERT_MAMBA"] = os.environ.get("HYDRA_INERT_MAMBA", "0")
|
| 399 |
+
env["HYDRA_FASTPATH"] = os.environ.get("HYDRA_FASTPATH", "0" if full_arch_hpo else "1")
|
| 400 |
+
|
| 401 |
+
return env
|
| 402 |
|
| 403 |
|
| 404 |
def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
|
|
|
|
| 410 |
return sanitized
|
| 411 |
|
| 412 |
|
| 413 |
+
def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
|
| 414 |
if args.hf_use_bash:
|
| 415 |
return [["bash", "-lc", args.hf_command]]
|
| 416 |
|
|
|
|
| 432 |
uniq.append(c)
|
| 433 |
return uniq
|
| 434 |
|
| 435 |
+
return [raw.split()]
|
| 436 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
+
def _space_repo_from_hf_image(image: str, namespace: str) -> str:
|
| 439 |
+
prefix = "hf.co/spaces/"
|
| 440 |
+
if image.startswith(prefix):
|
| 441 |
+
return image[len(prefix):]
|
| 442 |
+
return os.environ.get("FEATHER_HF_SPACE_REPO", f"{namespace}/feather-a10-runtime")
|
| 443 |
|
| 444 |
+
|
| 445 |
+
def _objective_local(args: argparse.Namespace):
|
| 446 |
+
effective_min_tps = _effective_min_tps(args)
|
| 447 |
+
|
| 448 |
+
def objective(trial: optuna.Trial) -> float:
|
| 449 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
| 450 |
metrics_path = trial_dir / "metrics.json"
|
| 451 |
|
|
|
|
| 460 |
timeout=args.trial_timeout,
|
| 461 |
)
|
| 462 |
|
| 463 |
+
metrics: dict[str, Any] | None = None
|
| 464 |
if metrics_path.exists():
|
| 465 |
try:
|
| 466 |
metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
|
| 467 |
except json.JSONDecodeError:
|
| 468 |
metrics = None
|
| 469 |
+
if metrics is None:
|
| 470 |
+
metrics = _parse_metrics_from_stdout(proc.stdout)
|
| 471 |
+
|
| 472 |
+
artifact_paths = _persist_trial_artifacts(
|
| 473 |
+
trial_dir=trial_dir,
|
| 474 |
+
metrics=metrics,
|
| 475 |
+
log_lines=(proc.stdout or "").splitlines(),
|
| 476 |
+
log_name="train_stdout.log",
|
| 477 |
+
metadata={"runner": "local", "returncode": proc.returncode},
|
| 478 |
+
)
|
| 479 |
+
(trial_dir / "train_stderr.log").write_text(proc.stderr or "", encoding="utf-8")
|
| 480 |
+
|
| 481 |
+
fallback_bpb = _parse_last_train_bpb_from_logs(proc.stdout.splitlines())
|
| 482 |
+
if metrics is None:
|
| 483 |
+
_resolve_objective_metric(
|
| 484 |
+
trial,
|
| 485 |
+
metric_key=args.metric,
|
| 486 |
+
metrics=None,
|
| 487 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 488 |
+
fallback_bpb=fallback_bpb,
|
| 489 |
+
tps_seen=None,
|
| 490 |
+
)
|
| 491 |
+
raise optuna.TrialPruned("No metrics found (HYDRA_METRICS_OUT/[METRICS_JSON])")
|
| 492 |
|
| 493 |
if proc.returncode != 0:
|
| 494 |
raise optuna.TrialPruned(f"Training failed rc={proc.returncode}")
|
| 495 |
|
| 496 |
+
metric_key = args.metric
|
| 497 |
|
| 498 |
tps_val = metrics.get("tps")
|
| 499 |
if tps_val is not None:
|
| 500 |
tps_f = float(tps_val)
|
| 501 |
trial.set_user_attr("tps", tps_f)
|
| 502 |
+
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 503 |
+
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 504 |
+
|
| 505 |
+
value = _resolve_objective_metric(
|
| 506 |
+
trial,
|
| 507 |
+
metric_key=metric_key,
|
| 508 |
+
metrics=metrics,
|
| 509 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 510 |
+
fallback_bpb=fallback_bpb,
|
| 511 |
+
tps_seen=None,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Keep useful context on trial
|
| 515 |
+
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 516 |
+
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 517 |
+
|
| 518 |
+
return value
|
| 519 |
|
| 520 |
return objective
|
| 521 |
|
| 522 |
|
| 523 |
+
def _objective_hf_job(args: argparse.Namespace):
|
| 524 |
from huggingface_hub import HfApi
|
| 525 |
from huggingface_hub.utils import get_token
|
| 526 |
|
|
|
|
| 530 |
f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
|
| 531 |
)
|
| 532 |
|
| 533 |
+
api = HfApi(token=token)
|
| 534 |
+
terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
|
| 535 |
+
effective_min_tps = _effective_min_tps(args)
|
| 536 |
|
| 537 |
def objective(trial: optuna.Trial) -> float:
|
| 538 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
|
|
|
| 568 |
info = api.inspect_job(job_id=job.id, token=token, namespace=args.hf_namespace)
|
| 569 |
bootstrap_stage = str(info.status.stage)
|
| 570 |
bootstrap_msg = str(getattr(info.status, "message", "") or "")
|
| 571 |
+
bootstrap_logs = _fetch_job_logs_safe(
|
| 572 |
+
api,
|
| 573 |
+
job_id=job.id,
|
| 574 |
+
token=token,
|
| 575 |
+
namespace=args.hf_namespace,
|
| 576 |
+
retries=2,
|
| 577 |
+
sleep_s=1.0,
|
| 578 |
+
)
|
| 579 |
if bootstrap_stage in {"RUNNING", "COMPLETED"} or bootstrap_logs:
|
| 580 |
break
|
| 581 |
if bootstrap_stage in {"ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}:
|
|
|
|
| 611 |
info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
|
| 612 |
stage = str(info.status.stage)
|
| 613 |
terminal_detail = str(getattr(info.status, "message", "")) or terminal_detail
|
| 614 |
+
log_lines = _fetch_job_logs_safe(
|
| 615 |
+
api,
|
| 616 |
+
job_id=job_id,
|
| 617 |
+
token=token,
|
| 618 |
+
namespace=args.hf_namespace,
|
| 619 |
+
)
|
| 620 |
|
| 621 |
m = _parse_metrics_from_log_lines(log_lines)
|
| 622 |
if m is not None:
|
|
|
|
| 643 |
except Exception:
|
| 644 |
pass
|
| 645 |
|
| 646 |
+
artifact_paths = _persist_trial_artifacts(
|
| 647 |
+
trial_dir=trial_dir,
|
| 648 |
+
metrics=metrics,
|
| 649 |
+
log_lines=log_lines,
|
| 650 |
+
log_name="hf_job.log",
|
| 651 |
+
metadata={"runner": "hf-job", "hf_job_id": job_id, "hf_stage": stage},
|
| 652 |
+
)
|
| 653 |
+
trial.set_user_attr("hf_stage", stage)
|
| 654 |
+
trial.set_user_attr("hf_log_lines", len(log_lines))
|
| 655 |
if terminal_detail:
|
| 656 |
trial.set_user_attr("hf_status_message", terminal_detail)
|
| 657 |
|
| 658 |
+
fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
|
| 659 |
+
if metrics is None:
|
| 660 |
+
try:
|
| 661 |
+
value = _resolve_objective_metric(
|
| 662 |
+
trial,
|
| 663 |
+
metric_key=args.metric,
|
| 664 |
+
metrics=None,
|
| 665 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 666 |
+
fallback_bpb=fallback_bpb,
|
| 667 |
+
tps_seen=tps_seen,
|
| 668 |
+
)
|
| 669 |
+
if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
|
| 670 |
+
raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
|
| 671 |
+
return value
|
| 672 |
+
except optuna.TrialPruned:
|
| 673 |
+
pass
|
| 674 |
+
if tps_seen is not None:
|
| 675 |
+
trial.set_user_attr("tps", tps_seen)
|
| 676 |
+
detail = f"stage={stage}, logs={len(log_lines)}"
|
| 677 |
+
if terminal_detail:
|
| 678 |
+
detail = f"{detail}, message={terminal_detail}"
|
| 679 |
raise optuna.TrialPruned(f"No metrics found from HF job ({detail})")
|
| 680 |
|
| 681 |
+
metric_key = args.metric
|
| 682 |
|
| 683 |
tps_val = metrics.get("tps")
|
| 684 |
if tps_val is not None:
|
| 685 |
tps_f = float(tps_val)
|
| 686 |
trial.set_user_attr("tps", tps_f)
|
| 687 |
+
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 688 |
+
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 689 |
+
|
| 690 |
+
value = _resolve_objective_metric(
|
| 691 |
+
trial,
|
| 692 |
+
metric_key=metric_key,
|
| 693 |
+
metrics=metrics,
|
| 694 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 695 |
+
fallback_bpb=fallback_bpb,
|
| 696 |
+
tps_seen=tps_seen,
|
| 697 |
+
)
|
| 698 |
+
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 699 |
+
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 700 |
+
return value
|
| 701 |
|
| 702 |
return objective
|
| 703 |
|
| 704 |
|
| 705 |
+
def _objective_hf_launcher(args: argparse.Namespace):
|
| 706 |
from huggingface_hub import HfApi
|
| 707 |
from huggingface_hub.utils import get_token
|
| 708 |
|
|
|
|
| 712 |
f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
|
| 713 |
)
|
| 714 |
|
| 715 |
+
api = HfApi(token=token)
|
| 716 |
+
terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
|
| 717 |
+
effective_min_tps = _effective_min_tps(args)
|
| 718 |
|
| 719 |
def objective(trial: optuna.Trial) -> float:
|
| 720 |
trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
|
|
|
|
| 725 |
local_env = os.environ.copy()
|
| 726 |
local_env.update(env)
|
| 727 |
local_env[args.hf_token_env] = token
|
| 728 |
+
local_env["FEATHER_HF_NAMESPACE"] = args.hf_namespace
|
| 729 |
+
local_env["FEATHER_HF_FLAVOR"] = args.hf_flavor
|
| 730 |
+
local_env["FEATHER_HF_JOB_TIMEOUT"] = args.hf_timeout
|
| 731 |
+
local_env["FEATHER_HF_IMAGE"] = args.hf_image
|
| 732 |
+
local_env["FEATHER_HF_SPACE_REPO"] = _space_repo_from_hf_image(args.hf_image, args.hf_namespace)
|
| 733 |
if args.hf_output_repo:
|
| 734 |
local_env["FEATHER_HF_OUTPUT_REPO"] = args.hf_output_repo
|
| 735 |
else:
|
|
|
|
| 766 |
info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
|
| 767 |
stage = str(info.status.stage)
|
| 768 |
terminal_detail = str(getattr(info.status, "message", "") or "") or terminal_detail
|
| 769 |
+
log_lines = _fetch_job_logs_safe(
|
| 770 |
+
api,
|
| 771 |
+
job_id=job_id,
|
| 772 |
+
token=token,
|
| 773 |
+
namespace=args.hf_namespace,
|
| 774 |
+
)
|
| 775 |
|
| 776 |
mtr = _parse_metrics_from_log_lines(log_lines)
|
| 777 |
if mtr is not None:
|
|
|
|
| 796 |
except Exception:
|
| 797 |
pass
|
| 798 |
|
| 799 |
+
artifact_paths = _persist_trial_artifacts(
|
| 800 |
+
trial_dir=trial_dir,
|
| 801 |
+
metrics=metrics,
|
| 802 |
+
log_lines=log_lines,
|
| 803 |
+
log_name="hf_job.log",
|
| 804 |
+
metadata={"runner": "hf-launcher", "hf_job_id": job_id, "hf_stage": stage},
|
| 805 |
+
)
|
| 806 |
+
trial.set_user_attr("hf_stage", stage)
|
| 807 |
+
trial.set_user_attr("hf_log_lines", len(log_lines))
|
| 808 |
if terminal_detail:
|
| 809 |
trial.set_user_attr("hf_status_message", terminal_detail)
|
| 810 |
|
| 811 |
+
fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
|
| 812 |
+
if metrics is None:
|
| 813 |
+
try:
|
| 814 |
+
value = _resolve_objective_metric(
|
| 815 |
+
trial,
|
| 816 |
+
metric_key=args.metric,
|
| 817 |
+
metrics=None,
|
| 818 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 819 |
+
fallback_bpb=fallback_bpb,
|
| 820 |
+
tps_seen=tps_seen,
|
| 821 |
+
)
|
| 822 |
+
if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
|
| 823 |
+
raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
|
| 824 |
+
return value
|
| 825 |
+
except optuna.TrialPruned:
|
| 826 |
+
pass
|
| 827 |
+
if tps_seen is not None:
|
| 828 |
+
trial.set_user_attr("tps", tps_seen)
|
| 829 |
+
detail = f"stage={stage}, logs={len(log_lines)}"
|
| 830 |
+
if terminal_detail:
|
| 831 |
+
detail = f"{detail}, message={terminal_detail}"
|
| 832 |
raise optuna.TrialPruned(f"No metrics found from HF launcher job ({detail})")
|
| 833 |
|
| 834 |
+
metric_key = args.metric
|
| 835 |
|
| 836 |
tps_val = metrics.get("tps")
|
| 837 |
if tps_val is not None:
|
| 838 |
tps_f = float(tps_val)
|
| 839 |
trial.set_user_attr("tps", tps_f)
|
| 840 |
+
if effective_min_tps is not None and tps_f < effective_min_tps:
|
| 841 |
+
raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
|
| 842 |
+
|
| 843 |
+
value = _resolve_objective_metric(
|
| 844 |
+
trial,
|
| 845 |
+
metric_key=metric_key,
|
| 846 |
+
metrics=metrics,
|
| 847 |
+
allow_log_metric_fallback=args.allow_log_metric_fallback,
|
| 848 |
+
fallback_bpb=fallback_bpb,
|
| 849 |
+
tps_seen=tps_seen,
|
| 850 |
+
)
|
| 851 |
+
trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
|
| 852 |
+
trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
|
| 853 |
+
return value
|
| 854 |
|
| 855 |
return objective
|
| 856 |
|
| 857 |
|
| 858 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 859 |
+
routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
|
| 860 |
+
parser = argparse.ArgumentParser(description="Optuna HPO runner for HYDRA train.py")
|
| 861 |
parser.add_argument("--study-name", default="hydra_hpo", help="Optuna study name")
|
| 862 |
parser.add_argument("--storage", default="sqlite:///optuna_hpo.db", help="Optuna storage URL")
|
| 863 |
parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize")
|
| 864 |
parser.add_argument("--metric", default="val_bpb", help="Metric key to optimize from HYDRA metrics")
|
| 865 |
+
parser.add_argument(
|
| 866 |
+
"--min-tps",
|
| 867 |
+
type=float,
|
| 868 |
+
default=50000.0,
|
| 869 |
+
help="TPS floor; prune trials under this value (set 0 to disable)",
|
| 870 |
+
)
|
| 871 |
parser.add_argument("--trials", type=int, default=20, help="Number of Optuna trials")
|
| 872 |
parser.add_argument("--study-timeout", type=int, default=None, help="Study timeout in seconds")
|
| 873 |
parser.add_argument("--trial-time-budget", type=int, default=300, help="HYDRA_TIME_BUDGET passed to each trial")
|
| 874 |
parser.add_argument("--trial-timeout", type=int, default=900, help="Subprocess timeout per trial in seconds")
|
| 875 |
parser.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local", help="Trial execution backend")
|
| 876 |
+
parser.add_argument("--hf-namespace", default=routing_defaults.job_namespace, help="HF namespace for jobs")
|
| 877 |
+
parser.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}", help="HF jobs image")
|
| 878 |
parser.add_argument("--hf-flavor", default="a10g-large", help="HF jobs hardware flavor")
|
| 879 |
parser.add_argument("--hf-timeout", default="25m", help="HF job timeout string")
|
| 880 |
parser.add_argument("--hf-command", default="/app/entrypoint.py", help="Command executed inside HF job")
|
|
|
|
| 886 |
parser.add_argument("--hf-token-env", default="HF_TOKEN", help="Token env key passed as HF job secret")
|
| 887 |
parser.add_argument("--hf-stop-after-metric", action="store_true", default=True, help="Cancel running job after metrics captured")
|
| 888 |
parser.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
|
| 889 |
+
parser.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py", help="Local launcher script for hf-launcher runner")
|
| 890 |
+
parser.add_argument("--hf-output-repo", default=routing_defaults.output_repo, help="Optional FEATHER_HF_OUTPUT_REPO override for launcher runner")
|
| 891 |
+
parser.add_argument("--allow-log-metric-fallback", action="store_true", default=False, help="When metrics JSON is missing, allow val_bpb fallback from latest logged train bpb")
|
| 892 |
+
parser.add_argument("--no-allow-log-metric-fallback", action="store_false", dest="allow_log_metric_fallback")
|
| 893 |
+
parser.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json", help="Path to transfer-learning prior trials JSON")
|
| 894 |
+
parser.add_argument("--apply-priors", action="store_true", default=True, help="Enqueue transfer-learning prior trials before optimize")
|
| 895 |
+
parser.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
|
| 896 |
+
parser.add_argument("--quality-mode-local", action="store_true", default=False, help="Narrow local full-architecture search around the proven quality-winning region")
|
| 897 |
+
parser.add_argument("--quality-anchor-top-k", type=int, default=3, help="Number of top clean priors to enqueue as deterministic local quality anchors")
|
| 898 |
+
parser.add_argument("--seed", type=int, default=42, help="Seed for sampler")
|
| 899 |
parser.add_argument("--n-startup-trials", type=int, default=5, help="Pruner startup trials before pruning")
|
| 900 |
parser.add_argument("--n-warmup-steps", type=int, default=0, help="Pruner warmup steps")
|
| 901 |
parser.add_argument("--patience-trials", type=int, default=None, help="Stop study after this many completed trials without meaningful improvement")
|
| 902 |
parser.add_argument("--min-improvement", type=float, default=0.0, help="Minimum best-value improvement to reset patience")
|
| 903 |
parser.add_argument("--work-dir", type=Path, default=REPO_ROOT / ".tmp" / "optuna", help="Directory for trial artifacts")
|
| 904 |
parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "best_summary.json")
|
| 905 |
+
return parser.parse_args(argv)
|
| 906 |
|
| 907 |
|
| 908 |
def main() -> int:
|
|
|
|
| 916 |
n_warmup_steps=args.n_warmup_steps,
|
| 917 |
)
|
| 918 |
|
| 919 |
+
study = optuna.create_study(
|
| 920 |
+
study_name=args.study_name,
|
| 921 |
+
storage=args.storage,
|
| 922 |
+
load_if_exists=True,
|
| 923 |
+
direction=args.direction,
|
| 924 |
+
sampler=sampler,
|
| 925 |
+
pruner=pruner,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
enqueued_quality_anchors = _enqueue_quality_anchors(study, args.priors_file, args.quality_mode_local, args.quality_anchor_top_k)
|
| 929 |
+
if enqueued_quality_anchors:
|
| 930 |
+
print(f"[hpo] enqueued {enqueued_quality_anchors} local quality anchors from {args.priors_file}")
|
| 931 |
+
|
| 932 |
+
enqueued_priors = _enqueue_transfer_priors(study, args.priors_file, args.apply_priors)
|
| 933 |
+
if enqueued_priors:
|
| 934 |
+
print(f"[hpo] enqueued {enqueued_priors} transfer priors from {args.priors_file}")
|
| 935 |
|
| 936 |
state: dict[str, Any] = {
|
| 937 |
"best": None,
|
|
|
|
| 990 |
"best_trial_number": study.best_trial.number,
|
| 991 |
"best_trial_user_attrs": study.best_trial.user_attrs,
|
| 992 |
"n_trials": len(study.trials),
|
| 993 |
+
"n_completed": len(completed),
|
| 994 |
+
"patience_trials": args.patience_trials,
|
| 995 |
+
"min_improvement": args.min_improvement,
|
| 996 |
+
"quality_mode_local": args.quality_mode_local,
|
| 997 |
+
"enqueued_quality_anchors": enqueued_quality_anchors,
|
| 998 |
+
"enqueued_priors": enqueued_priors,
|
| 999 |
+
}
|
| 1000 |
+
else:
|
| 1001 |
+
best = {
|
| 1002 |
"study_name": study.study_name,
|
| 1003 |
"direction": args.direction,
|
| 1004 |
"metric": args.metric,
|
| 1005 |
"best_value": None,
|
| 1006 |
"best_params": {},
|
| 1007 |
+
"best_trial_number": None,
|
| 1008 |
+
"best_trial_user_attrs": {},
|
| 1009 |
+
"n_trials": len(study.trials),
|
| 1010 |
+
"n_completed": 0,
|
| 1011 |
+
"quality_mode_local": args.quality_mode_local,
|
| 1012 |
+
"enqueued_quality_anchors": enqueued_quality_anchors,
|
| 1013 |
+
"enqueued_priors": enqueued_priors,
|
| 1014 |
+
"note": "No completed trials with metrics found.",
|
| 1015 |
+
}
|
| 1016 |
args.summary_out.write_text(json.dumps(best, indent=2), encoding="utf-8")
|
| 1017 |
print(json.dumps(best, indent=2))
|
| 1018 |
return 0
|
overlay/scripts/run_cycle1a.py
CHANGED
|
@@ -1,46 +1,45 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import sys
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
parser
|
| 17 |
-
parser.add_argument("--
|
| 18 |
-
parser.add_argument("--
|
| 19 |
-
parser.add_argument("--
|
| 20 |
-
parser.add_argument("--
|
| 21 |
-
parser.add_argument("--
|
| 22 |
-
parser.add_argument("--
|
| 23 |
-
parser.
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
"--
|
| 31 |
-
"--
|
| 32 |
-
"--
|
| 33 |
-
"--
|
| 34 |
-
"--
|
| 35 |
-
"--
|
| 36 |
-
"--all-
|
| 37 |
-
"--
|
| 38 |
-
*( ["--
|
| 39 |
-
*( ["--
|
| 40 |
-
*( ["--
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from scripts import cycle_executor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 15 |
+
parser = argparse.ArgumentParser(description="Run the full local Cycle 1a benchmark suite")
|
| 16 |
+
parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_runs")
|
| 17 |
+
parser.add_argument("--preflight-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_preflight.json")
|
| 18 |
+
parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_summary.json")
|
| 19 |
+
parser.add_argument("--hydrate-assets", action="store_true")
|
| 20 |
+
parser.add_argument("--require-ready", action="store_true")
|
| 21 |
+
parser.add_argument("--output-repo")
|
| 22 |
+
parser.add_argument("--tokenizer-repo")
|
| 23 |
+
return parser.parse_args(argv)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main(argv: list[str] | None = None) -> int:
|
| 27 |
+
args = parse_args(argv)
|
| 28 |
+
return cycle_executor.main([
|
| 29 |
+
"--benchmark", "GSM8K",
|
| 30 |
+
"--variant", "hydra_full",
|
| 31 |
+
"--seed", "42",
|
| 32 |
+
"--out-dir", str(args.out_dir),
|
| 33 |
+
"--preflight-out", str(args.preflight_out),
|
| 34 |
+
"--summary-out", str(args.summary_out),
|
| 35 |
+
"--all-runnable",
|
| 36 |
+
"--all-benchmarks",
|
| 37 |
+
*( ["--hydrate-assets"] if args.hydrate_assets else [] ),
|
| 38 |
+
*( ["--require-ready"] if args.require_ready else [] ),
|
| 39 |
+
*( ["--output-repo", args.output_repo] if args.output_repo else [] ),
|
| 40 |
+
*( ["--tokenizer-repo", args.tokenizer_repo] if args.tokenizer_repo else [] ),
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
raise SystemExit(main())
|
|
|
overlay/scripts/setup.sh
CHANGED
|
@@ -25,4 +25,3 @@ echo "=== Setup complete ==="
|
|
| 25 |
echo "Run experiments with: uv run train.py"
|
| 26 |
echo "Run orchestrator with: uv run -m harness.orchestrator"
|
| 27 |
echo "Run Phase 1 subsystems with: bash scripts/run_phase1.sh"
|
| 28 |
-
echo "For WSL/CUDA throughput gate: see docs/WSL_TPS_RUNBOOK.md"
|
|
|
|
| 25 |
echo "Run experiments with: uv run train.py"
|
| 26 |
echo "Run orchestrator with: uv run -m harness.orchestrator"
|
| 27 |
echo "Run Phase 1 subsystems with: bash scripts/run_phase1.sh"
|
|
|
overlay/scripts/sweep_depth_aggregate.py
CHANGED
|
@@ -11,77 +11,77 @@ Usage:
|
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
-
import json
|
| 15 |
-
import os
|
| 16 |
-
import statistics
|
| 17 |
-
import re
|
| 18 |
-
import sys
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
|
| 21 |
-
from configs.harness_config import HarnessConfig
|
| 22 |
-
|
| 23 |
-
type MetricValue = float | int | str | bool | None
|
| 24 |
-
type MetricsDict = dict[str, MetricValue]
|
| 25 |
-
|
| 26 |
-
MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
|
| 27 |
-
STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
|
| 28 |
-
MIN_TPS = float(os.environ.get('SWEEP_MIN_TPS', '0'))
|
| 29 |
-
TARGET_TOKENS_M = float(os.environ.get('SWEEP_TARGET_TOKENS_M', '0'))
|
| 30 |
-
TARGET_SECONDS = float(os.environ.get('SWEEP_TARGET_SECONDS', '0'))
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _zero_shot_score(result: MetricsDict) -> float:
|
| 34 |
-
"""Composite quality score for tie-breaking among BPB-near runs."""
|
| 35 |
-
factual = float(result.get('factual_english_score', 0.0) or 0.0)
|
| 36 |
-
instruction = float(result.get('instruction_following_score', 0.0) or 0.0)
|
| 37 |
-
distinct_2 = float(result.get('distinct_2', 0.0) or 0.0)
|
| 38 |
-
repetition = float(result.get('repetition_rate', 0.0) or 0.0)
|
| 39 |
-
return factual + instruction + distinct_2 - repetition
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def _metric_float(result: MetricsDict, key: str, default: float = 0.0) -> float:
|
| 43 |
-
value = result.get(key, default)
|
| 44 |
-
return float(value) if isinstance(value, (int, float)) else default
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _metric_int(result: MetricsDict, key: str, default: int = 0) -> int:
|
| 48 |
-
value = result.get(key, default)
|
| 49 |
-
return int(value) if isinstance(value, int) else default
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def _fixed_budget_ranking(results: dict[int, MetricsDict], *, metric_key: str, target: float) -> list[tuple[int, MetricsDict, float]]:
|
| 53 |
-
ranked: list[tuple[int, MetricsDict, float]] = []
|
| 54 |
-
for n_layer, row in results.items():
|
| 55 |
-
budget_val = row.get(metric_key)
|
| 56 |
-
if not isinstance(budget_val, (int, float)):
|
| 57 |
-
continue
|
| 58 |
-
gap = abs(float(budget_val) - target)
|
| 59 |
-
ranked.append((n_layer, row, gap))
|
| 60 |
-
ranked.sort(
|
| 61 |
-
key=lambda item: (
|
| 62 |
-
item[2],
|
| 63 |
-
_metric_float(item[1], 'val_bpb', float('inf')),
|
| 64 |
-
-_zero_shot_score(item[1]),
|
| 65 |
-
-_metric_float(item[1], 'tps_median', 0.0),
|
| 66 |
-
)
|
| 67 |
-
)
|
| 68 |
-
return ranked
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def _percentile_linear(sorted_values: list[float], pct: float) -> float:
|
| 72 |
-
if not sorted_values:
|
| 73 |
-
return 0.0
|
| 74 |
-
if len(sorted_values) == 1:
|
| 75 |
-
return sorted_values[0]
|
| 76 |
-
rank = (len(sorted_values) - 1) * (pct / 100.0)
|
| 77 |
-
lo = int(rank)
|
| 78 |
-
hi = min(lo + 1, len(sorted_values) - 1)
|
| 79 |
-
frac = rank - lo
|
| 80 |
-
return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
|
| 84 |
-
"""Fetch HF Job stdout and parse the [METRICS_JSON] line."""
|
| 85 |
try:
|
| 86 |
from huggingface_hub import HfApi # type: ignore
|
| 87 |
except Exception as e:
|
|
@@ -94,73 +94,73 @@ def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
|
|
| 94 |
print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
|
| 95 |
return None
|
| 96 |
|
| 97 |
-
last_json = None
|
| 98 |
-
tps_samples: list[tuple[int, int]] = []
|
| 99 |
-
warmup_steps = 25
|
| 100 |
-
for line in logs_stream:
|
| 101 |
-
# HfApi returns strings or JobLogEntry-like objects depending on version.
|
| 102 |
-
text = getattr(line, 'data', None) or str(line)
|
| 103 |
-
|
| 104 |
-
wm = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", text)
|
| 105 |
-
if wm:
|
| 106 |
-
warmup_steps = int(wm.group(1))
|
| 107 |
-
|
| 108 |
-
sm = STEP_TPS_PATTERN.search(text)
|
| 109 |
-
if sm:
|
| 110 |
-
tps_samples.append((int(sm.group(1)), int(sm.group(2))))
|
| 111 |
-
|
| 112 |
-
if '[METRICS_JSON]' in text:
|
| 113 |
-
payload = text.split('[METRICS_JSON]', 1)[1].strip()
|
| 114 |
-
try:
|
| 115 |
-
last_json = json.loads(payload)
|
| 116 |
-
except Exception:
|
| 117 |
-
# Might be truncated on a line boundary — keep looking.
|
| 118 |
-
pass
|
| 119 |
-
if last_json is None:
|
| 120 |
-
return None
|
| 121 |
-
|
| 122 |
-
steady_tps = [float(tps) for step, tps in tps_samples if step >= warmup_steps]
|
| 123 |
-
if not steady_tps:
|
| 124 |
-
steady_tps = [float(tps) for _, tps in tps_samples]
|
| 125 |
-
if steady_tps:
|
| 126 |
-
sorted_tps = sorted(steady_tps)
|
| 127 |
-
last_json['tps_samples'] = len(steady_tps)
|
| 128 |
-
last_json['tps_median'] = float(statistics.median(steady_tps))
|
| 129 |
-
last_json['tps_p10'] = float(_percentile_linear(sorted_tps, 10.0))
|
| 130 |
-
last_json['tps_min'] = float(sorted_tps[0])
|
| 131 |
-
last_json['tps_max'] = float(sorted_tps[-1])
|
| 132 |
-
last_json['tps_warmup_steps'] = int(warmup_steps)
|
| 133 |
-
|
| 134 |
-
return last_json
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def compare(results: dict[int, MetricsDict]) -> None:
|
| 138 |
-
"""Pretty-print comparison across n_layer values."""
|
| 139 |
-
if not results:
|
| 140 |
-
print('[agg] no results')
|
| 141 |
-
return
|
| 142 |
-
sorted_n = sorted(results.keys())
|
| 143 |
-
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 144 |
-
|
| 145 |
-
print('\n=== Active secondary gates ===')
|
| 146 |
-
for metric, thresholds in sorted(secondary_gates.items()):
|
| 147 |
-
print(f' {metric}: {json.dumps(thresholds, sort_keys=True)}')
|
| 148 |
-
|
| 149 |
-
# Top-level scalars
|
| 150 |
-
print('\n=== Top-level scalars ===')
|
| 151 |
hdr = ['metric'] + [f'L={n}' for n in sorted_n]
|
| 152 |
print(' '.join(f'{h:>14}' for h in hdr))
|
| 153 |
-
for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
|
| 154 |
-
'training_seconds', 'peak_vram_mb', 'sdr_target_active',
|
| 155 |
-
'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits',
|
| 156 |
-
'tps_median', 'tps_p10', 'tps_min', 'tps_max', 'tps_samples'):
|
| 157 |
-
row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
|
| 158 |
-
print(' '.join(f'{c:>14}' for c in row))
|
| 159 |
|
| 160 |
# Per-layer panel — one table per metric.
|
| 161 |
print('\n=== Per-layer: delta_ratio (residual contribution) ===')
|
| 162 |
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
|
| 163 |
-
max_depth = max(_metric_int(results[n], 'n_layer', 0) for n in sorted_n)
|
| 164 |
for li in range(max_depth):
|
| 165 |
row = [f'L{li:02d}']
|
| 166 |
for n in sorted_n:
|
|
@@ -197,62 +197,62 @@ def compare(results: dict[int, MetricsDict]) -> None:
|
|
| 197 |
|
| 198 |
# Dead-layer detection
|
| 199 |
print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
|
| 200 |
-
for n in sorted_n:
|
| 201 |
-
r = results[n]
|
| 202 |
-
n_layer = _metric_int(r, 'n_layer', 0)
|
| 203 |
dead = []
|
| 204 |
for li in range(n_layer):
|
| 205 |
v = r.get(f'layer_{li}_delta_ratio')
|
| 206 |
if isinstance(v, (int, float)) and v < 0.02:
|
| 207 |
dead.append(li)
|
| 208 |
-
status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
|
| 209 |
-
print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
|
| 210 |
-
|
| 211 |
-
print('\n=== Throughput-constrained ranking ===')
|
| 212 |
-
ranked = sorted(
|
| 213 |
-
((n, r) for n, r in results.items() if isinstance(r.get('val_bpb'), (int, float))),
|
| 214 |
-
key=lambda x: (
|
| 215 |
-
(MIN_TPS > 0) and (_metric_float(x[1], 'tps_median', 0.0) < MIN_TPS),
|
| 216 |
-
_metric_float(x[1], 'val_bpb', float('inf')),
|
| 217 |
-
-_zero_shot_score(x[1]),
|
| 218 |
-
),
|
| 219 |
-
)
|
| 220 |
-
feasible_count = 0
|
| 221 |
-
for n, r in ranked:
|
| 222 |
-
tps_median = _metric_float(r, 'tps_median', 0.0)
|
| 223 |
-
feasible = (MIN_TPS <= 0) or (tps_median >= MIN_TPS)
|
| 224 |
-
zero_shot_score = _zero_shot_score(r)
|
| 225 |
-
if feasible:
|
| 226 |
-
feasible_count += 1
|
| 227 |
-
print(
|
| 228 |
-
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 229 |
-
f"tps_median={tps_median:.0f} zero_shot_score={zero_shot_score:.4f} feasible={feasible}",
|
| 230 |
-
flush=True,
|
| 231 |
-
)
|
| 232 |
-
if MIN_TPS > 0:
|
| 233 |
-
print(f"[agg] throughput gate: tps_median >= {MIN_TPS:.0f}; feasible={feasible_count}/{len(ranked)}")
|
| 234 |
-
|
| 235 |
-
if TARGET_TOKENS_M > 0:
|
| 236 |
-
print('\n=== Fixed-token champion comparison ===')
|
| 237 |
-
print(f' target_tokens_M={TARGET_TOKENS_M:.4f}')
|
| 238 |
-
for n, r, gap in _fixed_budget_ranking(results, metric_key='total_tokens_M', target=TARGET_TOKENS_M):
|
| 239 |
-
print(
|
| 240 |
-
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 241 |
-
f"total_tokens_M={_metric_float(r, 'total_tokens_M', float('nan')):.4f} "
|
| 242 |
-
f"token_gap_M={gap:.4f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
|
| 243 |
-
flush=True,
|
| 244 |
-
)
|
| 245 |
-
|
| 246 |
-
if TARGET_SECONDS > 0:
|
| 247 |
-
print('\n=== Fixed-time champion comparison ===')
|
| 248 |
-
print(f' target_seconds={TARGET_SECONDS:.1f}')
|
| 249 |
-
for n, r, gap in _fixed_budget_ranking(results, metric_key='training_seconds', target=TARGET_SECONDS):
|
| 250 |
-
print(
|
| 251 |
-
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 252 |
-
f"training_seconds={_metric_float(r, 'training_seconds', float('nan')):.1f} "
|
| 253 |
-
f"time_gap_s={gap:.1f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
|
| 254 |
-
flush=True,
|
| 255 |
-
)
|
| 256 |
|
| 257 |
|
| 258 |
def main() -> int:
|
|
@@ -273,7 +273,7 @@ def main() -> int:
|
|
| 273 |
jobs[n_layer] = job_id
|
| 274 |
|
| 275 |
print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
|
| 276 |
-
results: dict[int, MetricsDict] = {}
|
| 277 |
for n, jid in jobs.items():
|
| 278 |
print(f'[agg] fetching job={jid} (n_layer={n}) ...')
|
| 279 |
m = fetch_metrics_from_job(jid)
|
|
|
|
| 11 |
"""
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import json
|
| 15 |
+
import os
|
| 16 |
+
import statistics
|
| 17 |
+
import re
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
from configs.harness_config import HarnessConfig
|
| 22 |
+
|
| 23 |
+
type MetricValue = float | int | str | bool | None
|
| 24 |
+
type MetricsDict = dict[str, MetricValue]
|
| 25 |
+
|
| 26 |
+
MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
|
| 27 |
+
STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
|
| 28 |
+
MIN_TPS = float(os.environ.get('SWEEP_MIN_TPS', '0'))
|
| 29 |
+
TARGET_TOKENS_M = float(os.environ.get('SWEEP_TARGET_TOKENS_M', '0'))
|
| 30 |
+
TARGET_SECONDS = float(os.environ.get('SWEEP_TARGET_SECONDS', '0'))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _zero_shot_score(result: MetricsDict) -> float:
|
| 34 |
+
"""Composite quality score for tie-breaking among BPB-near runs."""
|
| 35 |
+
factual = float(result.get('factual_english_score', 0.0) or 0.0)
|
| 36 |
+
instruction = float(result.get('instruction_following_score', 0.0) or 0.0)
|
| 37 |
+
distinct_2 = float(result.get('distinct_2', 0.0) or 0.0)
|
| 38 |
+
repetition = float(result.get('repetition_rate', 0.0) or 0.0)
|
| 39 |
+
return factual + instruction + distinct_2 - repetition
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _metric_float(result: MetricsDict, key: str, default: float = 0.0) -> float:
|
| 43 |
+
value = result.get(key, default)
|
| 44 |
+
return float(value) if isinstance(value, (int, float)) else default
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _metric_int(result: MetricsDict, key: str, default: int = 0) -> int:
|
| 48 |
+
value = result.get(key, default)
|
| 49 |
+
return int(value) if isinstance(value, int) else default
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _fixed_budget_ranking(results: dict[int, MetricsDict], *, metric_key: str, target: float) -> list[tuple[int, MetricsDict, float]]:
|
| 53 |
+
ranked: list[tuple[int, MetricsDict, float]] = []
|
| 54 |
+
for n_layer, row in results.items():
|
| 55 |
+
budget_val = row.get(metric_key)
|
| 56 |
+
if not isinstance(budget_val, (int, float)):
|
| 57 |
+
continue
|
| 58 |
+
gap = abs(float(budget_val) - target)
|
| 59 |
+
ranked.append((n_layer, row, gap))
|
| 60 |
+
ranked.sort(
|
| 61 |
+
key=lambda item: (
|
| 62 |
+
item[2],
|
| 63 |
+
_metric_float(item[1], 'val_bpb', float('inf')),
|
| 64 |
+
-_zero_shot_score(item[1]),
|
| 65 |
+
-_metric_float(item[1], 'tps_median', 0.0),
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
return ranked
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _percentile_linear(sorted_values: list[float], pct: float) -> float:
|
| 72 |
+
if not sorted_values:
|
| 73 |
+
return 0.0
|
| 74 |
+
if len(sorted_values) == 1:
|
| 75 |
+
return sorted_values[0]
|
| 76 |
+
rank = (len(sorted_values) - 1) * (pct / 100.0)
|
| 77 |
+
lo = int(rank)
|
| 78 |
+
hi = min(lo + 1, len(sorted_values) - 1)
|
| 79 |
+
frac = rank - lo
|
| 80 |
+
return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
|
| 84 |
+
"""Fetch HF Job stdout and parse the [METRICS_JSON] line."""
|
| 85 |
try:
|
| 86 |
from huggingface_hub import HfApi # type: ignore
|
| 87 |
except Exception as e:
|
|
|
|
| 94 |
print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
|
| 95 |
return None
|
| 96 |
|
| 97 |
+
last_json = None
|
| 98 |
+
tps_samples: list[tuple[int, int]] = []
|
| 99 |
+
warmup_steps = 25
|
| 100 |
+
for line in logs_stream:
|
| 101 |
+
# HfApi returns strings or JobLogEntry-like objects depending on version.
|
| 102 |
+
text = getattr(line, 'data', None) or str(line)
|
| 103 |
+
|
| 104 |
+
wm = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", text)
|
| 105 |
+
if wm:
|
| 106 |
+
warmup_steps = int(wm.group(1))
|
| 107 |
+
|
| 108 |
+
sm = STEP_TPS_PATTERN.search(text)
|
| 109 |
+
if sm:
|
| 110 |
+
tps_samples.append((int(sm.group(1)), int(sm.group(2))))
|
| 111 |
+
|
| 112 |
+
if '[METRICS_JSON]' in text:
|
| 113 |
+
payload = text.split('[METRICS_JSON]', 1)[1].strip()
|
| 114 |
+
try:
|
| 115 |
+
last_json = json.loads(payload)
|
| 116 |
+
except Exception:
|
| 117 |
+
# Might be truncated on a line boundary — keep looking.
|
| 118 |
+
pass
|
| 119 |
+
if last_json is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
steady_tps = [float(tps) for step, tps in tps_samples if step >= warmup_steps]
|
| 123 |
+
if not steady_tps:
|
| 124 |
+
steady_tps = [float(tps) for _, tps in tps_samples]
|
| 125 |
+
if steady_tps:
|
| 126 |
+
sorted_tps = sorted(steady_tps)
|
| 127 |
+
last_json['tps_samples'] = len(steady_tps)
|
| 128 |
+
last_json['tps_median'] = float(statistics.median(steady_tps))
|
| 129 |
+
last_json['tps_p10'] = float(_percentile_linear(sorted_tps, 10.0))
|
| 130 |
+
last_json['tps_min'] = float(sorted_tps[0])
|
| 131 |
+
last_json['tps_max'] = float(sorted_tps[-1])
|
| 132 |
+
last_json['tps_warmup_steps'] = int(warmup_steps)
|
| 133 |
+
|
| 134 |
+
return last_json
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def compare(results: dict[int, MetricsDict]) -> None:
|
| 138 |
+
"""Pretty-print comparison across n_layer values."""
|
| 139 |
+
if not results:
|
| 140 |
+
print('[agg] no results')
|
| 141 |
+
return
|
| 142 |
+
sorted_n = sorted(results.keys())
|
| 143 |
+
secondary_gates = HarnessConfig().to_secondary_gates()
|
| 144 |
+
|
| 145 |
+
print('\n=== Active secondary gates ===')
|
| 146 |
+
for metric, thresholds in sorted(secondary_gates.items()):
|
| 147 |
+
print(f' {metric}: {json.dumps(thresholds, sort_keys=True)}')
|
| 148 |
+
|
| 149 |
+
# Top-level scalars
|
| 150 |
+
print('\n=== Top-level scalars ===')
|
| 151 |
hdr = ['metric'] + [f'L={n}' for n in sorted_n]
|
| 152 |
print(' '.join(f'{h:>14}' for h in hdr))
|
| 153 |
+
for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
|
| 154 |
+
'training_seconds', 'peak_vram_mb', 'sdr_target_active',
|
| 155 |
+
'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits',
|
| 156 |
+
'tps_median', 'tps_p10', 'tps_min', 'tps_max', 'tps_samples'):
|
| 157 |
+
row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
|
| 158 |
+
print(' '.join(f'{c:>14}' for c in row))
|
| 159 |
|
| 160 |
# Per-layer panel — one table per metric.
|
| 161 |
print('\n=== Per-layer: delta_ratio (residual contribution) ===')
|
| 162 |
print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
|
| 163 |
+
max_depth = max(_metric_int(results[n], 'n_layer', 0) for n in sorted_n)
|
| 164 |
for li in range(max_depth):
|
| 165 |
row = [f'L{li:02d}']
|
| 166 |
for n in sorted_n:
|
|
|
|
| 197 |
|
| 198 |
# Dead-layer detection
|
| 199 |
print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
|
| 200 |
+
for n in sorted_n:
|
| 201 |
+
r = results[n]
|
| 202 |
+
n_layer = _metric_int(r, 'n_layer', 0)
|
| 203 |
dead = []
|
| 204 |
for li in range(n_layer):
|
| 205 |
v = r.get(f'layer_{li}_delta_ratio')
|
| 206 |
if isinstance(v, (int, float)) and v < 0.02:
|
| 207 |
dead.append(li)
|
| 208 |
+
status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
|
| 209 |
+
print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
|
| 210 |
+
|
| 211 |
+
print('\n=== Throughput-constrained ranking ===')
|
| 212 |
+
ranked = sorted(
|
| 213 |
+
((n, r) for n, r in results.items() if isinstance(r.get('val_bpb'), (int, float))),
|
| 214 |
+
key=lambda x: (
|
| 215 |
+
(MIN_TPS > 0) and (_metric_float(x[1], 'tps_median', 0.0) < MIN_TPS),
|
| 216 |
+
_metric_float(x[1], 'val_bpb', float('inf')),
|
| 217 |
+
-_zero_shot_score(x[1]),
|
| 218 |
+
),
|
| 219 |
+
)
|
| 220 |
+
feasible_count = 0
|
| 221 |
+
for n, r in ranked:
|
| 222 |
+
tps_median = _metric_float(r, 'tps_median', 0.0)
|
| 223 |
+
feasible = (MIN_TPS <= 0) or (tps_median >= MIN_TPS)
|
| 224 |
+
zero_shot_score = _zero_shot_score(r)
|
| 225 |
+
if feasible:
|
| 226 |
+
feasible_count += 1
|
| 227 |
+
print(
|
| 228 |
+
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 229 |
+
f"tps_median={tps_median:.0f} zero_shot_score={zero_shot_score:.4f} feasible={feasible}",
|
| 230 |
+
flush=True,
|
| 231 |
+
)
|
| 232 |
+
if MIN_TPS > 0:
|
| 233 |
+
print(f"[agg] throughput gate: tps_median >= {MIN_TPS:.0f}; feasible={feasible_count}/{len(ranked)}")
|
| 234 |
+
|
| 235 |
+
if TARGET_TOKENS_M > 0:
|
| 236 |
+
print('\n=== Fixed-token champion comparison ===')
|
| 237 |
+
print(f' target_tokens_M={TARGET_TOKENS_M:.4f}')
|
| 238 |
+
for n, r, gap in _fixed_budget_ranking(results, metric_key='total_tokens_M', target=TARGET_TOKENS_M):
|
| 239 |
+
print(
|
| 240 |
+
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 241 |
+
f"total_tokens_M={_metric_float(r, 'total_tokens_M', float('nan')):.4f} "
|
| 242 |
+
f"token_gap_M={gap:.4f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
|
| 243 |
+
flush=True,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if TARGET_SECONDS > 0:
|
| 247 |
+
print('\n=== Fixed-time champion comparison ===')
|
| 248 |
+
print(f' target_seconds={TARGET_SECONDS:.1f}')
|
| 249 |
+
for n, r, gap in _fixed_budget_ranking(results, metric_key='training_seconds', target=TARGET_SECONDS):
|
| 250 |
+
print(
|
| 251 |
+
f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
|
| 252 |
+
f"training_seconds={_metric_float(r, 'training_seconds', float('nan')):.1f} "
|
| 253 |
+
f"time_gap_s={gap:.1f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
|
| 254 |
+
flush=True,
|
| 255 |
+
)
|
| 256 |
|
| 257 |
|
| 258 |
def main() -> int:
|
|
|
|
| 273 |
jobs[n_layer] = job_id
|
| 274 |
|
| 275 |
print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
|
| 276 |
+
results: dict[int, MetricsDict] = {}
|
| 277 |
for n, jid in jobs.items():
|
| 278 |
print(f'[agg] fetching job={jid} (n_layer={n}) ...')
|
| 279 |
m = fetch_metrics_from_job(jid)
|
overlay/scripts/watch_benchmark_hf_job.py
CHANGED
|
@@ -1,81 +1,33 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
import json
|
| 6 |
-
import
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
continue
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
return
|
| 34 |
-
"job_id": job_id,
|
| 35 |
-
"stage": stage,
|
| 36 |
-
"message": message,
|
| 37 |
-
"log_lines": len(texts),
|
| 38 |
-
"result": parse_benchmark_result_from_logs(texts),
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def wait_for_terminal_snapshot(api, *, job_id: str, token: str, namespace: str, poll_interval: float = 10.0, timeout_s: float = 1800.0) -> dict[str, object]:
|
| 43 |
-
deadline = time.time() + timeout_s
|
| 44 |
-
terminal = {"COMPLETED", "ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}
|
| 45 |
-
while True:
|
| 46 |
-
payload = collect_job_snapshot(api, job_id=job_id, token=token, namespace=namespace)
|
| 47 |
-
if payload["stage"] in terminal:
|
| 48 |
-
return payload
|
| 49 |
-
if time.time() >= deadline:
|
| 50 |
-
return payload
|
| 51 |
-
time.sleep(poll_interval)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def write_watch_summary(path: Path, payload: dict[str, object]) -> None:
|
| 55 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 56 |
-
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 60 |
-
parser = argparse.ArgumentParser(description="Watch or snapshot a remote benchmark job")
|
| 61 |
-
parser.add_argument("--job-id", required=True)
|
| 62 |
-
parser.add_argument("--namespace", default="jackoatmon")
|
| 63 |
-
parser.add_argument("--summary-out", type=Path)
|
| 64 |
-
return parser.parse_args(argv)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def main(argv: list[str] | None = None) -> int:
|
| 68 |
-
args = parse_args(argv)
|
| 69 |
-
token = get_token()
|
| 70 |
-
if not token:
|
| 71 |
-
raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
|
| 72 |
-
api = HfApi(token=token)
|
| 73 |
-
payload = collect_job_snapshot(api, job_id=args.job_id, token=token, namespace=args.namespace)
|
| 74 |
-
if args.summary_out is not None:
|
| 75 |
-
write_watch_summary(args.summary_out, payload)
|
| 76 |
-
print(json.dumps(payload, indent=2, sort_keys=True))
|
| 77 |
-
return 0
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
if __name__ == "__main__":
|
| 81 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_benchmark_result_from_logs(lines: list[str]):
|
| 10 |
+
for line in reversed(lines):
|
| 11 |
+
text = line.strip()
|
| 12 |
+
if not text.startswith("{"):
|
| 13 |
+
continue
|
| 14 |
+
try:
|
| 15 |
+
payload = json.loads(text)
|
| 16 |
+
except json.JSONDecodeError:
|
| 17 |
+
continue
|
| 18 |
+
if isinstance(payload, dict) and "benchmark" in payload:
|
| 19 |
+
return payload
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def write_watch_summary(path: Path, payload: dict[str, object]) -> None:
|
| 24 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 25 |
+
path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 29 |
+
parser = argparse.ArgumentParser(description="Watch or snapshot a remote benchmark job")
|
| 30 |
+
parser.add_argument("--job-id", required=True)
|
| 31 |
+
parser.add_argument("--namespace", default="jackoatmon")
|
| 32 |
+
parser.add_argument("--summary-out", type=Path)
|
| 33 |
+
return parser.parse_args(argv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
overlay/subsystems/htm.py
CHANGED
|
@@ -29,46 +29,46 @@ copy is small compared to the SP/TM compute.
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
import time
|
| 32 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 33 |
-
from typing import Any
|
| 34 |
|
| 35 |
import numpy as np
|
| 36 |
import torch
|
| 37 |
import torch.nn as nn
|
| 38 |
|
| 39 |
-
import htm_rust
|
| 40 |
-
|
| 41 |
-
_HTM_REGION: Any = getattr(htm_rust, "HTMRegion", None)
|
| 42 |
-
_HTM_REGION_GPU: Any = getattr(htm_rust, "HTMRegionGpu", None)
|
| 43 |
-
_HTM_STEP_BATCH_FUSED_CUDA: Any = getattr(htm_rust, "step_batch_fused_cuda", None)
|
| 44 |
-
|
| 45 |
-
# step_many releases the GIL for the whole pass, so multiple threads can
|
| 46 |
-
# truly run regions in parallel — wall-clock scales with B up to CPU cores.
|
| 47 |
-
_HTM_HAS_STEP_MANY = hasattr(_HTM_REGION, "step_many")
|
| 48 |
# GPU backend: built with `maturin develop --features gpu`. One CUDA region
|
| 49 |
# per batch slot, persistent device state for SP synapses. Transparent
|
| 50 |
# fallback to CPU when not available.
|
| 51 |
-
_HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
|
| 52 |
# Zero-copy CUDA path: consumes torch CUDA tensors directly via the
|
| 53 |
# __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
|
| 54 |
# and the D2H of outputs. Huge win when the input SDR already lives on GPU
|
| 55 |
# (which is the train.py hot path — retina is a device buffer).
|
| 56 |
-
_HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_cuda")
|
| 57 |
# Fused megakernel path: collapses all T timesteps + SP + TM into a single
|
| 58 |
# CUDA launch per forward. Replaces global top-K with per-column threshold
|
| 59 |
# inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
|
| 60 |
# Opt-in via env var (default on when available).
|
| 61 |
import os as _os_fused
|
| 62 |
-
_HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_fused_cuda")
|
| 63 |
-
_HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def _is_fused_unavailable_error(exc: RuntimeError) -> bool:
|
| 67 |
-
message = str(exc)
|
| 68 |
-
return (
|
| 69 |
-
"Fused HTM kernel is unavailable" in message
|
| 70 |
-
or "fused HTM kernel disabled for this CUDA arch" in message
|
| 71 |
-
)
|
| 72 |
|
| 73 |
|
| 74 |
class HTMLayer(nn.Module):
|
|
@@ -93,11 +93,11 @@ class HTMLayer(nn.Module):
|
|
| 93 |
learn: bool = True,
|
| 94 |
reset_each_forward: bool = True,
|
| 95 |
use_gpu: bool | None = None,
|
| 96 |
-
) -> None:
|
| 97 |
-
super().__init__()
|
| 98 |
-
self.input_bits = input_bits
|
| 99 |
-
self.n_columns = n_columns
|
| 100 |
-
self.cells_per_column = cells_per_column
|
| 101 |
self.learn = learn
|
| 102 |
self.reset_each_forward = reset_each_forward
|
| 103 |
self._seed_base = seed
|
|
@@ -107,23 +107,23 @@ class HTMLayer(nn.Module):
|
|
| 107 |
# converges since the EMA accumulates over many calls. Env:
|
| 108 |
# HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
|
| 109 |
import os as _os
|
| 110 |
-
self._learn_every = int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))
|
| 111 |
-
self._forward_counter = 0
|
| 112 |
-
force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
|
| 113 |
-
# GPU backend gate. Default: auto-detect — use GPU when the pyo3
|
| 114 |
-
# module was built with --features gpu AND CUDA is actually usable.
|
| 115 |
-
if use_gpu is None:
|
| 116 |
-
use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
|
| 117 |
-
elif use_gpu and not _HTM_HAS_GPU:
|
| 118 |
-
raise RuntimeError(
|
| 119 |
-
"HTMLayer(use_gpu=True) but htm_rust was not built with "
|
| 120 |
-
"--features gpu. Re-run `maturin develop --features gpu`."
|
| 121 |
-
)
|
| 122 |
-
elif use_gpu and force_cpu:
|
| 123 |
-
use_gpu = False
|
| 124 |
-
self._use_gpu = bool(use_gpu)
|
| 125 |
-
cls = _HTM_REGION_GPU if self._use_gpu else _HTM_REGION
|
| 126 |
-
self._region_cls = cls
|
| 127 |
self._regions = [
|
| 128 |
cls(input_bits, n_columns, cells_per_column, seed + i)
|
| 129 |
for i in range(batch_size)
|
|
@@ -144,19 +144,19 @@ class HTMLayer(nn.Module):
|
|
| 144 |
)
|
| 145 |
)
|
| 146 |
|
| 147 |
-
def reset(self) -> None:
|
| 148 |
-
"""Clear TM predictive state on every region (keeps SP synapses)."""
|
| 149 |
-
for r in self._regions:
|
| 150 |
-
r.reset()
|
| 151 |
-
|
| 152 |
-
def _next_learn_flag(self) -> bool:
|
| 153 |
-
self._forward_counter += 1
|
| 154 |
-
return bool(
|
| 155 |
-
self.learn
|
| 156 |
-
and self.training
|
| 157 |
-
and self._learn_every > 0
|
| 158 |
-
and (self._forward_counter % self._learn_every == 0)
|
| 159 |
-
)
|
| 160 |
|
| 161 |
@torch.no_grad()
|
| 162 |
def forward(self, sdr: torch.Tensor) -> torch.Tensor:
|
|
@@ -167,9 +167,9 @@ class HTMLayer(nn.Module):
|
|
| 167 |
if self.reset_each_forward:
|
| 168 |
self.reset()
|
| 169 |
|
| 170 |
-
# Learn-gate: run learn kernels only every N forwards (skips 56% of
|
| 171 |
-
# HTM CUDA time on skip-forwards; Hebbian EMA still converges).
|
| 172 |
-
learn = self._next_learn_flag()
|
| 173 |
|
| 174 |
# Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
|
| 175 |
# so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
|
|
@@ -178,30 +178,30 @@ class HTMLayer(nn.Module):
|
|
| 178 |
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 179 |
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
| 180 |
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
|
| 181 |
-
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
|
| 182 |
-
# Pick fused (1 launch) or legacy (12*T launches) path.
|
| 183 |
-
if _HTM_USE_FUSED:
|
| 184 |
-
try:
|
| 185 |
-
for b in range(B):
|
| 186 |
-
self._regions[b].step_many_fused_cuda(
|
| 187 |
-
sdr_u8[b].__cuda_array_interface__,
|
| 188 |
-
cols_out[b].__cuda_array_interface__,
|
| 189 |
-
anom_out[b].__cuda_array_interface__,
|
| 190 |
-
learn,
|
| 191 |
-
)
|
| 192 |
-
except RuntimeError as exc:
|
| 193 |
-
if not _is_fused_unavailable_error(exc):
|
| 194 |
-
raise
|
| 195 |
-
for b in range(B):
|
| 196 |
-
self._regions[b].step_many_cuda(
|
| 197 |
-
sdr_u8[b].__cuda_array_interface__,
|
| 198 |
-
cols_out[b].__cuda_array_interface__,
|
| 199 |
-
anom_out[b].__cuda_array_interface__,
|
| 200 |
-
learn,
|
| 201 |
-
)
|
| 202 |
-
else:
|
| 203 |
-
for b in range(B):
|
| 204 |
-
self._regions[b].step_many_cuda(
|
| 205 |
sdr_u8[b].__cuda_array_interface__,
|
| 206 |
cols_out[b].__cuda_array_interface__,
|
| 207 |
anom_out[b].__cuda_array_interface__,
|
|
@@ -275,7 +275,7 @@ class HTMLayer(nn.Module):
|
|
| 275 |
self._ensure_regions(B)
|
| 276 |
if self.reset_each_forward:
|
| 277 |
self.reset()
|
| 278 |
-
learn = self._next_learn_flag()
|
| 279 |
|
| 280 |
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 281 |
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
|
@@ -287,61 +287,61 @@ class HTMLayer(nn.Module):
|
|
| 287 |
# grid.y = B processes all regions concurrently — ~B× speedup.
|
| 288 |
# Falls back to sequential dispatch if the batched entry isn't
|
| 289 |
# available (older htm_rust wheel).
|
| 290 |
-
if _HTM_USE_FUSED and _HTM_STEP_BATCH_FUSED_CUDA is not None:
|
| 291 |
# Slice self._regions to match B: _ensure_regions may have
|
| 292 |
# allocated more regions than the current batch size needs
|
| 293 |
# (e.g. factual eval uses smaller batches than training).
|
| 294 |
try:
|
| 295 |
-
_HTM_STEP_BATCH_FUSED_CUDA(
|
| 296 |
self._regions[:B],
|
| 297 |
[sdr_u8[b].__cuda_array_interface__ for b in range(B)],
|
| 298 |
[cols_out[b].__cuda_array_interface__ for b in range(B)],
|
| 299 |
[anom_out[b].__cuda_array_interface__ for b in range(B)],
|
| 300 |
learn,
|
| 301 |
)
|
| 302 |
-
except RuntimeError as _e:
|
| 303 |
-
if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
|
| 304 |
-
# Batch too large for cooperative grid. Fall back to
|
| 305 |
-
# sequential per-region fused launches (each B=1).
|
| 306 |
-
for b in range(B):
|
| 307 |
self._regions[b].step_many_fused_cuda(
|
| 308 |
sdr_u8[b].__cuda_array_interface__,
|
| 309 |
cols_out[b].__cuda_array_interface__,
|
| 310 |
-
anom_out[b].__cuda_array_interface__,
|
| 311 |
-
learn,
|
| 312 |
-
)
|
| 313 |
-
elif _is_fused_unavailable_error(_e):
|
| 314 |
-
for b in range(B):
|
| 315 |
-
self._regions[b].step_many_cuda(
|
| 316 |
-
sdr_u8[b].__cuda_array_interface__,
|
| 317 |
-
cols_out[b].__cuda_array_interface__,
|
| 318 |
-
anom_out[b].__cuda_array_interface__,
|
| 319 |
-
learn,
|
| 320 |
-
)
|
| 321 |
-
else:
|
| 322 |
-
raise
|
| 323 |
-
elif _HTM_USE_FUSED:
|
| 324 |
-
try:
|
| 325 |
-
for b in range(B):
|
| 326 |
-
self._regions[b].step_many_fused_cuda(
|
| 327 |
-
sdr_u8[b].__cuda_array_interface__,
|
| 328 |
-
cols_out[b].__cuda_array_interface__,
|
| 329 |
-
anom_out[b].__cuda_array_interface__,
|
| 330 |
-
learn,
|
| 331 |
-
)
|
| 332 |
-
except RuntimeError as exc:
|
| 333 |
-
if not _is_fused_unavailable_error(exc):
|
| 334 |
-
raise
|
| 335 |
-
for b in range(B):
|
| 336 |
-
self._regions[b].step_many_cuda(
|
| 337 |
-
sdr_u8[b].__cuda_array_interface__,
|
| 338 |
-
cols_out[b].__cuda_array_interface__,
|
| 339 |
-
anom_out[b].__cuda_array_interface__,
|
| 340 |
-
learn,
|
| 341 |
-
)
|
| 342 |
-
else:
|
| 343 |
-
for b in range(B):
|
| 344 |
-
self._regions[b].step_many_cuda(
|
| 345 |
sdr_u8[b].__cuda_array_interface__,
|
| 346 |
cols_out[b].__cuda_array_interface__,
|
| 347 |
anom_out[b].__cuda_array_interface__,
|
|
|
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
import time
|
| 32 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 33 |
+
from typing import Any
|
| 34 |
|
| 35 |
import numpy as np
|
| 36 |
import torch
|
| 37 |
import torch.nn as nn
|
| 38 |
|
| 39 |
+
import htm_rust
|
| 40 |
+
|
| 41 |
+
_HTM_REGION: Any = getattr(htm_rust, "HTMRegion", None)
|
| 42 |
+
_HTM_REGION_GPU: Any = getattr(htm_rust, "HTMRegionGpu", None)
|
| 43 |
+
_HTM_STEP_BATCH_FUSED_CUDA: Any = getattr(htm_rust, "step_batch_fused_cuda", None)
|
| 44 |
+
|
| 45 |
+
# step_many releases the GIL for the whole pass, so multiple threads can
|
| 46 |
+
# truly run regions in parallel — wall-clock scales with B up to CPU cores.
|
| 47 |
+
_HTM_HAS_STEP_MANY = hasattr(_HTM_REGION, "step_many")
|
| 48 |
# GPU backend: built with `maturin develop --features gpu`. One CUDA region
|
| 49 |
# per batch slot, persistent device state for SP synapses. Transparent
|
| 50 |
# fallback to CPU when not available.
|
| 51 |
+
_HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
|
| 52 |
# Zero-copy CUDA path: consumes torch CUDA tensors directly via the
|
| 53 |
# __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
|
| 54 |
# and the D2H of outputs. Huge win when the input SDR already lives on GPU
|
| 55 |
# (which is the train.py hot path — retina is a device buffer).
|
| 56 |
+
_HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_cuda")
|
| 57 |
# Fused megakernel path: collapses all T timesteps + SP + TM into a single
|
| 58 |
# CUDA launch per forward. Replaces global top-K with per-column threshold
|
| 59 |
# inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
|
| 60 |
# Opt-in via env var (default on when available).
|
| 61 |
import os as _os_fused
|
| 62 |
+
_HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_fused_cuda")
|
| 63 |
+
_HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _is_fused_unavailable_error(exc: RuntimeError) -> bool:
|
| 67 |
+
message = str(exc)
|
| 68 |
+
return (
|
| 69 |
+
"Fused HTM kernel is unavailable" in message
|
| 70 |
+
or "fused HTM kernel disabled for this CUDA arch" in message
|
| 71 |
+
)
|
| 72 |
|
| 73 |
|
| 74 |
class HTMLayer(nn.Module):
|
|
|
|
| 93 |
learn: bool = True,
|
| 94 |
reset_each_forward: bool = True,
|
| 95 |
use_gpu: bool | None = None,
|
| 96 |
+
) -> None:
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.input_bits = input_bits
|
| 99 |
+
self.n_columns = n_columns
|
| 100 |
+
self.cells_per_column = cells_per_column
|
| 101 |
self.learn = learn
|
| 102 |
self.reset_each_forward = reset_each_forward
|
| 103 |
self._seed_base = seed
|
|
|
|
| 107 |
# converges since the EMA accumulates over many calls. Env:
|
| 108 |
# HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
|
| 109 |
import os as _os
|
| 110 |
+
self._learn_every = int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))
|
| 111 |
+
self._forward_counter = 0
|
| 112 |
+
force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
|
| 113 |
+
# GPU backend gate. Default: auto-detect — use GPU when the pyo3
|
| 114 |
+
# module was built with --features gpu AND CUDA is actually usable.
|
| 115 |
+
if use_gpu is None:
|
| 116 |
+
use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
|
| 117 |
+
elif use_gpu and not _HTM_HAS_GPU:
|
| 118 |
+
raise RuntimeError(
|
| 119 |
+
"HTMLayer(use_gpu=True) but htm_rust was not built with "
|
| 120 |
+
"--features gpu. Re-run `maturin develop --features gpu`."
|
| 121 |
+
)
|
| 122 |
+
elif use_gpu and force_cpu:
|
| 123 |
+
use_gpu = False
|
| 124 |
+
self._use_gpu = bool(use_gpu)
|
| 125 |
+
cls = _HTM_REGION_GPU if self._use_gpu else _HTM_REGION
|
| 126 |
+
self._region_cls = cls
|
| 127 |
self._regions = [
|
| 128 |
cls(input_bits, n_columns, cells_per_column, seed + i)
|
| 129 |
for i in range(batch_size)
|
|
|
|
| 144 |
)
|
| 145 |
)
|
| 146 |
|
| 147 |
+
def reset(self) -> None:
|
| 148 |
+
"""Clear TM predictive state on every region (keeps SP synapses)."""
|
| 149 |
+
for r in self._regions:
|
| 150 |
+
r.reset()
|
| 151 |
+
|
| 152 |
+
def _next_learn_flag(self) -> bool:
|
| 153 |
+
self._forward_counter += 1
|
| 154 |
+
return bool(
|
| 155 |
+
self.learn
|
| 156 |
+
and self.training
|
| 157 |
+
and self._learn_every > 0
|
| 158 |
+
and (self._forward_counter % self._learn_every == 0)
|
| 159 |
+
)
|
| 160 |
|
| 161 |
@torch.no_grad()
|
| 162 |
def forward(self, sdr: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 167 |
if self.reset_each_forward:
|
| 168 |
self.reset()
|
| 169 |
|
| 170 |
+
# Learn-gate: run learn kernels only every N forwards (skips 56% of
|
| 171 |
+
# HTM CUDA time on skip-forwards; Hebbian EMA still converges).
|
| 172 |
+
learn = self._next_learn_flag()
|
| 173 |
|
| 174 |
# Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
|
| 175 |
# so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
|
|
|
|
| 178 |
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 179 |
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
| 180 |
cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
|
| 181 |
+
anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
|
| 182 |
+
# Pick fused (1 launch) or legacy (12*T launches) path.
|
| 183 |
+
if _HTM_USE_FUSED:
|
| 184 |
+
try:
|
| 185 |
+
for b in range(B):
|
| 186 |
+
self._regions[b].step_many_fused_cuda(
|
| 187 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 188 |
+
cols_out[b].__cuda_array_interface__,
|
| 189 |
+
anom_out[b].__cuda_array_interface__,
|
| 190 |
+
learn,
|
| 191 |
+
)
|
| 192 |
+
except RuntimeError as exc:
|
| 193 |
+
if not _is_fused_unavailable_error(exc):
|
| 194 |
+
raise
|
| 195 |
+
for b in range(B):
|
| 196 |
+
self._regions[b].step_many_cuda(
|
| 197 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 198 |
+
cols_out[b].__cuda_array_interface__,
|
| 199 |
+
anom_out[b].__cuda_array_interface__,
|
| 200 |
+
learn,
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
for b in range(B):
|
| 204 |
+
self._regions[b].step_many_cuda(
|
| 205 |
sdr_u8[b].__cuda_array_interface__,
|
| 206 |
cols_out[b].__cuda_array_interface__,
|
| 207 |
anom_out[b].__cuda_array_interface__,
|
|
|
|
| 275 |
self._ensure_regions(B)
|
| 276 |
if self.reset_each_forward:
|
| 277 |
self.reset()
|
| 278 |
+
learn = self._next_learn_flag()
|
| 279 |
|
| 280 |
if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
|
| 281 |
sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
|
|
|
|
| 287 |
# grid.y = B processes all regions concurrently — ~B× speedup.
|
| 288 |
# Falls back to sequential dispatch if the batched entry isn't
|
| 289 |
# available (older htm_rust wheel).
|
| 290 |
+
if _HTM_USE_FUSED and _HTM_STEP_BATCH_FUSED_CUDA is not None:
|
| 291 |
# Slice self._regions to match B: _ensure_regions may have
|
| 292 |
# allocated more regions than the current batch size needs
|
| 293 |
# (e.g. factual eval uses smaller batches than training).
|
| 294 |
try:
|
| 295 |
+
_HTM_STEP_BATCH_FUSED_CUDA(
|
| 296 |
self._regions[:B],
|
| 297 |
[sdr_u8[b].__cuda_array_interface__ for b in range(B)],
|
| 298 |
[cols_out[b].__cuda_array_interface__ for b in range(B)],
|
| 299 |
[anom_out[b].__cuda_array_interface__ for b in range(B)],
|
| 300 |
learn,
|
| 301 |
)
|
| 302 |
+
except RuntimeError as _e:
|
| 303 |
+
if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
|
| 304 |
+
# Batch too large for cooperative grid. Fall back to
|
| 305 |
+
# sequential per-region fused launches (each B=1).
|
| 306 |
+
for b in range(B):
|
| 307 |
self._regions[b].step_many_fused_cuda(
|
| 308 |
sdr_u8[b].__cuda_array_interface__,
|
| 309 |
cols_out[b].__cuda_array_interface__,
|
| 310 |
+
anom_out[b].__cuda_array_interface__,
|
| 311 |
+
learn,
|
| 312 |
+
)
|
| 313 |
+
elif _is_fused_unavailable_error(_e):
|
| 314 |
+
for b in range(B):
|
| 315 |
+
self._regions[b].step_many_cuda(
|
| 316 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 317 |
+
cols_out[b].__cuda_array_interface__,
|
| 318 |
+
anom_out[b].__cuda_array_interface__,
|
| 319 |
+
learn,
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
raise
|
| 323 |
+
elif _HTM_USE_FUSED:
|
| 324 |
+
try:
|
| 325 |
+
for b in range(B):
|
| 326 |
+
self._regions[b].step_many_fused_cuda(
|
| 327 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 328 |
+
cols_out[b].__cuda_array_interface__,
|
| 329 |
+
anom_out[b].__cuda_array_interface__,
|
| 330 |
+
learn,
|
| 331 |
+
)
|
| 332 |
+
except RuntimeError as exc:
|
| 333 |
+
if not _is_fused_unavailable_error(exc):
|
| 334 |
+
raise
|
| 335 |
+
for b in range(B):
|
| 336 |
+
self._regions[b].step_many_cuda(
|
| 337 |
+
sdr_u8[b].__cuda_array_interface__,
|
| 338 |
+
cols_out[b].__cuda_array_interface__,
|
| 339 |
+
anom_out[b].__cuda_array_interface__,
|
| 340 |
+
learn,
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
for b in range(B):
|
| 344 |
+
self._regions[b].step_many_cuda(
|
| 345 |
sdr_u8[b].__cuda_array_interface__,
|
| 346 |
cols_out[b].__cuda_array_interface__,
|
| 347 |
anom_out[b].__cuda_array_interface__,
|