Jackoatmon commited on
Commit
951f760
·
verified ·
1 Parent(s): 6a47c48

Update Feather training runtime image

Browse files
Files changed (41) hide show
  1. Dockerfile +7 -19
  2. entrypoint.py +9 -67
  3. mamba_ssm_init.py +3 -35
  4. overlay/configs/harness_config.py +47 -47
  5. overlay/harness/eval_agent.py +188 -188
  6. overlay/harness/orchestrator.py +16 -16
  7. overlay/htm_rust/src/gpu/fused.rs +73 -73
  8. overlay/hydra/eval.py +1 -8
  9. overlay/hydra/model.py +296 -296
  10. overlay/hydra/training.py +387 -387
  11. overlay/prepare.py +60 -60
  12. overlay/prepare_nemotron.py +159 -162
  13. overlay/scripts/__init__.py +1 -1
  14. overlay/scripts/audit_overlay_sync.py +100 -100
  15. overlay/scripts/benchmark_assets.py +62 -124
  16. overlay/scripts/benchmark_checkpoint.py +19 -118
  17. overlay/scripts/benchmark_checkpoint_report.py +50 -50
  18. overlay/scripts/benchmark_contract.py +67 -67
  19. overlay/scripts/benchmark_datasets.py +18 -190
  20. overlay/scripts/benchmark_hyena_stack.py +41 -66
  21. overlay/scripts/benchmark_preflight.py +31 -35
  22. overlay/scripts/benchmark_runner.py +248 -327
  23. overlay/scripts/benchmark_suite.py +84 -84
  24. overlay/scripts/bootstrap_benchmark_env.py +63 -63
  25. overlay/scripts/cycle1a_report.py +52 -52
  26. overlay/scripts/cycle_executor.py +312 -332
  27. overlay/scripts/export_hpo_priors.py +94 -94
  28. overlay/scripts/hf_routing.py +94 -94
  29. overlay/scripts/hpo_component_report.py +130 -130
  30. overlay/scripts/hpo_leaderboard.py +156 -156
  31. overlay/scripts/hpo_orchestrator.py +118 -118
  32. overlay/scripts/hpo_retest.py +151 -151
  33. overlay/scripts/hydra_generation.py +180 -183
  34. overlay/scripts/launch_benchmark_hf_job.py +157 -222
  35. overlay/scripts/launch_feather_hf_job.py +337 -343
  36. overlay/scripts/optuna_hpo.py +575 -575
  37. overlay/scripts/run_cycle1a.py +45 -46
  38. overlay/scripts/setup.sh +0 -1
  39. overlay/scripts/sweep_depth_aggregate.py +184 -184
  40. overlay/scripts/watch_benchmark_hf_job.py +33 -81
  41. overlay/subsystems/htm.py +128 -128
Dockerfile CHANGED
@@ -1,6 +1,4 @@
1
- FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
-
3
- ARG HTM_CUDA_ARCH=sm_86
4
 
5
  ENV DEBIAN_FRONTEND=noninteractive \
6
  PIP_NO_CACHE_DIR=1 \
@@ -107,22 +105,12 @@ COPY overlay /workspace/feather
107
  COPY entrypoint.py /app/entrypoint.py
108
  WORKDIR /workspace/feather
109
 
110
- RUN python - <<'PY'
111
- from pathlib import Path
112
- for sh in Path('/workspace/feather/scripts').glob('*.sh'):
113
- raw = sh.read_bytes()
114
- norm = raw.replace(b'\r\n', b'\n')
115
- if norm != raw:
116
- sh.write_bytes(norm)
117
- PY
118
-
119
- RUN python -m py_compile hydra/training.py prepare.py train.py && \
120
- bash -n scripts/run_domain_expanded_pretrain.sh
121
 
122
- RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
123
- export HTM_CUDA_ARCH=${HTM_CUDA_ARCH} && \
124
- export CARGO_BUILD_JOBS=1 && \
125
- maturin build --release -j 1 --features gpu --manifest-path htm_rust/Cargo.toml && \
126
- pip install htm_rust/target/wheels/htm_rust-*.whl
127
 
128
  CMD ["python", "/app/entrypoint.py"]
 
1
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
 
 
2
 
3
  ENV DEBIAN_FRONTEND=noninteractive \
4
  PIP_NO_CACHE_DIR=1 \
 
105
  COPY entrypoint.py /app/entrypoint.py
106
  WORKDIR /workspace/feather
107
 
108
+ RUN python -m py_compile hydra/training.py prepare.py train.py && \
109
+ bash -n scripts/run_domain_expanded_pretrain.sh
 
 
 
 
 
 
 
 
 
110
 
111
+ RUN export LD_LIBRARY_PATH=/usr/local/cuda/lib64:${LD_LIBRARY_PATH} && \
112
+ export HTM_CUDA_ARCH=sm_90 && \
113
+ maturin build --release --features gpu --manifest-path htm_rust/Cargo.toml && \
114
+ pip install htm_rust/target/wheels/htm_rust-*.whl
 
115
 
116
  CMD ["python", "/app/entrypoint.py"]
entrypoint.py CHANGED
@@ -68,11 +68,7 @@ try:
68
  except ImportError:
69
  print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
70
 
71
- from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
72
- if '/workspace/feather' not in sys.path: # noqa: E402
73
- sys.path.insert(0, '/workspace/feather')
74
- from scripts.benchmark_assets import hydrate_benchmark_assets # noqa: E402
75
- from subsystems.sdr_retina import build_retina # noqa: E402
76
 
77
  REPO_ROOT = Path('/workspace/feather')
78
  CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
@@ -114,7 +110,7 @@ def _start_health_server() -> HTTPServer:
114
  return server
115
 
116
 
117
- def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
118
  if not path.exists():
119
  print(f'[upload] skip missing {path}', flush=True)
120
  return
@@ -124,20 +120,7 @@ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
124
  repo_id=OUTPUT_REPO,
125
  repo_type='model',
126
  )
127
- print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
128
-
129
-
130
- def build_benchmark_mode_command() -> list[str]:
131
- return [
132
- 'python',
133
- str(REPO_ROOT / 'scripts' / 'benchmark_runner.py'),
134
- '--benchmark', os.environ.get('HYDRA_BENCHMARK_NAME', 'GSM8K'),
135
- '--generator-mode', 'hydra',
136
- '--variant', os.environ.get('HYDRA_BENCHMARK_VARIANT', 'hydra_full'),
137
- '--seed', os.environ.get('HYDRA_SEED', '42'),
138
- '--out', str(REPO_ROOT / 'benchmark_result.json'),
139
- '--ledger', str(REPO_ROOT / 'benchmark_ledger.json'),
140
- ] + sys.argv[1:]
141
 
142
 
143
  def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
@@ -175,7 +158,7 @@ def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
175
  print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
176
 
177
 
178
- def run_job_mode() -> int:
179
  os.chdir(REPO_ROOT)
180
  os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
181
  os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
@@ -220,46 +203,7 @@ def run_job_mode() -> int:
220
  else:
221
  print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
222
 
223
- return proc.returncode
224
-
225
-
226
- def run_benchmark_mode() -> int:
227
- os.chdir(REPO_ROOT)
228
- os.environ.setdefault('HYDRA_USE_NEMOTRON', '1')
229
- if TOKEN:
230
- try:
231
- hydrate_benchmark_assets(
232
- cache_dir=CACHE_ROOT,
233
- output_repo=OUTPUT_REPO,
234
- tokenizer_repo=os.environ.get('HYDRA_TOKENIZER_CACHE_REPO', OUTPUT_REPO),
235
- token=TOKEN,
236
- )
237
- except Exception as e:
238
- print(f'[benchmark] asset hydrate warning: {type(e).__name__}: {e}', flush=True)
239
- try:
240
- build_retina()
241
- except Exception as e:
242
- print(f'[benchmark] retina materialize warning: {type(e).__name__}: {e}', flush=True)
243
- cmd = build_benchmark_mode_command()
244
- print(f'[benchmark] command={cmd}', flush=True)
245
- proc = subprocess.run(cmd, check=False)
246
-
247
- if TOKEN:
248
- api = HfApi(token=TOKEN)
249
- try:
250
- api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
251
- except Exception as e:
252
- print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
253
- prefix = f'jobs/{JOB_ID}'
254
- try:
255
- upload_artifact(api, REPO_ROOT / 'benchmark_result.json', f'{prefix}/benchmark_result.json')
256
- upload_artifact(api, REPO_ROOT / 'benchmark_ledger.json', f'{prefix}/benchmark_ledger.json')
257
- except Exception as e:
258
- print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
259
- else:
260
- print('[upload] HF_TOKEN not set; skipping benchmark artifact upload', flush=True)
261
-
262
- return proc.returncode
263
 
264
 
265
  def run_space_mode() -> int:
@@ -273,12 +217,10 @@ def run_space_mode() -> int:
273
  server.server_close()
274
 
275
 
276
- def main() -> int:
277
- if RUNTIME_MODE == 'job':
278
- return run_job_mode()
279
- if RUNTIME_MODE == 'benchmark':
280
- return run_benchmark_mode()
281
- return run_space_mode()
282
 
283
 
284
  if __name__ == '__main__':
 
68
  except ImportError:
69
  print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
70
 
71
+ from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
 
 
 
 
72
 
73
  REPO_ROOT = Path('/workspace/feather')
74
  CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
 
110
  return server
111
 
112
 
113
+ def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
114
  if not path.exists():
115
  print(f'[upload] skip missing {path}', flush=True)
116
  return
 
120
  repo_id=OUTPUT_REPO,
121
  repo_type='model',
122
  )
123
+ print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
 
158
  print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
159
 
160
 
161
+ def run_job_mode() -> int:
162
  os.chdir(REPO_ROOT)
163
  os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
164
  os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
 
203
  else:
204
  print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
205
 
206
+ return proc.returncode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
 
209
  def run_space_mode() -> int:
 
217
  server.server_close()
218
 
219
 
220
+ def main() -> int:
221
+ if RUNTIME_MODE == 'job':
222
+ return run_job_mode()
223
+ return run_space_mode()
 
 
224
 
225
 
226
  if __name__ == '__main__':
mamba_ssm_init.py CHANGED
@@ -24,8 +24,8 @@ mamba_inner_fn = None
24
  # stub is never actually invoked at runtime because the codebase does not use
25
  # torch.compile — but importing torch._inductor.* still requires the symbol to
26
  # exist at module load time.
27
- import triton as _triton # noqa: E402
28
- if not hasattr(_triton, "set_allocator"):
29
  def _noop_set_allocator(_fn): # pragma: no cover
30
  return None
31
  _triton.set_allocator = _noop_set_allocator
@@ -53,39 +53,7 @@ if not hasattr(_tcc, "triton_key"):
53
  def _triton_key_shim():
54
  import triton as _t
55
  return f"triton-{_t.__version__}-shim"
56
- _tcc.triton_key = _triton_key_shim
57
-
58
- # Triton 3.5 wheels can occasionally load with an empty backend registry in
59
- # HF Jobs environments (driver.active -> "0 active drivers"), even though the
60
- # NVIDIA backend module is present and CudaDriver.is_active() is True.
61
- # Patch _create_driver to directly select CudaDriver when registry discovery
62
- # returns empty.
63
- import importlib as _importlib # noqa: E402
64
- _triton_driver_mod = _importlib.import_module("triton.runtime.driver")
65
- if getattr(_triton_driver_mod, "backends", None) == {}:
66
- from triton.backends.nvidia import driver as _nvidia_driver # noqa: E402
67
-
68
- def _create_driver_shim():
69
- if hasattr(_nvidia_driver, "CudaDriver") and _nvidia_driver.CudaDriver.is_active():
70
- return _nvidia_driver.CudaDriver()
71
- raise RuntimeError(
72
- "Triton backend registry is empty and NVIDIA CudaDriver is not active"
73
- )
74
-
75
- _triton_driver_mod._create_driver = _create_driver_shim
76
- if hasattr(_triton_driver_mod, "driver") and hasattr(_triton_driver_mod.driver, "reset_active"):
77
- _triton_driver_mod.driver.reset_active()
78
-
79
- _triton_compiler_mod = _importlib.import_module("triton.compiler.compiler")
80
- if getattr(_triton_compiler_mod, "backends", None) == {}:
81
- from triton.backends import Backend as _Backend # noqa: E402
82
- from triton.backends.nvidia.compiler import CUDABackend as _CUDABackend # noqa: E402
83
- from triton.backends.nvidia.driver import CudaDriver as _CudaDriver # noqa: E402
84
-
85
- _triton_compiler_mod.backends["nvidia"] = _Backend(
86
- compiler=_CUDABackend,
87
- driver=_CudaDriver,
88
- )
89
 
90
  # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
91
  # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
 
24
  # stub is never actually invoked at runtime because the codebase does not use
25
  # torch.compile — but importing torch._inductor.* still requires the symbol to
26
  # exist at module load time.
27
+ import triton as _triton # noqa: E402
28
+ if not hasattr(_triton, "set_allocator"):
29
  def _noop_set_allocator(_fn): # pragma: no cover
30
  return None
31
  _triton.set_allocator = _noop_set_allocator
 
53
  def _triton_key_shim():
54
  import triton as _t
55
  return f"triton-{_t.__version__}-shim"
56
+ _tcc.triton_key = _triton_key_shim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Suppress torch.compile/_dynamo errors globally — we don't rely on torch.compile
59
  # for performance in this codebase (Muon + mamba3 CUDA kernels already fused),
overlay/configs/harness_config.py CHANGED
@@ -1,13 +1,13 @@
1
- """Harness configuration for HYDRA's self-evolving outer loop."""
2
- from typing import Literal
3
-
4
- from pydantic import BaseModel, Field
5
-
6
- type GateThresholds = dict[str, float]
7
- type GateConfig = dict[str, GateThresholds]
8
 
 
9
 
10
- class HarnessConfig(BaseModel):
 
 
 
 
11
  """Configuration for the HYDRA harness behavior."""
12
 
13
  # Inner loop
@@ -50,19 +50,19 @@ class HarnessConfig(BaseModel):
50
  default=5.0, description="Max % regression from best known val_bpb"
51
  )
52
 
53
- # Keep/discard criteria
54
- primary_metric: str = "val_bpb"
55
- secondary_metrics: GateConfig = Field(
56
- default_factory=lambda: {
57
- "mhc_spectral_norm": {"max": 2.0},
58
- "engram_hit_rate": {"min": 0.1},
59
- "factual_english_score": {"min": 0.5},
60
- "instruction_following_score": {"min": 0.5},
61
- "distinct_2": {"min": 0.1},
62
- "repetition_rate": {"max": 0.2},
63
- "hestia_quant_error": {"max": 0.05},
64
- }
65
- )
66
 
67
  # Experiment execution
68
  experiment_timeout: int = Field(
@@ -80,29 +80,29 @@ class HarnessConfig(BaseModel):
80
  gate_mhc_spectral_norm: float | None = Field(
81
  default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
82
  )
83
- gate_engram_hit_rate: float | None = Field(
84
- default=None, description="Min engram_hit_rate for keep (None=disabled)"
85
- )
86
- gate_tps_median: float | None = Field(
87
- default=None,
88
- description="Min steady-state tps_median for keep (None=disabled)",
89
- )
90
- gate_tps_p10: float | None = Field(
91
- default=None,
92
- description="Min steady-state tps_p10 for keep (None=disabled)",
93
- )
94
-
95
- def to_secondary_gates(self) -> GateConfig:
96
- """Build active keep/discard gates from defaults plus gate_* overrides."""
97
- gates = {metric: thresholds.copy() for metric, thresholds in self.secondary_metrics.items()}
98
-
99
- if self.gate_mhc_spectral_norm is not None:
100
- gates.setdefault("mhc_spectral_norm", {})["max"] = self.gate_mhc_spectral_norm
101
- if self.gate_engram_hit_rate is not None:
102
- gates.setdefault("engram_hit_rate", {})["min"] = self.gate_engram_hit_rate
103
- if self.gate_tps_median is not None:
104
- gates.setdefault("tps_median", {})["min"] = self.gate_tps_median
105
- if self.gate_tps_p10 is not None:
106
- gates.setdefault("tps_p10", {})["min"] = self.gate_tps_p10
107
-
108
- return gates
 
1
+ """Harness configuration for HYDRA's self-evolving outer loop."""
2
+ from typing import Literal
 
 
 
 
 
3
 
4
+ from pydantic import BaseModel, Field
5
 
6
+ type GateThresholds = dict[str, float]
7
+ type GateConfig = dict[str, GateThresholds]
8
+
9
+
10
+ class HarnessConfig(BaseModel):
11
  """Configuration for the HYDRA harness behavior."""
12
 
13
  # Inner loop
 
50
  default=5.0, description="Max % regression from best known val_bpb"
51
  )
52
 
53
+ # Keep/discard criteria
54
+ primary_metric: str = "val_bpb"
55
+ secondary_metrics: GateConfig = Field(
56
+ default_factory=lambda: {
57
+ "mhc_spectral_norm": {"max": 2.0},
58
+ "engram_hit_rate": {"min": 0.1},
59
+ "factual_english_score": {"min": 0.5},
60
+ "instruction_following_score": {"min": 0.5},
61
+ "distinct_2": {"min": 0.1},
62
+ "repetition_rate": {"max": 0.2},
63
+ "hestia_quant_error": {"max": 0.05},
64
+ }
65
+ )
66
 
67
  # Experiment execution
68
  experiment_timeout: int = Field(
 
80
  gate_mhc_spectral_norm: float | None = Field(
81
  default=None, description="Max mhc_spectral_norm for keep (None=disabled)"
82
  )
83
+ gate_engram_hit_rate: float | None = Field(
84
+ default=None, description="Min engram_hit_rate for keep (None=disabled)"
85
+ )
86
+ gate_tps_median: float | None = Field(
87
+ default=None,
88
+ description="Min steady-state tps_median for keep (None=disabled)",
89
+ )
90
+ gate_tps_p10: float | None = Field(
91
+ default=None,
92
+ description="Min steady-state tps_p10 for keep (None=disabled)",
93
+ )
94
+
95
+ def to_secondary_gates(self) -> GateConfig:
96
+ """Build active keep/discard gates from defaults plus gate_* overrides."""
97
+ gates = {metric: thresholds.copy() for metric, thresholds in self.secondary_metrics.items()}
98
+
99
+ if self.gate_mhc_spectral_norm is not None:
100
+ gates.setdefault("mhc_spectral_norm", {})["max"] = self.gate_mhc_spectral_norm
101
+ if self.gate_engram_hit_rate is not None:
102
+ gates.setdefault("engram_hit_rate", {})["min"] = self.gate_engram_hit_rate
103
+ if self.gate_tps_median is not None:
104
+ gates.setdefault("tps_median", {})["min"] = self.gate_tps_median
105
+ if self.gate_tps_p10 is not None:
106
+ gates.setdefault("tps_p10", {})["min"] = self.gate_tps_p10
107
+
108
+ return gates
overlay/harness/eval_agent.py CHANGED
@@ -1,15 +1,15 @@
1
- """Eval agent: parse run.log and extract metrics from training runs."""
2
- import re
3
- import statistics
4
- from dataclasses import dataclass
5
-
6
-
7
- type GateThresholds = dict[str, float]
8
- type GateConfig = dict[str, GateThresholds]
9
 
10
 
11
  @dataclass
12
- class ExperimentResult:
13
  """Parsed result from a single experiment run.
14
 
15
  All float fields default to 0.0; integer fields default to 0.
@@ -28,38 +28,38 @@ class ExperimentResult:
28
  peak_vram_mb: float = 0.0
29
  mfu_percent: float = 0.0
30
 
31
- # Throughput
32
- total_tokens_m: float = 0.0
33
- num_steps: int = 0
34
- tps_median: float = 0.0
35
- tps_p10: float = 0.0
36
- tps_min: float = 0.0
37
- tps_max: float = 0.0
38
- tps_samples: int = 0
39
 
40
  # Model shape (echoed by train.py summary block)
41
  num_params_m: float = 0.0
42
  n_layer: int = 0
43
  d_model: int = 0
44
 
45
- # Secondary health metrics
46
- mhc_spectral_norm: float = 0.0
47
- engram_hit_rate: float = 0.0
48
- sr_bypass_rate: float = 0.0
49
-
50
- # Evaluation breadth metrics
51
- factual_english_score: float = 0.0
52
- instruction_following_score: float = 0.0
53
- distinct_1: float = 0.0
54
- distinct_2: float = 0.0
55
- repetition_rate: float = 0.0
56
- repetition_bigram_rate: float = 0.0
57
- calibration_ece: float = 0.0
58
- calibration_brier: float = 0.0
59
- calibration_accuracy: float = 0.0
60
- calibration_tokens: int = 0
61
- eval_seed: int = 0
62
- eval_seed_group: str = ""
63
 
64
  # Status
65
  crashed: bool = False
@@ -80,48 +80,48 @@ _PATTERNS: dict[str, str] = {
80
  "n_layer": r"^n_layer:\s+(\d+)",
81
  "d_model": r"^d_model:\s+(\d+)",
82
  "mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
83
- "engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
84
- "sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
85
- "factual_english_score": r"^factual_english_score:\s+([\d.]+)",
86
- "instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
87
- "distinct_1": r"^distinct_1:\s+([\d.]+)",
88
- "distinct_2": r"^distinct_2:\s+([\d.]+)",
89
- "repetition_rate": r"^repetition_rate:\s+([\d.]+)",
90
- "repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
91
- "calibration_ece": r"^calibration_ece:\s+([\d.]+)",
92
- "calibration_brier": r"^calibration_brier:\s*([\d.]+)",
93
- "calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
94
- "calibration_tokens": r"^calibration_tokens:\s+(\d+)",
95
- "eval_seed": r"^eval_seed:\s+(\d+)",
96
- "eval_seed_group": r"^eval_seed_group:\s+(.+)",
97
- }
98
 
99
  # Attributes that should be parsed as int rather than float.
100
- _INT_ATTRS: frozenset[str] = frozenset(
101
- {
102
- "num_steps",
103
- "n_layer",
104
- "d_model",
105
- "calibration_tokens",
106
- "eval_seed",
107
- }
108
- )
109
- _STR_ATTRS: frozenset[str] = frozenset({"eval_seed_group"})
110
- _STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
111
- _TPS_PATTERN = re.compile(r"\btps=(\d+)\b")
112
-
113
-
114
- def _percentile_linear(sorted_values: list[float], pct: float) -> float:
115
- """Compute percentile via linear interpolation (0 <= pct <= 100)."""
116
- if not sorted_values:
117
- return 0.0
118
- if len(sorted_values) == 1:
119
- return sorted_values[0]
120
- rank = (len(sorted_values) - 1) * (pct / 100.0)
121
- lo = int(rank)
122
- hi = min(lo + 1, len(sorted_values) - 1)
123
- frac = rank - lo
124
- return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
125
 
126
 
127
  def parse_run_log(log_path: str) -> ExperimentResult:
@@ -144,60 +144,60 @@ def parse_run_log(log_path: str) -> ExperimentResult:
144
  result.error_message = f"Log file not found: {log_path}"
145
  return result
146
 
147
- # Detect crash signals in output. Keep this strict to avoid false positives
148
- # from benign log lines that include "error" in a non-fatal context.
149
- if (
150
- "Traceback" in content
151
- or "\nFAIL\n" in content
152
- or "[TPS_GUARD] FAIL" in content
153
- or "raise SystemExit(1)" in content
154
- ):
155
- result.crashed = True
156
- lines = content.strip().splitlines()
157
- result.error_message = "\n".join(lines[-20:])
158
-
159
- for attr, pattern in _PATTERNS.items():
160
- match = re.search(pattern, content, re.MULTILINE)
161
- if match:
162
- raw = match.group(1)
163
- if attr in _INT_ATTRS:
164
- setattr(result, attr, int(raw))
165
- elif attr in _STR_ATTRS:
166
- setattr(result, attr, raw.strip())
167
- else:
168
- setattr(result, attr, float(raw))
169
-
170
- warmup_steps = 10
171
- warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
172
- if warmup_match:
173
- warmup_steps = int(warmup_match.group(1))
174
-
175
- step_tps_samples: list[tuple[int, int]] = []
176
- for m in _STEP_TPS_PATTERN.finditer(content):
177
- step_tps_samples.append((int(m.group(1)), int(m.group(2))))
178
-
179
- tps_values: list[float] = []
180
- if step_tps_samples:
181
- for step, tps in step_tps_samples:
182
- if step >= warmup_steps:
183
- tps_values.append(float(tps))
184
- if not tps_values:
185
- tps_values = [float(tps) for _, tps in step_tps_samples]
186
- else:
187
- tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
188
-
189
- if tps_values:
190
- sorted_tps = sorted(tps_values)
191
- result.tps_samples = len(tps_values)
192
- result.tps_median = float(statistics.median(tps_values))
193
- result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
194
- result.tps_min = float(sorted_tps[0])
195
- result.tps_max = float(sorted_tps[-1])
196
-
197
- return result
198
-
199
-
200
- def check_secondary_alarms(result: ExperimentResult) -> list[str]:
201
  """Check secondary metrics against fixed alarm thresholds.
202
 
203
  Args:
@@ -216,44 +216,44 @@ def check_secondary_alarms(result: ExperimentResult) -> list[str]:
216
  alarms.append(
217
  f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
218
  )
219
- if 0 < result.mfu_percent < 10:
220
- alarms.append(
221
- f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
222
- )
223
- if result.calibration_ece > 0.35:
224
- alarms.append(
225
- f"calibration_ece={result.calibration_ece:.4f} > 0.35 (poor calibration)"
226
- )
227
- if result.tps_median > 0 and result.tps_median < 50000:
228
- alarms.append(
229
- f"tps_median={result.tps_median:.0f} < 50000 (throughput below A10 objective)"
230
- )
231
-
232
- return alarms
233
-
234
-
235
- def _check_gate(
236
- result: ExperimentResult,
237
- gates: GateConfig,
238
- metric: str,
239
- ) -> tuple[bool, str] | None:
240
- """Evaluate a single min/max gate against an ExperimentResult metric."""
241
- gate = gates.get(metric, {})
242
- value = getattr(result, metric)
243
- max_value = gate.get("max")
244
- if max_value is not None and value > max_value:
245
- return False, f"{metric} {value:.4f} > gate {max_value}"
246
- min_value = gate.get("min")
247
- if min_value is not None and value < min_value:
248
- return False, f"{metric} {value:.4f} < gate {min_value}"
249
- return None
250
-
251
-
252
- def should_keep(
253
- result: ExperimentResult,
254
- best_bpb: float,
255
- gates: GateConfig | None = None,
256
- ) -> tuple[bool, str]:
257
  """Decide whether to keep or discard an experiment.
258
 
259
  The primary criterion is strictly lower val_bpb than the current best.
@@ -277,24 +277,24 @@ def should_keep(
277
  if result.val_bpb >= best_bpb:
278
  return False, "discard"
279
 
280
- # Secondary gate checks.
281
- if gates:
282
- gate_metrics = (
283
- "mhc_spectral_norm",
284
- "engram_hit_rate",
285
- "factual_english_score",
286
- "instruction_following_score",
287
- "distinct_1",
288
- "distinct_2",
289
- "repetition_rate",
290
- "repetition_bigram_rate",
291
- "calibration_ece",
292
- "tps_median",
293
- "tps_p10",
294
- )
295
- for metric in gate_metrics:
296
- gate_result = _check_gate(result, gates, metric)
297
- if gate_result is not None:
298
- return gate_result
299
-
300
- return True, "keep"
 
1
+ """Eval agent: parse run.log and extract metrics from training runs."""
2
+ import re
3
+ import statistics
4
+ from dataclasses import dataclass
5
+
6
+
7
+ type GateThresholds = dict[str, float]
8
+ type GateConfig = dict[str, GateThresholds]
9
 
10
 
11
  @dataclass
12
+ class ExperimentResult:
13
  """Parsed result from a single experiment run.
14
 
15
  All float fields default to 0.0; integer fields default to 0.
 
28
  peak_vram_mb: float = 0.0
29
  mfu_percent: float = 0.0
30
 
31
+ # Throughput
32
+ total_tokens_m: float = 0.0
33
+ num_steps: int = 0
34
+ tps_median: float = 0.0
35
+ tps_p10: float = 0.0
36
+ tps_min: float = 0.0
37
+ tps_max: float = 0.0
38
+ tps_samples: int = 0
39
 
40
  # Model shape (echoed by train.py summary block)
41
  num_params_m: float = 0.0
42
  n_layer: int = 0
43
  d_model: int = 0
44
 
45
+ # Secondary health metrics
46
+ mhc_spectral_norm: float = 0.0
47
+ engram_hit_rate: float = 0.0
48
+ sr_bypass_rate: float = 0.0
49
+
50
+ # Evaluation breadth metrics
51
+ factual_english_score: float = 0.0
52
+ instruction_following_score: float = 0.0
53
+ distinct_1: float = 0.0
54
+ distinct_2: float = 0.0
55
+ repetition_rate: float = 0.0
56
+ repetition_bigram_rate: float = 0.0
57
+ calibration_ece: float = 0.0
58
+ calibration_brier: float = 0.0
59
+ calibration_accuracy: float = 0.0
60
+ calibration_tokens: int = 0
61
+ eval_seed: int = 0
62
+ eval_seed_group: str = ""
63
 
64
  # Status
65
  crashed: bool = False
 
80
  "n_layer": r"^n_layer:\s+(\d+)",
81
  "d_model": r"^d_model:\s+(\d+)",
82
  "mhc_spectral_norm": r"^mhc_spectral_norm:\s+([\d.]+)",
83
+ "engram_hit_rate": r"^engram_hit_rate:\s+([\d.]+)",
84
+ "sr_bypass_rate": r"^sr_bypass_rate:\s+([\d.]+)",
85
+ "factual_english_score": r"^factual_english_score:\s+([\d.]+)",
86
+ "instruction_following_score": r"^instruction_following_score:\s+([\d.]+)",
87
+ "distinct_1": r"^distinct_1:\s+([\d.]+)",
88
+ "distinct_2": r"^distinct_2:\s+([\d.]+)",
89
+ "repetition_rate": r"^repetition_rate:\s+([\d.]+)",
90
+ "repetition_bigram_rate": r"^repetition_bigram_rate:\s+([\d.]+)",
91
+ "calibration_ece": r"^calibration_ece:\s+([\d.]+)",
92
+ "calibration_brier": r"^calibration_brier:\s*([\d.]+)",
93
+ "calibration_accuracy": r"^calibration_accuracy:\s+([\d.]+)",
94
+ "calibration_tokens": r"^calibration_tokens:\s+(\d+)",
95
+ "eval_seed": r"^eval_seed:\s+(\d+)",
96
+ "eval_seed_group": r"^eval_seed_group:\s+(.+)",
97
+ }
98
 
99
  # Attributes that should be parsed as int rather than float.
100
+ _INT_ATTRS: frozenset[str] = frozenset(
101
+ {
102
+ "num_steps",
103
+ "n_layer",
104
+ "d_model",
105
+ "calibration_tokens",
106
+ "eval_seed",
107
+ }
108
+ )
109
+ _STR_ATTRS: frozenset[str] = frozenset({"eval_seed_group"})
110
+ _STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
111
+ _TPS_PATTERN = re.compile(r"\btps=(\d+)\b")
112
+
113
+
114
+ def _percentile_linear(sorted_values: list[float], pct: float) -> float:
115
+ """Compute percentile via linear interpolation (0 <= pct <= 100)."""
116
+ if not sorted_values:
117
+ return 0.0
118
+ if len(sorted_values) == 1:
119
+ return sorted_values[0]
120
+ rank = (len(sorted_values) - 1) * (pct / 100.0)
121
+ lo = int(rank)
122
+ hi = min(lo + 1, len(sorted_values) - 1)
123
+ frac = rank - lo
124
+ return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
125
 
126
 
127
  def parse_run_log(log_path: str) -> ExperimentResult:
 
144
  result.error_message = f"Log file not found: {log_path}"
145
  return result
146
 
147
+ # Detect crash signals in output. Keep this strict to avoid false positives
148
+ # from benign log lines that include "error" in a non-fatal context.
149
+ if (
150
+ "Traceback" in content
151
+ or "\nFAIL\n" in content
152
+ or "[TPS_GUARD] FAIL" in content
153
+ or "raise SystemExit(1)" in content
154
+ ):
155
+ result.crashed = True
156
+ lines = content.strip().splitlines()
157
+ result.error_message = "\n".join(lines[-20:])
158
+
159
+ for attr, pattern in _PATTERNS.items():
160
+ match = re.search(pattern, content, re.MULTILINE)
161
+ if match:
162
+ raw = match.group(1)
163
+ if attr in _INT_ATTRS:
164
+ setattr(result, attr, int(raw))
165
+ elif attr in _STR_ATTRS:
166
+ setattr(result, attr, raw.strip())
167
+ else:
168
+ setattr(result, attr, float(raw))
169
+
170
+ warmup_steps = 10
171
+ warmup_match = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", content)
172
+ if warmup_match:
173
+ warmup_steps = int(warmup_match.group(1))
174
+
175
+ step_tps_samples: list[tuple[int, int]] = []
176
+ for m in _STEP_TPS_PATTERN.finditer(content):
177
+ step_tps_samples.append((int(m.group(1)), int(m.group(2))))
178
+
179
+ tps_values: list[float] = []
180
+ if step_tps_samples:
181
+ for step, tps in step_tps_samples:
182
+ if step >= warmup_steps:
183
+ tps_values.append(float(tps))
184
+ if not tps_values:
185
+ tps_values = [float(tps) for _, tps in step_tps_samples]
186
+ else:
187
+ tps_values = [float(m.group(1)) for m in _TPS_PATTERN.finditer(content)]
188
+
189
+ if tps_values:
190
+ sorted_tps = sorted(tps_values)
191
+ result.tps_samples = len(tps_values)
192
+ result.tps_median = float(statistics.median(tps_values))
193
+ result.tps_p10 = float(_percentile_linear(sorted_tps, 10.0))
194
+ result.tps_min = float(sorted_tps[0])
195
+ result.tps_max = float(sorted_tps[-1])
196
+
197
+ return result
198
+
199
+
200
+ def check_secondary_alarms(result: ExperimentResult) -> list[str]:
201
  """Check secondary metrics against fixed alarm thresholds.
202
 
203
  Args:
 
216
  alarms.append(
217
  f"engram_hit_rate={result.engram_hit_rate:.4f} < 0.1 (memory underused)"
218
  )
219
+ if 0 < result.mfu_percent < 10:
220
+ alarms.append(
221
+ f"mfu_percent={result.mfu_percent:.2f}% < 10% (GPU underutilized)"
222
+ )
223
+ if result.calibration_ece > 0.35:
224
+ alarms.append(
225
+ f"calibration_ece={result.calibration_ece:.4f} > 0.35 (poor calibration)"
226
+ )
227
+ if result.tps_median > 0 and result.tps_median < 50000:
228
+ alarms.append(
229
+ f"tps_median={result.tps_median:.0f} < 50000 (throughput below A10 objective)"
230
+ )
231
+
232
+ return alarms
233
+
234
+
235
+ def _check_gate(
236
+ result: ExperimentResult,
237
+ gates: GateConfig,
238
+ metric: str,
239
+ ) -> tuple[bool, str] | None:
240
+ """Evaluate a single min/max gate against an ExperimentResult metric."""
241
+ gate = gates.get(metric, {})
242
+ value = getattr(result, metric)
243
+ max_value = gate.get("max")
244
+ if max_value is not None and value > max_value:
245
+ return False, f"{metric} {value:.4f} > gate {max_value}"
246
+ min_value = gate.get("min")
247
+ if min_value is not None and value < min_value:
248
+ return False, f"{metric} {value:.4f} < gate {min_value}"
249
+ return None
250
+
251
+
252
+ def should_keep(
253
+ result: ExperimentResult,
254
+ best_bpb: float,
255
+ gates: GateConfig | None = None,
256
+ ) -> tuple[bool, str]:
257
  """Decide whether to keep or discard an experiment.
258
 
259
  The primary criterion is strictly lower val_bpb than the current best.
 
277
  if result.val_bpb >= best_bpb:
278
  return False, "discard"
279
 
280
+ # Secondary gate checks.
281
+ if gates:
282
+ gate_metrics = (
283
+ "mhc_spectral_norm",
284
+ "engram_hit_rate",
285
+ "factual_english_score",
286
+ "instruction_following_score",
287
+ "distinct_1",
288
+ "distinct_2",
289
+ "repetition_rate",
290
+ "repetition_bigram_rate",
291
+ "calibration_ece",
292
+ "tps_median",
293
+ "tps_p10",
294
+ )
295
+ for metric in gate_metrics:
296
+ gate_result = _check_gate(result, gates, metric)
297
+ if gate_result is not None:
298
+ return gate_result
299
+
300
+ return True, "keep"
overlay/harness/orchestrator.py CHANGED
@@ -20,12 +20,12 @@ provides the infrastructure ("rails") that the autoresearch loop runs on.
20
  """
21
  import argparse
22
  import csv
23
- import os
24
- import subprocess
25
- import time
26
-
27
- from configs.harness_config import HarnessConfig
28
- from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
29
  from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
30
  from harness.health_monitor import check_health, reset_peak_stats
31
  from harness.meta_agent import run_meta_iteration
@@ -145,12 +145,12 @@ def run_experiment(timeout: int = 600) -> str:
145
  # ---------------------------------------------------------------------------
146
 
147
 
148
- def run_loop(
149
- meta_interval: int = 20,
150
- max_experiments: int | None = None,
151
- experiment_timeout: int = 600,
152
- secondary_gates: dict[str, dict[str, float]] | None = None,
153
- ) -> None:
154
  """Run the HYDRA autoresearch loop.
155
 
156
  This function runs indefinitely (or until ``max_experiments`` is reached
@@ -163,10 +163,10 @@ def run_loop(
163
  secondary_gates: Optional gate thresholds forwarded to
164
  :func:`~harness.eval_agent.should_keep`.
165
  """
166
- init_results_tsv()
167
- if secondary_gates is None:
168
- secondary_gates = HarnessConfig().to_secondary_gates()
169
- best_bpb = _load_best_bpb()
170
  experiment_num = count_experiments()
171
 
172
  print(
 
20
  """
21
  import argparse
22
  import csv
23
+ import os
24
+ import subprocess
25
+ import time
26
+
27
+ from configs.harness_config import HarnessConfig
28
+ from harness.eval_agent import ExperimentResult, check_secondary_alarms, parse_run_log, should_keep
29
  from harness.git_utils import REPO_DIR, commit_all, current_commit_short, reset_to
30
  from harness.health_monitor import check_health, reset_peak_stats
31
  from harness.meta_agent import run_meta_iteration
 
145
  # ---------------------------------------------------------------------------
146
 
147
 
148
+ def run_loop(
149
+ meta_interval: int = 20,
150
+ max_experiments: int | None = None,
151
+ experiment_timeout: int = 600,
152
+ secondary_gates: dict[str, dict[str, float]] | None = None,
153
+ ) -> None:
154
  """Run the HYDRA autoresearch loop.
155
 
156
  This function runs indefinitely (or until ``max_experiments`` is reached
 
163
  secondary_gates: Optional gate thresholds forwarded to
164
  :func:`~harness.eval_agent.should_keep`.
165
  """
166
+ init_results_tsv()
167
+ if secondary_gates is None:
168
+ secondary_gates = HarnessConfig().to_secondary_gates()
169
+ best_bpb = _load_best_bpb()
170
  experiment_num = count_experiments()
171
 
172
  print(
overlay/htm_rust/src/gpu/fused.rs CHANGED
@@ -513,41 +513,41 @@ pub(super) fn launch_fused_batched_raw(
513
  assert_eq!(anom_per_region.len(), b);
514
  assert!(b >= 1, "need at least one region");
515
 
516
- // Reset per-region step_scratch before each launch.
517
- for &rp in region_ptrs.iter() {
518
- let r = unsafe { &mut *rp };
519
- let dev = r.sp_gpu.dev_ref().clone();
520
- let fused = r
521
- .fused_state
522
- .as_mut()
523
- .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
524
- dev.memset_zeros(&mut fused.step_scratch)?;
525
- fused.iter_counter = fused.iter_counter.wrapping_add(1);
526
- }
527
 
528
  // Shared config — all regions use identical sp/tm parameters.
529
- let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
530
- let r0 = unsafe { &*region_ptrs[0] };
531
- let fused = r0
532
- .fused_state
533
- .as_ref()
534
- .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
535
- (
536
- fused.grid_dim_x,
537
- fused.block_dim_x,
538
- fused.raw_kernel.function_batched,
539
- *r0.sp_gpu.dev_ref().cu_stream(),
540
- *r0.sp_gpu.dev_ref().cu_primary_ctx(),
541
- )
542
  };
543
 
544
- let cfg = {
545
- let r = unsafe { &*region_ptrs[0] };
546
- let fused = r
547
- .fused_state
548
- .as_ref()
549
- .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
550
- FusedConfig {
551
  input_bits: input_bits as u32,
552
  n_columns: r.sp_gpu.n_columns_accessor() as u32,
553
  synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
@@ -572,42 +572,42 @@ pub(super) fn launch_fused_batched_raw(
572
  initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
573
  t: t as u32,
574
  learn: if learn { 1 } else { 0 },
575
- iter_seed: fused.iter_counter,
576
- cooperative_grid_sync: 1,
577
- }
578
- };
579
 
580
  // Build B FusedPtrs per-region.
581
- let ptrs_vec: Vec<FusedPtrs> = (0..b)
582
- .map(|i| {
583
- let r = unsafe { &*region_ptrs[i] };
584
- let fused = r
585
- .fused_state
586
- .as_ref()
587
- .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
588
- Ok(FusedPtrs {
589
- syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
590
- syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
591
- boost: *r.sp_gpu.boost_accessor().device_ptr(),
592
- active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
593
- inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
594
- seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
595
- seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
596
- syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
597
- tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
598
- cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
599
- cell_active_a: *fused.cell_active_bits_a.device_ptr(),
600
- cell_active_b: *fused.cell_active_bits_b.device_ptr(),
601
- cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
602
- cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
603
- inputs: inputs_per_region[i],
604
- cols_out: cols_per_region[i],
605
- anom_out: anom_per_region[i],
606
- barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
607
- step_scratch: *fused.step_scratch.device_ptr(),
608
- })
609
- })
610
- .collect::<Result<Vec<_>, DriverError>>()?;
611
 
612
  // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
613
  // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
@@ -619,14 +619,14 @@ pub(super) fn launch_fused_batched_raw(
619
  // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
620
  // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
621
  // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
622
- let use_cluster = {
623
- let r0 = unsafe { &*region_ptrs[0] };
624
- let fused = r0
625
- .fused_state
626
- .as_ref()
627
- .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
628
- fused.cluster_info.max_cluster_size > 0
629
- };
630
 
631
  unsafe {
632
  result::ctx::set_current(cu_ctx)?;
 
513
  assert_eq!(anom_per_region.len(), b);
514
  assert!(b >= 1, "need at least one region");
515
 
516
+ // Reset per-region step_scratch before each launch.
517
+ for &rp in region_ptrs.iter() {
518
+ let r = unsafe { &mut *rp };
519
+ let dev = r.sp_gpu.dev_ref().clone();
520
+ let fused = r
521
+ .fused_state
522
+ .as_mut()
523
+ .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
524
+ dev.memset_zeros(&mut fused.step_scratch)?;
525
+ fused.iter_counter = fused.iter_counter.wrapping_add(1);
526
+ }
527
 
528
  // Shared config — all regions use identical sp/tm parameters.
529
+ let (grid_x, block_x, function_batched, cu_stream, cu_ctx) = {
530
+ let r0 = unsafe { &*region_ptrs[0] };
531
+ let fused = r0
532
+ .fused_state
533
+ .as_ref()
534
+ .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
535
+ (
536
+ fused.grid_dim_x,
537
+ fused.block_dim_x,
538
+ fused.raw_kernel.function_batched,
539
+ *r0.sp_gpu.dev_ref().cu_stream(),
540
+ *r0.sp_gpu.dev_ref().cu_primary_ctx(),
541
+ )
542
  };
543
 
544
+ let cfg = {
545
+ let r = unsafe { &*region_ptrs[0] };
546
+ let fused = r
547
+ .fused_state
548
+ .as_ref()
549
+ .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
550
+ FusedConfig {
551
  input_bits: input_bits as u32,
552
  n_columns: r.sp_gpu.n_columns_accessor() as u32,
553
  synapses_per_col: r.sp_gpu.synapses_per_col_accessor() as u32,
 
572
  initial_perm_i16: r.tm_gpu.initial_perm_i16 as i32,
573
  t: t as u32,
574
  learn: if learn { 1 } else { 0 },
575
+ iter_seed: fused.iter_counter,
576
+ cooperative_grid_sync: 1,
577
+ }
578
+ };
579
 
580
  // Build B FusedPtrs per-region.
581
+ let ptrs_vec: Vec<FusedPtrs> = (0..b)
582
+ .map(|i| {
583
+ let r = unsafe { &*region_ptrs[i] };
584
+ let fused = r
585
+ .fused_state
586
+ .as_ref()
587
+ .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
588
+ Ok(FusedPtrs {
589
+ syn_bit: *r.sp_gpu.syn_bit_accessor().device_ptr(),
590
+ syn_perm: *r.sp_gpu.syn_perm_accessor().device_ptr(),
591
+ boost: *r.sp_gpu.boost_accessor().device_ptr(),
592
+ active_duty: *r.sp_gpu.active_duty_accessor().device_ptr(),
593
+ inhibition_threshold: *fused.inhibition_threshold.device_ptr(),
594
+ seg_cell_id: *r.tm_gpu.seg_cell_id_accessor().device_ptr(),
595
+ seg_syn_count: *r.tm_gpu.seg_syn_count_accessor().device_ptr(),
596
+ syn_presyn: *r.tm_gpu.syn_presyn_accessor().device_ptr(),
597
+ tm_syn_perm: *r.tm_gpu.syn_perm_accessor().device_ptr(),
598
+ cell_seg_count: *r.tm_gpu.cell_seg_count_accessor().device_ptr(),
599
+ cell_active_a: *fused.cell_active_bits_a.device_ptr(),
600
+ cell_active_b: *fused.cell_active_bits_b.device_ptr(),
601
+ cell_winner_a: *fused.cell_winner_bits_a.device_ptr(),
602
+ cell_winner_b: *fused.cell_winner_bits_b.device_ptr(),
603
+ inputs: inputs_per_region[i],
604
+ cols_out: cols_per_region[i],
605
+ anom_out: anom_per_region[i],
606
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
607
+ step_scratch: *fused.step_scratch.device_ptr(),
608
+ })
609
+ })
610
+ .collect::<Result<Vec<_>, DriverError>>()?;
611
 
612
  // Upload FusedPtrs array to device (B * sizeof(FusedPtrs) bytes).
613
  // FusedPtrs is repr(C) + DeviceRepr so htod_sync_copy handles it.
 
619
  // Grid = (grid_x, B, 1) with cluster_dim=(16,1,1): each region (Y slice)
620
  // occupies exactly one cluster of 16 blocks. All 8 clusters run concurrently
621
  // on the H200's 132 SMs (8 × 16 = 128 blocks ≤ 132 SMs).
622
+ let use_cluster = {
623
+ let r0 = unsafe { &*region_ptrs[0] };
624
+ let fused = r0
625
+ .fused_state
626
+ .as_ref()
627
+ .ok_or(DriverError(sys::CUresult::CUDA_ERROR_NOT_INITIALIZED))?;
628
+ fused.cluster_info.max_cluster_size > 0
629
+ };
630
 
631
  unsafe {
632
  result::ctx::set_current(cu_ctx)?;
overlay/hydra/eval.py CHANGED
@@ -138,9 +138,6 @@ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
138
  num_samples = FACTUAL_SAMPLES
139
  batch = FACTUAL_BATCH
140
  gen_tokens = FACTUAL_GEN_TOKENS
141
- # Optional fast incremental decode path for recurrence-capable backbones.
142
- # If disabled, we preserve the original full-context re-forward behavior.
143
- incremental_decode = os.environ.get("HYDRA_FACTUAL_GEN_INCREMENTAL", "1") == "1"
144
  temps = [0.7, 0.9, 1.1]
145
  hits = 0
146
 
@@ -157,18 +154,14 @@ def _run_factual_english_gen(model, tokenizer, max_seq_len: int):
157
  temp = temps[batch_idx % len(temps)]
158
  batch_idx += 1
159
  ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
160
- logits = model(ctx, targets=None)
161
  for _ in range(gen_tokens):
 
162
  next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
163
  probs = torch.softmax(next_logits.float() / temp, dim=-1)
164
  next_id = torch.multinomial(probs, num_samples=1)
165
  ctx = torch.cat([ctx, next_id], dim=1)
166
  if ctx.size(1) >= max_seq_len:
167
  break
168
- if incremental_decode:
169
- logits = model(ctx[:, -1:], targets=None)
170
- else:
171
- logits = model(ctx, targets=None)
172
  # Transfer to CPU in one shot, no per-row sync
173
  all_rows.extend(ctx.cpu().tolist())
174
  samples_done += b
 
138
  num_samples = FACTUAL_SAMPLES
139
  batch = FACTUAL_BATCH
140
  gen_tokens = FACTUAL_GEN_TOKENS
 
 
 
141
  temps = [0.7, 0.9, 1.1]
142
  hits = 0
143
 
 
154
  temp = temps[batch_idx % len(temps)]
155
  batch_idx += 1
156
  ctx = torch.tensor([ids] * b, device="cuda", dtype=torch.long)
 
157
  for _ in range(gen_tokens):
158
+ logits = model(ctx, targets=None)
159
  next_logits = logits[:, -1, :] if logits.dim() == 3 else logits
160
  probs = torch.softmax(next_logits.float() / temp, dim=-1)
161
  next_id = torch.multinomial(probs, num_samples=1)
162
  ctx = torch.cat([ctx, next_id], dim=1)
163
  if ctx.size(1) >= max_seq_len:
164
  break
 
 
 
 
165
  # Transfer to CPU in one shot, no per-row sync
166
  all_rows.extend(ctx.cpu().tolist())
167
  samples_done += b
overlay/hydra/model.py CHANGED
@@ -32,58 +32,58 @@ from __future__ import annotations
32
 
33
  import os
34
 
35
- import torch
36
- import torch.nn as nn
37
- import torch.nn.functional as F
38
-
39
- try:
40
- from mamba_ssm import Mamba3
41
- except Exception: # pragma: no cover - depends on optional runtime install
42
- Mamba3 = None # type: ignore[assignment]
43
-
44
-
45
- def _get_mamba3_cls():
46
- global Mamba3
47
- if Mamba3 is None:
48
- try:
49
- from mamba_ssm import Mamba3 as _Mamba3 # type: ignore
50
- Mamba3 = _Mamba3 # type: ignore[assignment]
51
- except Exception as exc: # pragma: no cover - environment dependent
52
- raise ImportError(
53
- "mamba_ssm is required for Mamba-based HYDRA blocks. "
54
- "Install mamba-ssm or use HYDRA_BASELINE_ARCH=transformer."
55
- ) from exc
56
- return Mamba3
57
-
58
-
59
- def _ensure_triton_cuda_backend_registered() -> None:
60
- """Ensure Triton sees exactly one CUDA backend in HF Jobs A10 runtime.
61
-
62
- In some Triton 3.5.1 environments, `triton.compiler.compiler.backends`
63
- and `triton.runtime.driver.backends` are empty even though
64
- `triton.backends.nvidia` is available and CUDA is active. When that
65
- happens, Mamba3 layernorm path crashes at first forward with
66
- "0 compatible backends for target (cuda)".
67
- """
68
- try:
69
- import triton.compiler.compiler as cc
70
- import triton.runtime.driver as rd
71
- from triton.backends import Backend
72
- from triton.backends.nvidia.compiler import CUDABackend
73
- from triton.backends.nvidia.driver import CudaDriver
74
-
75
- if hasattr(rd, "backends") and isinstance(rd.backends, dict) and not rd.backends:
76
- rd.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
77
-
78
- if hasattr(cc, "backends") and isinstance(cc.backends, dict) and not cc.backends:
79
- cc.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
80
- except Exception:
81
- # Keep model construction resilient; runtime will raise explicit Triton
82
- # errors later if backend setup is still invalid.
83
- pass
84
-
85
-
86
- _ensure_triton_cuda_backend_registered()
87
 
88
  from subsystems.hestia_mini import HestiaQAT
89
  from subsystems.htm import HTMLayer
@@ -98,30 +98,30 @@ from hydra.hyena_block import HyenaBlock
98
  from hydra.optimizer import MuonAdamW
99
 
100
 
101
- def norm(x: torch.Tensor) -> torch.Tensor:
102
- """RMSNorm over the last dim — stateless, autocast-friendly."""
103
- return F.rms_norm(x, (x.size(-1),))
104
-
105
-
106
- class TransformerBaselineBlock(nn.Module):
107
- """Transformer-style delta block for matched baseline experiments.
108
-
109
- This block returns a transformed delta tensor rather than owning the outer
110
- residual connection, because ManifoldHyperConnection already handles stream
111
- mixing and residual injection around the block function.
112
- """
113
-
114
- def __init__(self, d_model: int, n_heads: int, expand: int, dropout: float) -> None:
115
- super().__init__()
116
- self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
117
- self.ff_in = nn.Linear(d_model, expand * d_model, bias=False)
118
- self.ff_out = nn.Linear(expand * d_model, d_model, bias=False)
119
- self.dropout = nn.Dropout(dropout)
120
-
121
- def forward(self, x: torch.Tensor) -> torch.Tensor:
122
- attn_out, _ = self.self_attn(x, x, x, need_weights=False)
123
- ff = self.ff_out(F.gelu(self.ff_in(attn_out)))
124
- return self.dropout(attn_out + ff)
125
 
126
 
127
  class PostSemClawModel(nn.Module):
@@ -136,12 +136,12 @@ class PostSemClawModel(nn.Module):
136
  model(x, y, reduction='mean') -> scalar loss
137
  """
138
 
139
- def __init__(self, config):
140
- super().__init__()
141
- _ensure_triton_cuda_backend_registered()
142
- self.config = config
143
- self._throughput_mode = os.environ.get("HYDRA_THROUGHPUT_MODE", "0") == "1"
144
- self._baseline_arch = os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower()
145
 
146
  # Token embedding
147
  self.wte = nn.Embedding(config.vocab_size, config.d_model)
@@ -163,31 +163,31 @@ class PostSemClawModel(nn.Module):
163
  print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True)
164
  _gdn_layer_set -= _hyena_layer_set
165
 
166
- if _gdn_layer_set:
167
- from hydra.gdn_block import GDNBlock # requires `fla` package
168
-
169
- def _build_block(i: int) -> nn.Module:
170
- if self._baseline_arch == "transformer":
171
- return TransformerBaselineBlock(
172
- d_model=config.d_model,
173
- n_heads=config.n_heads,
174
- expand=config.expand,
175
- dropout=float(os.environ.get("HYDRA_DROPOUT", "0.2")),
176
- )
177
- if i in _hyena_layer_set:
178
- return HyenaBlock(
179
  d_model=config.d_model,
180
  seq_len=config.sequence_len,
181
  order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")),
182
  filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")),
183
  )
184
- if i in _gdn_layer_set:
185
- return GDNBlock(
186
- d_model=config.d_model,
187
- n_heads=config.n_heads,
188
- )
189
- mamba3_cls = _get_mamba3_cls()
190
- return mamba3_cls(
191
  d_model=config.d_model,
192
  d_state=config.d_state,
193
  expand=config.expand,
@@ -201,43 +201,43 @@ class PostSemClawModel(nn.Module):
201
  self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)])
202
 
203
  # Full-architecture SDR: offline semantic retina + STE (no-bypass).
204
- if self._throughput_mode:
205
- self.sdr_semantic = None
206
- self.htm = None
207
- self.htm_proj = None
208
- self.htm_anom_proj = None
209
- self.engram = None
210
- self.engram_layer_idx = -1
211
- else:
212
- self.sdr_semantic = SemanticFoldingSDR(
213
- vocab_size=config.vocab_size,
214
- n_bits=config.sdr_n_bits,
215
- target_active=config.sdr_target_active,
216
- delta_rank=config.sdr_delta_rank,
217
- som_warmup_steps=config.sdr_som_warmup,
218
- som_update_interval=config.sdr_som_interval,
219
- )
220
-
221
- # HTM spatial pooler + temporal memory (Rust, Hebbian).
222
- self.htm = HTMLayer(
223
- input_bits=config.sdr_n_bits,
224
- n_columns=config.htm_n_columns,
225
- cells_per_column=config.htm_cells_per_column,
226
- batch_size=1,
227
- seed=42,
228
- learn=True,
229
- reset_each_forward=True,
230
- )
231
-
232
- self.htm_proj = nn.Linear(config.htm_n_columns, config.d_model, bias=False)
233
- self.htm_anom_proj = nn.Linear(1, config.d_model, bias=False)
234
-
235
- self.engram = GPUEngram(
236
- d_model=config.d_model,
237
- n_columns=config.engram_n_columns,
238
- max_ngram=3,
239
- )
240
- self.engram_layer_idx = config.engram_layer_idx
241
 
242
  # Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
243
  self.mhc = nn.ModuleList([
@@ -258,18 +258,18 @@ class PostSemClawModel(nn.Module):
258
  # additional CE losses; no new params. Activated via HYDRA_MTP_K.
259
  self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1")))
260
 
261
- # Learnability knob 3: gradient checkpointing on Mamba3 blocks.
262
- self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
263
-
264
- # Full-arch throughput knob: Engram remains in the architecture, but
265
- # can run every N forwards and reuse its most recent residual delta on
266
- # skipped forwards. This amortizes top-k Hopfield retrieval while still
267
- # injecting Engram signal every microbatch. N=1 preserves exact legacy
268
- # behavior.
269
- self._engram_subsample = max(1, int(os.environ.get("HYDRA_ENGRAM_SUBSAMPLE", "1")))
270
- self._engram_call_idx = 0
271
- self._engram_delta_cache = None
272
- self._engram_hit_rate_cache = None
273
 
274
  # Learnability knob 4: doc-separator BOS masking in packed sequences.
275
  self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
@@ -373,8 +373,8 @@ class PostSemClawModel(nn.Module):
373
  # Required because to_empty() only moves params/buffers, and _retina_indices
374
  # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
375
  device = self.wte.weight.device
376
- if self.sdr_semantic is not None and hasattr(self.sdr_semantic, '_retina_indices'):
377
- self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
378
 
379
  # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
380
  # vocab=8192; at larger vocabs, smaller std prevents logit blowup.
@@ -413,20 +413,20 @@ class PostSemClawModel(nn.Module):
413
  ))
414
  nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
415
 
416
- if self.htm_proj is not None:
417
- nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
418
- if self.htm_anom_proj is not None:
419
- nn.init.normal_(self.htm_anom_proj.weight, mean=0.0, std=s)
420
 
421
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
422
  # dtypes in the same shape group would break lerp_ dtype checks.
423
  self.wte.to(dtype=torch.bfloat16)
424
- if self.htm_proj is not None:
425
- self.htm_proj.to(dtype=torch.bfloat16)
426
- if self.htm_anom_proj is not None:
427
- self.htm_anom_proj.to(dtype=torch.bfloat16)
428
- if self.engram is not None:
429
- self.engram.to(dtype=torch.bfloat16)
430
 
431
  def set_bos_token_id(self, bos_id: int) -> None:
432
  """Inform the model of the tokenizer's BOS id so doc-separator
@@ -472,10 +472,10 @@ class PostSemClawModel(nn.Module):
472
  wte = sum(p.numel() for p in self.wte.parameters())
473
  lm_head = sum(p.numel() for p in self.lm_head.parameters())
474
  blocks = sum(p.numel() for p in self.blocks.parameters())
475
- sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else 0
476
- htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) if self.htm_proj is not None else 0
477
- htm_anom_proj = sum(p.numel() for p in self.htm_anom_proj.parameters()) if self.htm_anom_proj is not None else 0
478
- engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
479
  total = sum(p.numel() for p in self.parameters())
480
  return {
481
  'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
@@ -544,19 +544,19 @@ class PostSemClawModel(nn.Module):
544
  # for p in self.sdr_semantic.parameters():
545
  # if p.dim() == 2:
546
  # matrix_params.append(p)
547
- if self.htm_proj is not None:
548
- for name, p in self.htm_proj.named_parameters():
549
- if _muon_eligible(name, p):
550
- matrix_params.append(p)
551
- if self.engram is not None:
552
- for name, p in self.engram.named_parameters():
553
- if _muon_eligible(name, p):
554
- matrix_params.append(p)
555
 
556
  # SDR params are intentionally not in any optimizer group — they
557
  # receive no gradient in the current forward, so any update would be
558
  # pure noise (weight_decay × lr on a zero-grad param).
559
- sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else set()
560
  assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
561
  scalar_params = [
562
  p for p in self.parameters()
@@ -565,7 +565,7 @@ class PostSemClawModel(nn.Module):
565
 
566
  total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
567
  total_params = len(list(self.parameters()))
568
- sdr_excluded = len(list(self.sdr_semantic.parameters())) if self.sdr_semantic is not None else 0
569
  assert total_assigned + sdr_excluded == total_params, (
570
  f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
571
  f"{sdr_excluded} vs total {total_params}"
@@ -633,59 +633,59 @@ class PostSemClawModel(nn.Module):
633
  else:
634
  _t0 = None
635
 
636
- dense_emb = self.wte(idx) # (B, T, d_model) bf16
637
- if self._throughput_mode:
638
- self._last_sdr = None
639
- sdr_active_bits = 0.0
640
- htm_anomaly = dense_emb.new_tensor(0.0)
641
- x = norm(dense_emb)
642
- if _profile:
643
- _t_htm_async = _ev()
644
- _t_wte = _ev()
645
- _t_htm_await = _ev()
646
- _t_htm_proj = _ev()
647
- else:
648
- sdr_binary = self.sdr_semantic.binary_only(idx)
649
- self._last_sdr = sdr_binary
650
- _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
651
- if not hasattr(self, '_htm_call_idx'):
652
- self._htm_call_idx = 0
653
-
654
- _run_htm = (self._htm_call_idx % _htm_sub == 0)
655
- self._htm_call_idx += 1
656
- if _run_htm:
657
- htm_handle = self.htm.forward_async(sdr_binary)
658
- else:
659
- htm_handle = None
660
-
661
- if _profile: _t_htm_async = _ev()
662
- if _profile: _t_wte = _ev()
663
-
664
- if _run_htm:
665
- htm_out = self.htm.forward_await(htm_handle)
666
- self._htm_cache = htm_out.detach()
667
- elif hasattr(self, '_htm_cache') and self._htm_cache is not None and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
668
- htm_out = self._htm_cache
669
- else:
670
- htm_handle = self.htm.forward_async(sdr_binary)
671
- htm_out = self.htm.forward_await(htm_handle)
672
- self._htm_cache = htm_out.detach()
673
-
674
- if _profile: _t_htm_await = _ev()
675
- with torch.no_grad():
676
- sdr_active_bits = float(self.sdr_semantic.target_active)
677
- htm_anomaly = htm_out[..., -1].mean()
678
- if self._htm_stop_grad:
679
- htm_out = htm_out.detach()
680
- htm_cols = htm_out[..., :-1].to(dense_emb.dtype)
681
- htm_anom = htm_out[..., -1:].to(dense_emb.dtype)
682
- htm_proj_out = self.htm_proj(htm_cols) + self.htm_anom_proj(htm_anom)
683
- x = norm(dense_emb + htm_proj_out)
684
- if _profile: _t_htm_proj = _ev()
685
-
686
- # mHC-routed Mamba-3 stack with Engram injection at configured layer.
687
- streams = self.mhc[0].init_streams(x)
688
- _profile_layer_events = []
689
 
690
  # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
691
  # measure residual contribution of each layer: delta_N = h_post - h_pre.
@@ -702,20 +702,20 @@ class PostSemClawModel(nn.Module):
702
  h_pre = self.mhc[0].merge_streams(streams).detach().float()
703
  _run_svd = (self._diag_step % self._diag_svd_every) == 0
704
 
705
- for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
706
- if _profile:
707
- _t_layer_start = _ev()
708
- _t_layer_mhc = None
709
- _t_engram_start = None
710
- _t_engram_end = None
711
- else:
712
- _t_layer_start = None
713
- _t_layer_mhc = None
714
- _t_engram_start = None
715
- _t_engram_end = None
716
-
717
- def _block_fn(h, _block=block):
718
- return self.drop(_block(norm(h)))
719
 
720
  # Learnability #3: gradient checkpointing. Wrap the block-fn so
721
  # the mhc layer's internal uses of it re-run the block in backward
@@ -726,31 +726,31 @@ class PostSemClawModel(nn.Module):
726
  _raw_fn = _block_fn
727
  def _block_fn(h, _raw=_raw_fn): # noqa: E731
728
  return _ckpt.checkpoint(_raw, h, use_reentrant=False)
729
-
730
- streams = mhc_layer(streams, _block_fn)
731
- if _profile:
732
- _t_layer_mhc = _ev()
733
-
734
- if self.engram is not None and i == self.engram_layer_idx:
735
- if _profile: _t_engram_start = _ev()
736
- x_mid = mhc_layer.merge_streams(streams)
737
- _run_engram = (self._engram_call_idx % self._engram_subsample == 0)
738
- self._engram_call_idx += 1
739
- if _run_engram or self._engram_delta_cache is None or self._engram_delta_cache.shape != x_mid.shape:
740
- x_engram, hit_rate = self.engram(x_mid, idx)
741
- self._engram_delta_cache = (x_engram - x_mid).detach()
742
- self._engram_hit_rate_cache = hit_rate.detach() if torch.is_tensor(hit_rate) else hit_rate
743
- x_mid = x_engram
744
- else:
745
- # Preserve gradient flow through the identity path while
746
- # reusing a detached Engram residual. The Engram module
747
- # still contributes to every forward; its expensive top-k
748
- # retrieval and parameter gradients run on the cadence.
749
- x_mid = x_mid + self._engram_delta_cache.to(dtype=x_mid.dtype, device=x_mid.device)
750
- hit_rate = self._engram_hit_rate_cache
751
- streams = mhc_layer.init_streams(x_mid)
752
- self._metrics['engram_hit_rate'] = hit_rate
753
- if _profile: _t_engram_end = _ev()
754
 
755
  if _diag:
756
  with torch.no_grad():
@@ -773,20 +773,20 @@ class PostSemClawModel(nn.Module):
773
  self._metrics[f'layer_{i}_eff_rank'] = eff_rank
774
  except Exception:
775
  pass
776
- h_pre = h_post
777
-
778
- if _profile:
779
- _profile_layer_events.append(
780
- (i, _t_layer_start, _t_layer_mhc, _t_engram_start, _t_engram_end, _ev())
781
- )
782
 
783
  if _diag:
784
  self._diag_step += 1
785
 
786
  if _profile: _t_blocks = _ev()
787
 
788
- self._metrics['sdr_active_bits'] = sdr_active_bits
789
- self._metrics['htm_anomaly'] = htm_anomaly
790
 
791
  x = self.mhc[-1].merge_streams(streams)
792
  x = norm(x)
@@ -1000,8 +1000,8 @@ class PostSemClawModel(nn.Module):
1000
  _t_end = _ev()
1001
  torch.cuda.synchronize()
1002
  def _ms(a, b): return a.elapsed_time(b)
1003
- print(
1004
- f"[PROFILE B={B} T={T}] "
1005
  f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
1006
  f"wte={_ms(_t_htm_async, _t_wte):.2f} "
1007
  f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
@@ -1010,23 +1010,23 @@ class PostSemClawModel(nn.Module):
1010
  f"merge={_ms(_t_blocks, _t_merge):.2f} "
1011
  f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
1012
  f"total={_ms(_t0, _t_end):.2f} ms",
1013
- flush=True,
1014
- )
1015
- for _li, _start, _after_mhc, _engram_start, _engram_end, _end in _profile_layer_events:
1016
- print(
1017
- f"[PROFILE_LAYER B={B} T={T} layer={_li}] "
1018
- f"mhc_block={_ms(_start, _after_mhc):.2f} "
1019
- f"layer_total={_ms(_start, _end):.2f} ms",
1020
- flush=True,
1021
- )
1022
- if _engram_start is not None and _engram_end is not None:
1023
- print(
1024
- f"[PROFILE_ENGRAM B={B} T={T} layer={_li}] "
1025
- f"engram={_ms(_engram_start, _engram_end):.2f} "
1026
- f"post_layer_total={_ms(_after_mhc, _end):.2f} ms",
1027
- flush=True,
1028
- )
1029
- return out
1030
 
1031
  logits = self.lm_head(x).float()
1032
  if _softcap_clamp:
 
32
 
33
  import os
34
 
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+ try:
40
+ from mamba_ssm import Mamba3
41
+ except Exception: # pragma: no cover - depends on optional runtime install
42
+ Mamba3 = None # type: ignore[assignment]
43
+
44
+
45
+ def _get_mamba3_cls():
46
+ global Mamba3
47
+ if Mamba3 is None:
48
+ try:
49
+ from mamba_ssm import Mamba3 as _Mamba3 # type: ignore
50
+ Mamba3 = _Mamba3 # type: ignore[assignment]
51
+ except Exception as exc: # pragma: no cover - environment dependent
52
+ raise ImportError(
53
+ "mamba_ssm is required for Mamba-based HYDRA blocks. "
54
+ "Install mamba-ssm or use HYDRA_BASELINE_ARCH=transformer."
55
+ ) from exc
56
+ return Mamba3
57
+
58
+
59
+ def _ensure_triton_cuda_backend_registered() -> None:
60
+ """Ensure Triton sees exactly one CUDA backend in HF Jobs A10 runtime.
61
+
62
+ In some Triton 3.5.1 environments, `triton.compiler.compiler.backends`
63
+ and `triton.runtime.driver.backends` are empty even though
64
+ `triton.backends.nvidia` is available and CUDA is active. When that
65
+ happens, Mamba3 layernorm path crashes at first forward with
66
+ "0 compatible backends for target (cuda)".
67
+ """
68
+ try:
69
+ import triton.compiler.compiler as cc
70
+ import triton.runtime.driver as rd
71
+ from triton.backends import Backend
72
+ from triton.backends.nvidia.compiler import CUDABackend
73
+ from triton.backends.nvidia.driver import CudaDriver
74
+
75
+ if hasattr(rd, "backends") and isinstance(rd.backends, dict) and not rd.backends:
76
+ rd.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
77
+
78
+ if hasattr(cc, "backends") and isinstance(cc.backends, dict) and not cc.backends:
79
+ cc.backends["nvidia"] = Backend(compiler=CUDABackend, driver=CudaDriver)
80
+ except Exception:
81
+ # Keep model construction resilient; runtime will raise explicit Triton
82
+ # errors later if backend setup is still invalid.
83
+ pass
84
+
85
+
86
+ _ensure_triton_cuda_backend_registered()
87
 
88
  from subsystems.hestia_mini import HestiaQAT
89
  from subsystems.htm import HTMLayer
 
98
  from hydra.optimizer import MuonAdamW
99
 
100
 
101
+ def norm(x: torch.Tensor) -> torch.Tensor:
102
+ """RMSNorm over the last dim — stateless, autocast-friendly."""
103
+ return F.rms_norm(x, (x.size(-1),))
104
+
105
+
106
+ class TransformerBaselineBlock(nn.Module):
107
+ """Transformer-style delta block for matched baseline experiments.
108
+
109
+ This block returns a transformed delta tensor rather than owning the outer
110
+ residual connection, because ManifoldHyperConnection already handles stream
111
+ mixing and residual injection around the block function.
112
+ """
113
+
114
+ def __init__(self, d_model: int, n_heads: int, expand: int, dropout: float) -> None:
115
+ super().__init__()
116
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
117
+ self.ff_in = nn.Linear(d_model, expand * d_model, bias=False)
118
+ self.ff_out = nn.Linear(expand * d_model, d_model, bias=False)
119
+ self.dropout = nn.Dropout(dropout)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ attn_out, _ = self.self_attn(x, x, x, need_weights=False)
123
+ ff = self.ff_out(F.gelu(self.ff_in(attn_out)))
124
+ return self.dropout(attn_out + ff)
125
 
126
 
127
  class PostSemClawModel(nn.Module):
 
136
  model(x, y, reduction='mean') -> scalar loss
137
  """
138
 
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ _ensure_triton_cuda_backend_registered()
142
+ self.config = config
143
+ self._throughput_mode = os.environ.get("HYDRA_THROUGHPUT_MODE", "0") == "1"
144
+ self._baseline_arch = os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower()
145
 
146
  # Token embedding
147
  self.wte = nn.Embedding(config.vocab_size, config.d_model)
 
163
  print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True)
164
  _gdn_layer_set -= _hyena_layer_set
165
 
166
+ if _gdn_layer_set:
167
+ from hydra.gdn_block import GDNBlock # requires `fla` package
168
+
169
+ def _build_block(i: int) -> nn.Module:
170
+ if self._baseline_arch == "transformer":
171
+ return TransformerBaselineBlock(
172
+ d_model=config.d_model,
173
+ n_heads=config.n_heads,
174
+ expand=config.expand,
175
+ dropout=float(os.environ.get("HYDRA_DROPOUT", "0.2")),
176
+ )
177
+ if i in _hyena_layer_set:
178
+ return HyenaBlock(
179
  d_model=config.d_model,
180
  seq_len=config.sequence_len,
181
  order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")),
182
  filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")),
183
  )
184
+ if i in _gdn_layer_set:
185
+ return GDNBlock(
186
+ d_model=config.d_model,
187
+ n_heads=config.n_heads,
188
+ )
189
+ mamba3_cls = _get_mamba3_cls()
190
+ return mamba3_cls(
191
  d_model=config.d_model,
192
  d_state=config.d_state,
193
  expand=config.expand,
 
201
  self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)])
202
 
203
  # Full-architecture SDR: offline semantic retina + STE (no-bypass).
204
+ if self._throughput_mode:
205
+ self.sdr_semantic = None
206
+ self.htm = None
207
+ self.htm_proj = None
208
+ self.htm_anom_proj = None
209
+ self.engram = None
210
+ self.engram_layer_idx = -1
211
+ else:
212
+ self.sdr_semantic = SemanticFoldingSDR(
213
+ vocab_size=config.vocab_size,
214
+ n_bits=config.sdr_n_bits,
215
+ target_active=config.sdr_target_active,
216
+ delta_rank=config.sdr_delta_rank,
217
+ som_warmup_steps=config.sdr_som_warmup,
218
+ som_update_interval=config.sdr_som_interval,
219
+ )
220
+
221
+ # HTM spatial pooler + temporal memory (Rust, Hebbian).
222
+ self.htm = HTMLayer(
223
+ input_bits=config.sdr_n_bits,
224
+ n_columns=config.htm_n_columns,
225
+ cells_per_column=config.htm_cells_per_column,
226
+ batch_size=1,
227
+ seed=42,
228
+ learn=True,
229
+ reset_each_forward=True,
230
+ )
231
+
232
+ self.htm_proj = nn.Linear(config.htm_n_columns, config.d_model, bias=False)
233
+ self.htm_anom_proj = nn.Linear(1, config.d_model, bias=False)
234
+
235
+ self.engram = GPUEngram(
236
+ d_model=config.d_model,
237
+ n_columns=config.engram_n_columns,
238
+ max_ngram=3,
239
+ )
240
+ self.engram_layer_idx = config.engram_layer_idx
241
 
242
  # Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
243
  self.mhc = nn.ModuleList([
 
258
  # additional CE losses; no new params. Activated via HYDRA_MTP_K.
259
  self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1")))
260
 
261
+ # Learnability knob 3: gradient checkpointing on Mamba3 blocks.
262
+ self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
263
+
264
+ # Full-arch throughput knob: Engram remains in the architecture, but
265
+ # can run every N forwards and reuse its most recent residual delta on
266
+ # skipped forwards. This amortizes top-k Hopfield retrieval while still
267
+ # injecting Engram signal every microbatch. N=1 preserves exact legacy
268
+ # behavior.
269
+ self._engram_subsample = max(1, int(os.environ.get("HYDRA_ENGRAM_SUBSAMPLE", "1")))
270
+ self._engram_call_idx = 0
271
+ self._engram_delta_cache = None
272
+ self._engram_hit_rate_cache = None
273
 
274
  # Learnability knob 4: doc-separator BOS masking in packed sequences.
275
  self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
 
373
  # Required because to_empty() only moves params/buffers, and _retina_indices
374
  # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
375
  device = self.wte.weight.device
376
+ if self.sdr_semantic is not None and hasattr(self.sdr_semantic, '_retina_indices'):
377
+ self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
378
 
379
  # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
380
  # vocab=8192; at larger vocabs, smaller std prevents logit blowup.
 
413
  ))
414
  nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
415
 
416
+ if self.htm_proj is not None:
417
+ nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
418
+ if self.htm_anom_proj is not None:
419
+ nn.init.normal_(self.htm_anom_proj.weight, mean=0.0, std=s)
420
 
421
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
422
  # dtypes in the same shape group would break lerp_ dtype checks.
423
  self.wte.to(dtype=torch.bfloat16)
424
+ if self.htm_proj is not None:
425
+ self.htm_proj.to(dtype=torch.bfloat16)
426
+ if self.htm_anom_proj is not None:
427
+ self.htm_anom_proj.to(dtype=torch.bfloat16)
428
+ if self.engram is not None:
429
+ self.engram.to(dtype=torch.bfloat16)
430
 
431
  def set_bos_token_id(self, bos_id: int) -> None:
432
  """Inform the model of the tokenizer's BOS id so doc-separator
 
472
  wte = sum(p.numel() for p in self.wte.parameters())
473
  lm_head = sum(p.numel() for p in self.lm_head.parameters())
474
  blocks = sum(p.numel() for p in self.blocks.parameters())
475
+ sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else 0
476
+ htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) if self.htm_proj is not None else 0
477
+ htm_anom_proj = sum(p.numel() for p in self.htm_anom_proj.parameters()) if self.htm_anom_proj is not None else 0
478
+ engram = sum(p.numel() for p in self.engram.parameters()) if self.engram is not None else 0
479
  total = sum(p.numel() for p in self.parameters())
480
  return {
481
  'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
 
544
  # for p in self.sdr_semantic.parameters():
545
  # if p.dim() == 2:
546
  # matrix_params.append(p)
547
+ if self.htm_proj is not None:
548
+ for name, p in self.htm_proj.named_parameters():
549
+ if _muon_eligible(name, p):
550
+ matrix_params.append(p)
551
+ if self.engram is not None:
552
+ for name, p in self.engram.named_parameters():
553
+ if _muon_eligible(name, p):
554
+ matrix_params.append(p)
555
 
556
  # SDR params are intentionally not in any optimizer group — they
557
  # receive no gradient in the current forward, so any update would be
558
  # pure noise (weight_decay × lr on a zero-grad param).
559
+ sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) if self.sdr_semantic is not None else set()
560
  assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
561
  scalar_params = [
562
  p for p in self.parameters()
 
565
 
566
  total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params)
567
  total_params = len(list(self.parameters()))
568
+ sdr_excluded = len(list(self.sdr_semantic.parameters())) if self.sdr_semantic is not None else 0
569
  assert total_assigned + sdr_excluded == total_params, (
570
  f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
571
  f"{sdr_excluded} vs total {total_params}"
 
633
  else:
634
  _t0 = None
635
 
636
+ dense_emb = self.wte(idx) # (B, T, d_model) bf16
637
+ if self._throughput_mode:
638
+ self._last_sdr = None
639
+ sdr_active_bits = 0.0
640
+ htm_anomaly = dense_emb.new_tensor(0.0)
641
+ x = norm(dense_emb)
642
+ if _profile:
643
+ _t_htm_async = _ev()
644
+ _t_wte = _ev()
645
+ _t_htm_await = _ev()
646
+ _t_htm_proj = _ev()
647
+ else:
648
+ sdr_binary = self.sdr_semantic.binary_only(idx)
649
+ self._last_sdr = sdr_binary
650
+ _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
651
+ if not hasattr(self, '_htm_call_idx'):
652
+ self._htm_call_idx = 0
653
+
654
+ _run_htm = (self._htm_call_idx % _htm_sub == 0)
655
+ self._htm_call_idx += 1
656
+ if _run_htm:
657
+ htm_handle = self.htm.forward_async(sdr_binary)
658
+ else:
659
+ htm_handle = None
660
+
661
+ if _profile: _t_htm_async = _ev()
662
+ if _profile: _t_wte = _ev()
663
+
664
+ if _run_htm:
665
+ htm_out = self.htm.forward_await(htm_handle)
666
+ self._htm_cache = htm_out.detach()
667
+ elif hasattr(self, '_htm_cache') and self._htm_cache is not None and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
668
+ htm_out = self._htm_cache
669
+ else:
670
+ htm_handle = self.htm.forward_async(sdr_binary)
671
+ htm_out = self.htm.forward_await(htm_handle)
672
+ self._htm_cache = htm_out.detach()
673
+
674
+ if _profile: _t_htm_await = _ev()
675
+ with torch.no_grad():
676
+ sdr_active_bits = float(self.sdr_semantic.target_active)
677
+ htm_anomaly = htm_out[..., -1].mean()
678
+ if self._htm_stop_grad:
679
+ htm_out = htm_out.detach()
680
+ htm_cols = htm_out[..., :-1].to(dense_emb.dtype)
681
+ htm_anom = htm_out[..., -1:].to(dense_emb.dtype)
682
+ htm_proj_out = self.htm_proj(htm_cols) + self.htm_anom_proj(htm_anom)
683
+ x = norm(dense_emb + htm_proj_out)
684
+ if _profile: _t_htm_proj = _ev()
685
+
686
+ # mHC-routed Mamba-3 stack with Engram injection at configured layer.
687
+ streams = self.mhc[0].init_streams(x)
688
+ _profile_layer_events = []
689
 
690
  # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
691
  # measure residual contribution of each layer: delta_N = h_post - h_pre.
 
702
  h_pre = self.mhc[0].merge_streams(streams).detach().float()
703
  _run_svd = (self._diag_step % self._diag_svd_every) == 0
704
 
705
+ for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
706
+ if _profile:
707
+ _t_layer_start = _ev()
708
+ _t_layer_mhc = None
709
+ _t_engram_start = None
710
+ _t_engram_end = None
711
+ else:
712
+ _t_layer_start = None
713
+ _t_layer_mhc = None
714
+ _t_engram_start = None
715
+ _t_engram_end = None
716
+
717
+ def _block_fn(h, _block=block):
718
+ return self.drop(_block(norm(h)))
719
 
720
  # Learnability #3: gradient checkpointing. Wrap the block-fn so
721
  # the mhc layer's internal uses of it re-run the block in backward
 
726
  _raw_fn = _block_fn
727
  def _block_fn(h, _raw=_raw_fn): # noqa: E731
728
  return _ckpt.checkpoint(_raw, h, use_reentrant=False)
729
+
730
+ streams = mhc_layer(streams, _block_fn)
731
+ if _profile:
732
+ _t_layer_mhc = _ev()
733
+
734
+ if self.engram is not None and i == self.engram_layer_idx:
735
+ if _profile: _t_engram_start = _ev()
736
+ x_mid = mhc_layer.merge_streams(streams)
737
+ _run_engram = (self._engram_call_idx % self._engram_subsample == 0)
738
+ self._engram_call_idx += 1
739
+ if _run_engram or self._engram_delta_cache is None or self._engram_delta_cache.shape != x_mid.shape:
740
+ x_engram, hit_rate = self.engram(x_mid, idx)
741
+ self._engram_delta_cache = (x_engram - x_mid).detach()
742
+ self._engram_hit_rate_cache = hit_rate.detach() if torch.is_tensor(hit_rate) else hit_rate
743
+ x_mid = x_engram
744
+ else:
745
+ # Preserve gradient flow through the identity path while
746
+ # reusing a detached Engram residual. The Engram module
747
+ # still contributes to every forward; its expensive top-k
748
+ # retrieval and parameter gradients run on the cadence.
749
+ x_mid = x_mid + self._engram_delta_cache.to(dtype=x_mid.dtype, device=x_mid.device)
750
+ hit_rate = self._engram_hit_rate_cache
751
+ streams = mhc_layer.init_streams(x_mid)
752
+ self._metrics['engram_hit_rate'] = hit_rate
753
+ if _profile: _t_engram_end = _ev()
754
 
755
  if _diag:
756
  with torch.no_grad():
 
773
  self._metrics[f'layer_{i}_eff_rank'] = eff_rank
774
  except Exception:
775
  pass
776
+ h_pre = h_post
777
+
778
+ if _profile:
779
+ _profile_layer_events.append(
780
+ (i, _t_layer_start, _t_layer_mhc, _t_engram_start, _t_engram_end, _ev())
781
+ )
782
 
783
  if _diag:
784
  self._diag_step += 1
785
 
786
  if _profile: _t_blocks = _ev()
787
 
788
+ self._metrics['sdr_active_bits'] = sdr_active_bits
789
+ self._metrics['htm_anomaly'] = htm_anomaly
790
 
791
  x = self.mhc[-1].merge_streams(streams)
792
  x = norm(x)
 
1000
  _t_end = _ev()
1001
  torch.cuda.synchronize()
1002
  def _ms(a, b): return a.elapsed_time(b)
1003
+ print(
1004
+ f"[PROFILE B={B} T={T}] "
1005
  f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
1006
  f"wte={_ms(_t_htm_async, _t_wte):.2f} "
1007
  f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
 
1010
  f"merge={_ms(_t_blocks, _t_merge):.2f} "
1011
  f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
1012
  f"total={_ms(_t0, _t_end):.2f} ms",
1013
+ flush=True,
1014
+ )
1015
+ for _li, _start, _after_mhc, _engram_start, _engram_end, _end in _profile_layer_events:
1016
+ print(
1017
+ f"[PROFILE_LAYER B={B} T={T} layer={_li}] "
1018
+ f"mhc_block={_ms(_start, _after_mhc):.2f} "
1019
+ f"layer_total={_ms(_start, _end):.2f} ms",
1020
+ flush=True,
1021
+ )
1022
+ if _engram_start is not None and _engram_end is not None:
1023
+ print(
1024
+ f"[PROFILE_ENGRAM B={B} T={T} layer={_li}] "
1025
+ f"engram={_ms(_engram_start, _engram_end):.2f} "
1026
+ f"post_layer_total={_ms(_after_mhc, _end):.2f} ms",
1027
+ flush=True,
1028
+ )
1029
+ return out
1030
 
1031
  logits = self.lm_head(x).float()
1032
  if _softcap_clamp:
overlay/hydra/training.py CHANGED
@@ -4,20 +4,20 @@ Extracted from the monolithic train.py (W1 modularization). Semantics
4
  preserved. Public entrypoint: `main()`.
5
  """
6
 
7
- from __future__ import annotations
8
-
9
- import gc
10
- import hashlib
11
- import json
12
- import math
13
- import os
14
- import sys
15
- import threading
16
- import time
17
- from collections.abc import Mapping
18
- from dataclasses import asdict
19
- from pathlib import Path
20
- from typing import Any
21
 
22
  import torch
23
 
@@ -133,7 +133,7 @@ def _ckpt_snapshot_state_dicts(
133
  return msd, osd
134
 
135
 
136
- def save_ckpt(
137
  model: PostSemClawModel,
138
  optimizer: torch.optim.Optimizer,
139
  config: PostSemClawConfig,
@@ -214,233 +214,233 @@ def save_ckpt(
214
  target=_write, daemon=True, name=f"ckpt-save-{step}"
215
  )
216
  _CKPT_WORKER_THREAD.start()
217
- except Exception as e:
218
- print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
219
-
220
-
221
- def _env_flag_enabled(env: Mapping[str, str], key: str) -> bool:
222
- value = str(env.get(key, "0") or "0").strip().lower()
223
- return value not in {"", "0", "false", "no", "off"}
224
-
225
-
226
- def _env_int(env: Mapping[str, str], key: str, default: int) -> int:
227
- try:
228
- return int(str(env.get(key, str(default)) or str(default)))
229
- except ValueError:
230
- return default
231
-
232
-
233
- def architecture_compliance_payload(env: Mapping[str, str]) -> dict[str, bool | int | str]:
234
- throughput_mode = _env_flag_enabled(env, "HYDRA_THROUGHPUT_MODE")
235
- fastpath = _env_flag_enabled(env, "HYDRA_FASTPATH")
236
- force_htm_cpu = _env_flag_enabled(env, "HYDRA_FORCE_HTM_CPU")
237
- inert_mamba = _env_flag_enabled(env, "HYDRA_INERT_MAMBA")
238
- synthetic_retina = _env_flag_enabled(env, "HYDRA_ALLOW_SYNTHETIC_RETINA")
239
- hyena_layers = str(env.get("HYDRA_HYENA_LAYERS", "") or "")
240
- engram_subsample = _env_int(env, "HYDRA_ENGRAM_SUBSAMPLE", 1)
241
- htm_subsample = _env_int(env, "HYDRA_HTM_SUBSAMPLE", 1)
242
- full_arch_compliant = not any((
243
- throughput_mode,
244
- fastpath,
245
- force_htm_cpu,
246
- inert_mamba,
247
- synthetic_retina,
248
- bool(hyena_layers.strip()),
249
- ))
250
- return {
251
- 'full_arch_compliant': full_arch_compliant,
252
- 'throughput_mode': throughput_mode,
253
- 'fastpath': fastpath,
254
- 'force_htm_cpu': force_htm_cpu,
255
- 'inert_mamba': inert_mamba,
256
- 'synthetic_retina': synthetic_retina,
257
- 'hyena_layers': hyena_layers,
258
- 'engram_subsample': engram_subsample,
259
- 'htm_subsample': htm_subsample,
260
- }
261
-
262
-
263
- def eval_attempt_batches(*, requested_batch: int, min_batch: int) -> list[int]:
264
- requested = max(1, int(requested_batch))
265
- minimum = max(1, int(min_batch))
266
- batches: list[int] = []
267
- current = requested
268
- while current >= minimum:
269
- if current not in batches:
270
- batches.append(current)
271
- if current == minimum:
272
- break
273
- next_batch = max(minimum, current // 2)
274
- if next_batch == current:
275
- break
276
- current = next_batch
277
- if minimum not in batches:
278
- batches.append(minimum)
279
- return batches
280
-
281
-
282
- def build_eval_plan(*, eval_tokens: int, requested_batch: int, max_seq_len: int, chunk_tokens: int, min_batch: int) -> dict[str, Any]:
283
- effective_chunk_tokens = max(int(chunk_tokens), int(requested_batch) * int(max_seq_len))
284
- chunk_count = max(1, math.ceil(int(eval_tokens) / effective_chunk_tokens))
285
- return {
286
- 'eval_tokens': int(eval_tokens),
287
- 'eval_requested_batch': int(requested_batch),
288
- 'eval_chunk_tokens': int(effective_chunk_tokens),
289
- 'eval_chunk_count': int(chunk_count),
290
- 'eval_attempt_batches': eval_attempt_batches(requested_batch=requested_batch, min_batch=min_batch),
291
- 'eval_min_batch': int(max(1, min_batch)),
292
- }
293
-
294
-
295
- def _fingerprint_descriptor(descriptor: Mapping[str, Any]) -> str:
296
- payload = json.dumps(dict(descriptor), sort_keys=True, separators=(",", ":"))
297
- return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:12]
298
-
299
-
300
- def dataset_domain_payload(*, env: Mapping[str, str], prepare_module: Any, nemotron_module: Any | None) -> dict[str, Any]:
301
- use_nemotron = _env_flag_enabled(env, "HYDRA_USE_NEMOTRON")
302
- vocab_size = int(getattr(prepare_module, "VOCAB_SIZE", 0))
303
-
304
- if use_nemotron and nemotron_module is not None:
305
- use_full_blend = _env_flag_enabled(env, "HYDRA_USE_FULL_BLEND")
306
- phase = str(env.get("HYDRA_NEMOTRON_PHASE", "phase1") or "phase1").strip().lower()
307
- if use_full_blend:
308
- train_weights = dict(getattr(nemotron_module, "FULL_BLEND_WEIGHTS", {}))
309
- val_weights = dict(train_weights)
310
- else:
311
- train_weights = dict(
312
- getattr(nemotron_module, "PHASE2_WEIGHTS", {}) if phase == "phase2" else getattr(nemotron_module, "PHASE1_WEIGHTS", {})
313
- )
314
- val_weights = {"Nemotron-Pretraining-Multiple-Choice": 1.0}
315
- train_descriptor = {
316
- "backend": "nemotron_stream",
317
- "phase": "full_blend" if use_full_blend else phase,
318
- "weights": train_weights,
319
- "factual_inject_rate": _env_int(env, "HYDRA_FACTUAL_INJECT_RATE", 50),
320
- "vocab_size": vocab_size,
321
- }
322
- val_descriptor = {
323
- "backend": "nemotron_stream",
324
- "phase": "full_blend" if use_full_blend else "val_multiple_choice",
325
- "weights": val_weights,
326
- "vocab_size": vocab_size,
327
- }
328
- data_backend = "nemotron_stream"
329
- else:
330
- all_files = list(getattr(prepare_module, "list_parquet_files", lambda: [])())
331
- val_filename = str(getattr(prepare_module, "VAL_FILENAME", ""))
332
- train_files = [str(path) for path in all_files if not str(path).endswith(val_filename)]
333
- val_files = [str(path) for path in all_files if str(path).endswith(val_filename)]
334
- train_descriptor = {
335
- "backend": "climbmix_parquet",
336
- "train_shard_count": len(train_files),
337
- "train_shard_examples": sorted(Path(path).name for path in train_files[:3]),
338
- "vocab_size": vocab_size,
339
- }
340
- val_descriptor = {
341
- "backend": "climbmix_parquet",
342
- "val_filename": val_filename,
343
- "val_shard_count": len(val_files),
344
- "vocab_size": vocab_size,
345
- }
346
- data_backend = "climbmix_parquet"
347
-
348
- train_fingerprint = _fingerprint_descriptor(train_descriptor)
349
- val_fingerprint = _fingerprint_descriptor(val_descriptor)
350
- return {
351
- "data_backend": data_backend,
352
- "train_domain_descriptor": train_descriptor,
353
- "val_domain_descriptor": val_descriptor,
354
- "train_domain_fingerprint": train_fingerprint,
355
- "val_domain_fingerprint": val_fingerprint,
356
- "train_val_domain_match": train_fingerprint == val_fingerprint,
357
- }
358
-
359
-
360
- def build_lineage_payload(
361
- *,
362
- env: Mapping[str, str],
363
- seed: int,
364
- resume_requested: bool,
365
- resume_requested_path: str | None,
366
- resume_loaded_path: str | None,
367
- resume_step: int,
368
- resume_epoch: int,
369
- ) -> dict[str, Any]:
370
- warmstart = _env_flag_enabled(env, "HYDRA_WARMSTART")
371
- resume_applied = resume_loaded_path is not None and int(resume_step) > 0
372
- if resume_applied and warmstart:
373
- lineage_mode = "warmstart_resume"
374
- elif resume_applied:
375
- lineage_mode = "resume"
376
- else:
377
- lineage_mode = "fresh"
378
- return {
379
- "seed": int(seed),
380
- "warmstart": warmstart,
381
- "resume_requested": bool(resume_requested),
382
- "resume_applied": resume_applied,
383
- "resume_requested_path": resume_requested_path,
384
- "resume_loaded_path": resume_loaded_path,
385
- "resume_step": int(resume_step),
386
- "resume_epoch": int(resume_epoch),
387
- "lineage_mode": lineage_mode,
388
- }
389
-
390
-
391
- def build_final_metrics_payload(
392
- *,
393
- secondary_metrics: dict[str, Any],
394
- val_bpb: float | None,
395
- val_ppl: float | None,
396
- eval_status: str,
397
- eval_error: str | None,
398
- n_layer: int,
399
- d_model: int,
400
- num_params: int,
401
- step: int,
402
- total_tokens: int,
403
- peak_vram_mb: float,
404
- total_training_time: float,
405
- sdr_target_active: int,
406
- architecture_env: Mapping[str, str] | None = None,
407
- eval_diagnostics: Mapping[str, Any] | None = None,
408
- domain_fingerprints: Mapping[str, Any] | None = None,
409
- lineage_payload: Mapping[str, Any] | None = None,
410
- ) -> dict[str, Any]:
411
- """Build final run metrics without conflating skipped eval and validation.
412
-
413
- This helper deliberately preserves ``val_bpb=None`` when final eval did not
414
- complete. HPO can then prune or explicitly label a fallback instead of
415
- accidentally treating live training BPB as validation BPB.
416
- """
417
- payload = dict(secondary_metrics)
418
- payload.update({
419
- 'eval_status': eval_status,
420
- 'eval_error': eval_error,
421
- 'objective_source': 'final_val' if val_bpb is not None else 'missing_final_val',
422
- 'val_bpb': float(val_bpb) if val_bpb is not None else None,
423
- 'val_ppl': float(val_ppl) if val_ppl is not None else None,
424
- 'n_layer': int(n_layer),
425
- 'd_model': int(d_model),
426
- 'num_params_M': float(num_params / 1e6),
427
- 'num_steps': int(step),
428
- 'total_tokens_M': float(total_tokens / 1e6),
429
- 'peak_vram_mb': float(peak_vram_mb),
430
- 'training_seconds': float(total_training_time),
431
- 'sdr_target_active': int(sdr_target_active),
432
- })
433
- payload.update(architecture_compliance_payload(architecture_env or dict(os.environ)))
434
- if eval_diagnostics:
435
- payload.update(dict(eval_diagnostics))
436
- if domain_fingerprints:
437
- payload.update(dict(domain_fingerprints))
438
- if lineage_payload:
439
- payload.update(dict(lineage_payload))
440
- return payload
441
-
442
-
443
- def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
444
  """Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
445
 
446
  Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
@@ -500,14 +500,14 @@ def _try_load_ckpt(path: Path, model, optimizer, device):
500
  return step, total_training_time, smooth_train_loss, bpt_ema, epoch
501
 
502
 
503
- def maybe_resume_ckpt(
504
- model: PostSemClawModel,
505
- optimizer: torch.optim.Optimizer,
506
- device: torch.device,
507
- ) -> tuple[int, float, float, float, int, str | None]:
508
- if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
509
- print("[ckpt] resume disabled; starting fresh", flush=True)
510
- return 0, 0.0, 0.0, 0.0, 0, None
511
 
512
  resume_path = Path(os.path.expanduser(RESUME_CKPT))
513
  # Try the primary path, then rotated backups. This is crucial because a
@@ -521,18 +521,18 @@ def maybe_resume_ckpt(
521
  if not cand.exists():
522
  continue
523
  try:
524
- result = _try_load_ckpt(cand, model, optimizer, device)
525
- if result is not None:
526
- if cand != resume_path:
527
- print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
528
- step, total_training_time, smooth_train_loss, bpt_ema, epoch = result
529
- return step, total_training_time, smooth_train_loss, bpt_ema, epoch, str(cand)
530
  except Exception as e:
531
  print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
532
  continue
533
 
534
- print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
535
- return 0, 0.0, 0.0, 0.0, 0, None
536
 
537
 
538
  # ---------------------------------------------------------------------------
@@ -561,14 +561,14 @@ def main() -> None:
561
 
562
  # Streaming path skips prepare.py (which normally trains the tokenizer
563
  # and builds the retina), so we must materialize both before model init.
564
- if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
565
- _p_nemo.ensure_tokenizer()
566
- if os.environ.get("HYDRA_THROUGHPUT_MODE", "0") != "1":
567
- # Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
568
- # returns in seconds; otherwise build_retina streams Nemotron docs to
569
- # compute cooccurrence + train SOM, then uploads back to the cache.
570
- import subsystems.sdr_retina as _sdr_retina
571
- _sdr_retina.build_retina()
572
  tokenizer = Tokenizer.from_directory()
573
  vocab_size = tokenizer.get_vocab_size()
574
  print(f"Vocab size: {vocab_size:,}")
@@ -614,18 +614,18 @@ def main() -> None:
614
  weight_decay=WEIGHT_DECAY,
615
  )
616
 
617
- step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch, resume_loaded_path = maybe_resume_ckpt(
618
- model, optimizer, device,
619
- )
620
- lineage_payload = build_lineage_payload(
621
- env=dict(os.environ),
622
- seed=SEED,
623
- resume_requested=bool(RESUME_CKPT and RESUME_CKPT.lower() != "none"),
624
- resume_requested_path=RESUME_CKPT if RESUME_CKPT and RESUME_CKPT.lower() != "none" else None,
625
- resume_loaded_path=resume_loaded_path,
626
- resume_step=step,
627
- resume_epoch=resume_epoch,
628
- )
629
 
630
  # Learnability #4: inform the model of the BOS token id so it can mask
631
  # doc-separator positions in packed sequences. Always set (the mask only
@@ -1020,22 +1020,22 @@ def main() -> None:
1020
  # does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
1021
  # how many val tokens to sweep (default 2 M, short enough for autoresearch
1022
  # 5-min budgets).
1023
- val_bpb: float | None = None
1024
- val_ppl: float | None = None
1025
- eval_status = "not_started"
1026
- eval_error: str | None = None
1027
- _eval_B = int(os.environ.get("HYDRA_EVAL_BATCH", str(max(1, DEVICE_BATCH_SIZE // 2))))
1028
- _eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(2 * 524288)))
1029
- _eval_chunk_tokens = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(_eval_tokens)))
1030
- _eval_min_batch = int(os.environ.get("HYDRA_EVAL_MIN_BATCH", "1"))
1031
- eval_diagnostics = build_eval_plan(
1032
- eval_tokens=_eval_tokens,
1033
- requested_batch=_eval_B,
1034
- max_seq_len=MAX_SEQ_LEN,
1035
- chunk_tokens=_eval_chunk_tokens,
1036
- min_batch=_eval_min_batch,
1037
- )
1038
- try:
1039
  # Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
1040
  # which leaves < 1GB for the eval forward — the driver can't satisfy
1041
  # the allocation. Free EVERY tensor we don't strictly need:
@@ -1057,70 +1057,70 @@ def main() -> None:
1057
  model._last_sdr = None
1058
  import gc as _gc
1059
  _gc.collect()
1060
- torch.cuda.empty_cache()
1061
- torch.cuda.synchronize()
1062
- try:
1063
- _free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
1064
- eval_diagnostics["eval_free_vram_before_mb"] = float(_free_mb)
1065
- print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
1066
- except Exception:
1067
- pass
1068
- print(
1069
- f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B} "
1070
- f"chunk_tokens={eval_diagnostics['eval_chunk_tokens']} attempts={eval_diagnostics['eval_attempt_batches']}...",
1071
- flush=True,
1072
- )
1073
- model.eval()
1074
- _orig = _prepare_mod.EVAL_TOKENS
1075
- _orig_chunk = getattr(_prepare_mod, "EVAL_CHUNK_TOKENS", _eval_tokens)
1076
- _prepare_mod.EVAL_TOKENS = _eval_tokens
1077
- _prepare_mod.EVAL_CHUNK_TOKENS = int(eval_diagnostics["eval_chunk_tokens"])
1078
- _successful_batch: int | None = None
1079
- _attempts: list[int] = []
1080
- try:
1081
- for _attempt_batch in eval_diagnostics["eval_attempt_batches"]:
1082
- _attempts.append(int(_attempt_batch))
1083
- eval_diagnostics["eval_attempted_batch"] = int(_attempt_batch)
1084
- try:
1085
- with autocast_ctx:
1086
- val_bpb = evaluate_bpb(model, tokenizer, int(_attempt_batch))
1087
- _successful_batch = int(_attempt_batch)
1088
- break
1089
- except torch.cuda.OutOfMemoryError as _attempt_oom:
1090
- eval_error = str(_attempt_oom)
1091
- eval_status = "oom"
1092
- torch.cuda.empty_cache()
1093
- if int(_attempt_batch) == eval_diagnostics["eval_attempt_batches"][-1]:
1094
- raise
1095
- finally:
1096
- _prepare_mod.EVAL_TOKENS = _orig
1097
- _prepare_mod.EVAL_CHUNK_TOKENS = _orig_chunk
1098
- eval_diagnostics["eval_attempt_batches"] = _attempts
1099
- eval_diagnostics["eval_effective_batch"] = _successful_batch
1100
- val_ppl = 2 ** val_bpb
1101
- eval_status = "completed"
1102
- print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
1103
- except torch.cuda.OutOfMemoryError as e:
1104
- eval_status = "oom"
1105
- eval_error = str(e)
1106
- print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
1107
- torch.cuda.empty_cache()
1108
- try:
1109
- eval_diagnostics["eval_free_vram_after_mb"] = float(torch.cuda.mem_get_info()[0] / 1024 / 1024)
1110
- except Exception:
1111
- pass
1112
- except Exception as e:
1113
- import traceback as _tb
1114
- eval_status = type(e).__name__
1115
- eval_error = str(e)
1116
- print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
1117
- _tb.print_exc()
1118
- try:
1119
- _free = torch.cuda.mem_get_info()[0] / 1024 / 1024
1120
- eval_diagnostics["eval_free_vram_after_mb"] = float(_free)
1121
- print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
1122
- except Exception:
1123
- pass
1124
 
1125
  # Final ckpts with val_bpb filled in (if eval succeeded).
1126
  save_ckpt(
@@ -1164,13 +1164,13 @@ def main() -> None:
1164
  / total_training_time / GPU_BF16_PEAK_FLOPS
1165
  if total_training_time > 0 else 0
1166
  )
1167
- peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
1168
- metrics = model.get_secondary_metrics()
1169
- domain_fingerprints = dataset_domain_payload(
1170
- env=dict(os.environ),
1171
- prepare_module=_prepare_mod,
1172
- nemotron_module=globals().get("_p_nemo"),
1173
- )
1174
 
1175
  print("---")
1176
  print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
@@ -1206,28 +1206,28 @@ def main() -> None:
1206
  # Emit full metrics dictionary as JSON for sweep aggregation. Path from
1207
  # HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
1208
  # written (even without diagnostics) so the aggregator can compare runs.
1209
- _metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
1210
- try:
1211
- _dump = build_final_metrics_payload(
1212
- secondary_metrics=metrics,
1213
- val_bpb=val_bpb,
1214
- val_ppl=val_ppl,
1215
- eval_status=eval_status,
1216
- eval_error=eval_error,
1217
- n_layer=N_LAYER,
1218
- d_model=D_MODEL,
1219
- num_params=num_params,
1220
- step=step,
1221
- total_tokens=total_tokens,
1222
- peak_vram_mb=peak_vram_mb,
1223
- total_training_time=total_training_time,
1224
- sdr_target_active=int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
1225
- architecture_env=dict(os.environ),
1226
- eval_diagnostics=eval_diagnostics,
1227
- domain_fingerprints=domain_fingerprints,
1228
- lineage_payload=lineage_payload,
1229
- )
1230
- Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
1231
  with open(_metrics_out, 'w') as _f:
1232
  json.dump(_dump, _f, indent=2, sort_keys=True)
1233
  print(f"[METRICS] wrote {_metrics_out}", flush=True)
 
4
  preserved. Public entrypoint: `main()`.
5
  """
6
 
7
+ from __future__ import annotations
8
+
9
+ import gc
10
+ import hashlib
11
+ import json
12
+ import math
13
+ import os
14
+ import sys
15
+ import threading
16
+ import time
17
+ from collections.abc import Mapping
18
+ from dataclasses import asdict
19
+ from pathlib import Path
20
+ from typing import Any
21
 
22
  import torch
23
 
 
133
  return msd, osd
134
 
135
 
136
+ def save_ckpt(
137
  model: PostSemClawModel,
138
  optimizer: torch.optim.Optimizer,
139
  config: PostSemClawConfig,
 
214
  target=_write, daemon=True, name=f"ckpt-save-{step}"
215
  )
216
  _CKPT_WORKER_THREAD.start()
217
+ except Exception as e:
218
+ print(f"[ckpt] SNAPSHOT FAILED {path}: {type(e).__name__}: {e}", flush=True)
219
+
220
+
221
+ def _env_flag_enabled(env: Mapping[str, str], key: str) -> bool:
222
+ value = str(env.get(key, "0") or "0").strip().lower()
223
+ return value not in {"", "0", "false", "no", "off"}
224
+
225
+
226
+ def _env_int(env: Mapping[str, str], key: str, default: int) -> int:
227
+ try:
228
+ return int(str(env.get(key, str(default)) or str(default)))
229
+ except ValueError:
230
+ return default
231
+
232
+
233
+ def architecture_compliance_payload(env: Mapping[str, str]) -> dict[str, bool | int | str]:
234
+ throughput_mode = _env_flag_enabled(env, "HYDRA_THROUGHPUT_MODE")
235
+ fastpath = _env_flag_enabled(env, "HYDRA_FASTPATH")
236
+ force_htm_cpu = _env_flag_enabled(env, "HYDRA_FORCE_HTM_CPU")
237
+ inert_mamba = _env_flag_enabled(env, "HYDRA_INERT_MAMBA")
238
+ synthetic_retina = _env_flag_enabled(env, "HYDRA_ALLOW_SYNTHETIC_RETINA")
239
+ hyena_layers = str(env.get("HYDRA_HYENA_LAYERS", "") or "")
240
+ engram_subsample = _env_int(env, "HYDRA_ENGRAM_SUBSAMPLE", 1)
241
+ htm_subsample = _env_int(env, "HYDRA_HTM_SUBSAMPLE", 1)
242
+ full_arch_compliant = not any((
243
+ throughput_mode,
244
+ fastpath,
245
+ force_htm_cpu,
246
+ inert_mamba,
247
+ synthetic_retina,
248
+ bool(hyena_layers.strip()),
249
+ ))
250
+ return {
251
+ 'full_arch_compliant': full_arch_compliant,
252
+ 'throughput_mode': throughput_mode,
253
+ 'fastpath': fastpath,
254
+ 'force_htm_cpu': force_htm_cpu,
255
+ 'inert_mamba': inert_mamba,
256
+ 'synthetic_retina': synthetic_retina,
257
+ 'hyena_layers': hyena_layers,
258
+ 'engram_subsample': engram_subsample,
259
+ 'htm_subsample': htm_subsample,
260
+ }
261
+
262
+
263
+ def eval_attempt_batches(*, requested_batch: int, min_batch: int) -> list[int]:
264
+ requested = max(1, int(requested_batch))
265
+ minimum = max(1, int(min_batch))
266
+ batches: list[int] = []
267
+ current = requested
268
+ while current >= minimum:
269
+ if current not in batches:
270
+ batches.append(current)
271
+ if current == minimum:
272
+ break
273
+ next_batch = max(minimum, current // 2)
274
+ if next_batch == current:
275
+ break
276
+ current = next_batch
277
+ if minimum not in batches:
278
+ batches.append(minimum)
279
+ return batches
280
+
281
+
282
+ def build_eval_plan(*, eval_tokens: int, requested_batch: int, max_seq_len: int, chunk_tokens: int, min_batch: int) -> dict[str, Any]:
283
+ effective_chunk_tokens = max(int(chunk_tokens), int(requested_batch) * int(max_seq_len))
284
+ chunk_count = max(1, math.ceil(int(eval_tokens) / effective_chunk_tokens))
285
+ return {
286
+ 'eval_tokens': int(eval_tokens),
287
+ 'eval_requested_batch': int(requested_batch),
288
+ 'eval_chunk_tokens': int(effective_chunk_tokens),
289
+ 'eval_chunk_count': int(chunk_count),
290
+ 'eval_attempt_batches': eval_attempt_batches(requested_batch=requested_batch, min_batch=min_batch),
291
+ 'eval_min_batch': int(max(1, min_batch)),
292
+ }
293
+
294
+
295
+ def _fingerprint_descriptor(descriptor: Mapping[str, Any]) -> str:
296
+ payload = json.dumps(dict(descriptor), sort_keys=True, separators=(",", ":"))
297
+ return hashlib.sha1(payload.encode("utf-8")).hexdigest()[:12]
298
+
299
+
300
+ def dataset_domain_payload(*, env: Mapping[str, str], prepare_module: Any, nemotron_module: Any | None) -> dict[str, Any]:
301
+ use_nemotron = _env_flag_enabled(env, "HYDRA_USE_NEMOTRON")
302
+ vocab_size = int(getattr(prepare_module, "VOCAB_SIZE", 0))
303
+
304
+ if use_nemotron and nemotron_module is not None:
305
+ use_full_blend = _env_flag_enabled(env, "HYDRA_USE_FULL_BLEND")
306
+ phase = str(env.get("HYDRA_NEMOTRON_PHASE", "phase1") or "phase1").strip().lower()
307
+ if use_full_blend:
308
+ train_weights = dict(getattr(nemotron_module, "FULL_BLEND_WEIGHTS", {}))
309
+ val_weights = dict(train_weights)
310
+ else:
311
+ train_weights = dict(
312
+ getattr(nemotron_module, "PHASE2_WEIGHTS", {}) if phase == "phase2" else getattr(nemotron_module, "PHASE1_WEIGHTS", {})
313
+ )
314
+ val_weights = {"Nemotron-Pretraining-Multiple-Choice": 1.0}
315
+ train_descriptor = {
316
+ "backend": "nemotron_stream",
317
+ "phase": "full_blend" if use_full_blend else phase,
318
+ "weights": train_weights,
319
+ "factual_inject_rate": _env_int(env, "HYDRA_FACTUAL_INJECT_RATE", 50),
320
+ "vocab_size": vocab_size,
321
+ }
322
+ val_descriptor = {
323
+ "backend": "nemotron_stream",
324
+ "phase": "full_blend" if use_full_blend else "val_multiple_choice",
325
+ "weights": val_weights,
326
+ "vocab_size": vocab_size,
327
+ }
328
+ data_backend = "nemotron_stream"
329
+ else:
330
+ all_files = list(getattr(prepare_module, "list_parquet_files", lambda: [])())
331
+ val_filename = str(getattr(prepare_module, "VAL_FILENAME", ""))
332
+ train_files = [str(path) for path in all_files if not str(path).endswith(val_filename)]
333
+ val_files = [str(path) for path in all_files if str(path).endswith(val_filename)]
334
+ train_descriptor = {
335
+ "backend": "climbmix_parquet",
336
+ "train_shard_count": len(train_files),
337
+ "train_shard_examples": sorted(Path(path).name for path in train_files[:3]),
338
+ "vocab_size": vocab_size,
339
+ }
340
+ val_descriptor = {
341
+ "backend": "climbmix_parquet",
342
+ "val_filename": val_filename,
343
+ "val_shard_count": len(val_files),
344
+ "vocab_size": vocab_size,
345
+ }
346
+ data_backend = "climbmix_parquet"
347
+
348
+ train_fingerprint = _fingerprint_descriptor(train_descriptor)
349
+ val_fingerprint = _fingerprint_descriptor(val_descriptor)
350
+ return {
351
+ "data_backend": data_backend,
352
+ "train_domain_descriptor": train_descriptor,
353
+ "val_domain_descriptor": val_descriptor,
354
+ "train_domain_fingerprint": train_fingerprint,
355
+ "val_domain_fingerprint": val_fingerprint,
356
+ "train_val_domain_match": train_fingerprint == val_fingerprint,
357
+ }
358
+
359
+
360
+ def build_lineage_payload(
361
+ *,
362
+ env: Mapping[str, str],
363
+ seed: int,
364
+ resume_requested: bool,
365
+ resume_requested_path: str | None,
366
+ resume_loaded_path: str | None,
367
+ resume_step: int,
368
+ resume_epoch: int,
369
+ ) -> dict[str, Any]:
370
+ warmstart = _env_flag_enabled(env, "HYDRA_WARMSTART")
371
+ resume_applied = resume_loaded_path is not None and int(resume_step) > 0
372
+ if resume_applied and warmstart:
373
+ lineage_mode = "warmstart_resume"
374
+ elif resume_applied:
375
+ lineage_mode = "resume"
376
+ else:
377
+ lineage_mode = "fresh"
378
+ return {
379
+ "seed": int(seed),
380
+ "warmstart": warmstart,
381
+ "resume_requested": bool(resume_requested),
382
+ "resume_applied": resume_applied,
383
+ "resume_requested_path": resume_requested_path,
384
+ "resume_loaded_path": resume_loaded_path,
385
+ "resume_step": int(resume_step),
386
+ "resume_epoch": int(resume_epoch),
387
+ "lineage_mode": lineage_mode,
388
+ }
389
+
390
+
391
+ def build_final_metrics_payload(
392
+ *,
393
+ secondary_metrics: dict[str, Any],
394
+ val_bpb: float | None,
395
+ val_ppl: float | None,
396
+ eval_status: str,
397
+ eval_error: str | None,
398
+ n_layer: int,
399
+ d_model: int,
400
+ num_params: int,
401
+ step: int,
402
+ total_tokens: int,
403
+ peak_vram_mb: float,
404
+ total_training_time: float,
405
+ sdr_target_active: int,
406
+ architecture_env: Mapping[str, str] | None = None,
407
+ eval_diagnostics: Mapping[str, Any] | None = None,
408
+ domain_fingerprints: Mapping[str, Any] | None = None,
409
+ lineage_payload: Mapping[str, Any] | None = None,
410
+ ) -> dict[str, Any]:
411
+ """Build final run metrics without conflating skipped eval and validation.
412
+
413
+ This helper deliberately preserves ``val_bpb=None`` when final eval did not
414
+ complete. HPO can then prune or explicitly label a fallback instead of
415
+ accidentally treating live training BPB as validation BPB.
416
+ """
417
+ payload = dict(secondary_metrics)
418
+ payload.update({
419
+ 'eval_status': eval_status,
420
+ 'eval_error': eval_error,
421
+ 'objective_source': 'final_val' if val_bpb is not None else 'missing_final_val',
422
+ 'val_bpb': float(val_bpb) if val_bpb is not None else None,
423
+ 'val_ppl': float(val_ppl) if val_ppl is not None else None,
424
+ 'n_layer': int(n_layer),
425
+ 'd_model': int(d_model),
426
+ 'num_params_M': float(num_params / 1e6),
427
+ 'num_steps': int(step),
428
+ 'total_tokens_M': float(total_tokens / 1e6),
429
+ 'peak_vram_mb': float(peak_vram_mb),
430
+ 'training_seconds': float(total_training_time),
431
+ 'sdr_target_active': int(sdr_target_active),
432
+ })
433
+ payload.update(architecture_compliance_payload(architecture_env or dict(os.environ)))
434
+ if eval_diagnostics:
435
+ payload.update(dict(eval_diagnostics))
436
+ if domain_fingerprints:
437
+ payload.update(dict(domain_fingerprints))
438
+ if lineage_payload:
439
+ payload.update(dict(lineage_payload))
440
+ return payload
441
+
442
+
443
+ def config_from_dict(cfg_dict: dict) -> PostSemClawConfig:
444
  """Reconstruct a PostSemClawConfig from a checkpoint's asdict() payload.
445
 
446
  Newly-added fields (e.g. `hyena_layers`) are defaulted when absent in
 
500
  return step, total_training_time, smooth_train_loss, bpt_ema, epoch
501
 
502
 
503
+ def maybe_resume_ckpt(
504
+ model: PostSemClawModel,
505
+ optimizer: torch.optim.Optimizer,
506
+ device: torch.device,
507
+ ) -> tuple[int, float, float, float, int, str | None]:
508
+ if not RESUME_CKPT or RESUME_CKPT.lower() == "none":
509
+ print("[ckpt] resume disabled; starting fresh", flush=True)
510
+ return 0, 0.0, 0.0, 0.0, 0, None
511
 
512
  resume_path = Path(os.path.expanduser(RESUME_CKPT))
513
  # Try the primary path, then rotated backups. This is crucial because a
 
521
  if not cand.exists():
522
  continue
523
  try:
524
+ result = _try_load_ckpt(cand, model, optimizer, device)
525
+ if result is not None:
526
+ if cand != resume_path:
527
+ print(f"[ckpt] fell back to rotation {cand.name}", flush=True)
528
+ step, total_training_time, smooth_train_loss, bpt_ema, epoch = result
529
+ return step, total_training_time, smooth_train_loss, bpt_ema, epoch, str(cand)
530
  except Exception as e:
531
  print(f"[ckpt] {cand.name} load failed: {type(e).__name__}: {e}", flush=True)
532
  continue
533
 
534
+ print(f"[ckpt] no usable checkpoint in {resume_path} + rotations; starting fresh", flush=True)
535
+ return 0, 0.0, 0.0, 0.0, 0, None
536
 
537
 
538
  # ---------------------------------------------------------------------------
 
561
 
562
  # Streaming path skips prepare.py (which normally trains the tokenizer
563
  # and builds the retina), so we must materialize both before model init.
564
+ if os.environ.get("HYDRA_USE_NEMOTRON", "0") == "1":
565
+ _p_nemo.ensure_tokenizer()
566
+ if os.environ.get("HYDRA_THROUGHPUT_MODE", "0") != "1":
567
+ # Retina: HF Hub cache hit for this (vocab, n_bits, target_active) combo
568
+ # returns in seconds; otherwise build_retina streams Nemotron docs to
569
+ # compute cooccurrence + train SOM, then uploads back to the cache.
570
+ import subsystems.sdr_retina as _sdr_retina
571
+ _sdr_retina.build_retina()
572
  tokenizer = Tokenizer.from_directory()
573
  vocab_size = tokenizer.get_vocab_size()
574
  print(f"Vocab size: {vocab_size:,}")
 
614
  weight_decay=WEIGHT_DECAY,
615
  )
616
 
617
+ step, total_training_time, smooth_train_loss, bpt_ema, resume_epoch, resume_loaded_path = maybe_resume_ckpt(
618
+ model, optimizer, device,
619
+ )
620
+ lineage_payload = build_lineage_payload(
621
+ env=dict(os.environ),
622
+ seed=SEED,
623
+ resume_requested=bool(RESUME_CKPT and RESUME_CKPT.lower() != "none"),
624
+ resume_requested_path=RESUME_CKPT if RESUME_CKPT and RESUME_CKPT.lower() != "none" else None,
625
+ resume_loaded_path=resume_loaded_path,
626
+ resume_step=step,
627
+ resume_epoch=resume_epoch,
628
+ )
629
 
630
  # Learnability #4: inform the model of the BOS token id so it can mask
631
  # doc-separator positions in packed sequences. Always set (the mask only
 
1020
  # does not benefit from overlap with backward). HYDRA_EVAL_TOKENS controls
1021
  # how many val tokens to sweep (default 2 M, short enough for autoresearch
1022
  # 5-min budgets).
1023
+ val_bpb: float | None = None
1024
+ val_ppl: float | None = None
1025
+ eval_status = "not_started"
1026
+ eval_error: str | None = None
1027
+ _eval_B = int(os.environ.get("HYDRA_EVAL_BATCH", str(max(1, DEVICE_BATCH_SIZE // 2))))
1028
+ _eval_tokens = int(os.environ.get("HYDRA_EVAL_TOKENS", str(2 * 524288)))
1029
+ _eval_chunk_tokens = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(_eval_tokens)))
1030
+ _eval_min_batch = int(os.environ.get("HYDRA_EVAL_MIN_BATCH", "1"))
1031
+ eval_diagnostics = build_eval_plan(
1032
+ eval_tokens=_eval_tokens,
1033
+ requested_batch=_eval_B,
1034
+ max_seq_len=MAX_SEQ_LEN,
1035
+ chunk_tokens=_eval_chunk_tokens,
1036
+ min_batch=_eval_min_batch,
1037
+ )
1038
+ try:
1039
  # Aggressive VRAM reclaim for 6GB cards. Peak training VRAM = 5.1GB
1040
  # which leaves < 1GB for the eval forward — the driver can't satisfy
1041
  # the allocation. Free EVERY tensor we don't strictly need:
 
1057
  model._last_sdr = None
1058
  import gc as _gc
1059
  _gc.collect()
1060
+ torch.cuda.empty_cache()
1061
+ torch.cuda.synchronize()
1062
+ try:
1063
+ _free_mb = torch.cuda.mem_get_info()[0] / 1024 / 1024
1064
+ eval_diagnostics["eval_free_vram_before_mb"] = float(_free_mb)
1065
+ print(f"[VAL] free_vram_mb={_free_mb:.0f} (cleared optimizer state)", flush=True)
1066
+ except Exception:
1067
+ pass
1068
+ print(
1069
+ f"[VAL] running eval on {_eval_tokens} tokens at B={_eval_B} "
1070
+ f"chunk_tokens={eval_diagnostics['eval_chunk_tokens']} attempts={eval_diagnostics['eval_attempt_batches']}...",
1071
+ flush=True,
1072
+ )
1073
+ model.eval()
1074
+ _orig = _prepare_mod.EVAL_TOKENS
1075
+ _orig_chunk = getattr(_prepare_mod, "EVAL_CHUNK_TOKENS", _eval_tokens)
1076
+ _prepare_mod.EVAL_TOKENS = _eval_tokens
1077
+ _prepare_mod.EVAL_CHUNK_TOKENS = int(eval_diagnostics["eval_chunk_tokens"])
1078
+ _successful_batch: int | None = None
1079
+ _attempts: list[int] = []
1080
+ try:
1081
+ for _attempt_batch in eval_diagnostics["eval_attempt_batches"]:
1082
+ _attempts.append(int(_attempt_batch))
1083
+ eval_diagnostics["eval_attempted_batch"] = int(_attempt_batch)
1084
+ try:
1085
+ with autocast_ctx:
1086
+ val_bpb = evaluate_bpb(model, tokenizer, int(_attempt_batch))
1087
+ _successful_batch = int(_attempt_batch)
1088
+ break
1089
+ except torch.cuda.OutOfMemoryError as _attempt_oom:
1090
+ eval_error = str(_attempt_oom)
1091
+ eval_status = "oom"
1092
+ torch.cuda.empty_cache()
1093
+ if int(_attempt_batch) == eval_diagnostics["eval_attempt_batches"][-1]:
1094
+ raise
1095
+ finally:
1096
+ _prepare_mod.EVAL_TOKENS = _orig
1097
+ _prepare_mod.EVAL_CHUNK_TOKENS = _orig_chunk
1098
+ eval_diagnostics["eval_attempt_batches"] = _attempts
1099
+ eval_diagnostics["eval_effective_batch"] = _successful_batch
1100
+ val_ppl = 2 ** val_bpb
1101
+ eval_status = "completed"
1102
+ print(f"[VAL] step={step} val_bpb={val_bpb:.4f} val_ppl={val_ppl:.3f}", flush=True)
1103
+ except torch.cuda.OutOfMemoryError as e:
1104
+ eval_status = "oom"
1105
+ eval_error = str(e)
1106
+ print(f"[VAL] SKIPPED (OOM): {e}", flush=True)
1107
+ torch.cuda.empty_cache()
1108
+ try:
1109
+ eval_diagnostics["eval_free_vram_after_mb"] = float(torch.cuda.mem_get_info()[0] / 1024 / 1024)
1110
+ except Exception:
1111
+ pass
1112
+ except Exception as e:
1113
+ import traceback as _tb
1114
+ eval_status = type(e).__name__
1115
+ eval_error = str(e)
1116
+ print(f"[VAL] SKIPPED ({type(e).__name__}): {e}", flush=True)
1117
+ _tb.print_exc()
1118
+ try:
1119
+ _free = torch.cuda.mem_get_info()[0] / 1024 / 1024
1120
+ eval_diagnostics["eval_free_vram_after_mb"] = float(_free)
1121
+ print(f"[VAL] post-crash free_vram_mb={_free:.0f}", flush=True)
1122
+ except Exception:
1123
+ pass
1124
 
1125
  # Final ckpts with val_bpb filled in (if eval succeeded).
1126
  save_ckpt(
 
1164
  / total_training_time / GPU_BF16_PEAK_FLOPS
1165
  if total_training_time > 0 else 0
1166
  )
1167
+ peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
1168
+ metrics = model.get_secondary_metrics()
1169
+ domain_fingerprints = dataset_domain_payload(
1170
+ env=dict(os.environ),
1171
+ prepare_module=_prepare_mod,
1172
+ nemotron_module=globals().get("_p_nemo"),
1173
+ )
1174
 
1175
  print("---")
1176
  print(f"val_bpb: {val_bpb:.6f}" if val_bpb is not None else "val_bpb: SKIPPED")
 
1206
  # Emit full metrics dictionary as JSON for sweep aggregation. Path from
1207
  # HYDRA_METRICS_OUT env var; default=/tmp/hydra_run_metrics.json. Always
1208
  # written (even without diagnostics) so the aggregator can compare runs.
1209
+ _metrics_out = os.environ.get("HYDRA_METRICS_OUT", "/tmp/hydra_run_metrics.json")
1210
+ try:
1211
+ _dump = build_final_metrics_payload(
1212
+ secondary_metrics=metrics,
1213
+ val_bpb=val_bpb,
1214
+ val_ppl=val_ppl,
1215
+ eval_status=eval_status,
1216
+ eval_error=eval_error,
1217
+ n_layer=N_LAYER,
1218
+ d_model=D_MODEL,
1219
+ num_params=num_params,
1220
+ step=step,
1221
+ total_tokens=total_tokens,
1222
+ peak_vram_mb=peak_vram_mb,
1223
+ total_training_time=total_training_time,
1224
+ sdr_target_active=int(os.environ.get("HYDRA_SDR_TARGET_ACTIVE", "327")),
1225
+ architecture_env=dict(os.environ),
1226
+ eval_diagnostics=eval_diagnostics,
1227
+ domain_fingerprints=domain_fingerprints,
1228
+ lineage_payload=lineage_payload,
1229
+ )
1230
+ Path(_metrics_out).parent.mkdir(parents=True, exist_ok=True)
1231
  with open(_metrics_out, 'w') as _f:
1232
  json.dump(_dump, _f, indent=2, sort_keys=True)
1233
  print(f"[METRICS] wrote {_metrics_out}", flush=True)
overlay/prepare.py CHANGED
@@ -13,10 +13,10 @@ import os
13
  import sys
14
  import time
15
  import math
16
- import argparse
17
- import pickle
18
- from multiprocessing import Pool
19
- from typing import Any
20
 
21
  import requests
22
  import pyarrow.parquet as pq
@@ -30,8 +30,8 @@ import torch
30
 
31
  MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) # context length
32
  TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
33
- EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
34
- EVAL_CHUNK_TOKENS = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(EVAL_TOKENS)))
35
 
36
  # ---------------------------------------------------------------------------
37
  # Configuration
@@ -160,8 +160,8 @@ def train_tokenizer():
160
  print("Tokenizer: training BPE tokenizer...")
161
  t0 = time.time()
162
 
163
- tokenizer_cls = getattr(rustbpe, "Tokenizer")
164
- tokenizer: Any = tokenizer_cls()
165
  vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
166
  tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
167
 
@@ -228,10 +228,10 @@ class Tokenizer:
228
  def get_bos_token_id(self):
229
  return self.bos_token_id
230
 
231
- def encode(self, text, prepend=None, num_threads=8):
232
- prepend_id = None
233
- if prepend is not None:
234
- prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
235
  if isinstance(text, str):
236
  ids = self.enc.encode_ordinary(text)
237
  if prepend is not None:
@@ -249,7 +249,7 @@ class Tokenizer:
249
  return self.enc.decode(ids)
250
 
251
 
252
- _TOKEN_BYTES_CACHE: dict[str, torch.Tensor] = {}
253
 
254
  def get_token_bytes(device="cpu"):
255
  key = str(device)
@@ -345,30 +345,30 @@ def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
345
  gpu_buffer.copy_(cpu_buffer, non_blocking=True)
346
  yield inputs, targets, epoch
347
 
348
- # ---------------------------------------------------------------------------
349
- # Evaluation (DO NOT CHANGE — this is the fixed metric)
350
- # ---------------------------------------------------------------------------
351
-
352
- def compute_bpb_from_totals(total_nats: torch.Tensor, total_bytes: torch.Tensor) -> torch.Tensor:
353
- if int(total_bytes.item()) <= 0:
354
- raise ValueError("BPB normalization requires at least one non-special token")
355
- return total_nats.to(dtype=torch.float64) / (math.log(2) * total_bytes.to(dtype=torch.float64))
356
-
357
-
358
- def compute_bpb_from_losses(loss_flat: torch.Tensor, nbytes: torch.Tensor) -> torch.Tensor:
359
- """Convert per-token losses and token byte lengths into bits-per-byte.
360
-
361
- Tokens with zero byte length (special tokens) are excluded from both the
362
- numerator and denominator so BPB remains comparable across tokenizer
363
- special-token conventions.
364
- """
365
- mask = nbytes > 0
366
- total_nats = (loss_flat * mask).sum(dtype=torch.float64)
367
- total_bytes = nbytes[mask].sum(dtype=torch.int64)
368
- return compute_bpb_from_totals(total_nats, total_bytes)
369
-
370
- @torch.no_grad()
371
- def evaluate_bpb(model, tokenizer, batch_size):
372
  """
373
  Bits per byte (BPB): vocab size-independent evaluation metric.
374
  Sums per-token cross-entropy (in nats), sums target byte lengths,
@@ -379,35 +379,35 @@ def evaluate_bpb(model, tokenizer, batch_size):
379
  Perf: accumulates on GPU (single sync at end), prefetches next batch
380
  while current forward runs.
381
  """
382
- token_bytes = get_token_bytes(device="cuda")
383
- val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
384
- steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
385
- chunk_steps = max(1, EVAL_CHUNK_TOKENS // (batch_size * MAX_SEQ_LEN))
386
 
387
  # GPU-resident accumulators — avoid per-batch .item() sync
388
  total_nats_t = torch.zeros(1, device="cuda", dtype=torch.float64)
389
  total_bytes_t = torch.zeros(1, device="cuda", dtype=torch.int64)
390
 
391
  # Prefetch first batch
392
- next_batch = next(val_loader)
393
- steps_done = 0
394
- while steps_done < steps:
395
- this_chunk = min(chunk_steps, steps - steps_done)
396
- for _ in range(this_chunk):
397
- x, y, _epoch = next_batch
398
- # Prefetch NEXT batch while GPU computes current forward
399
- next_batch = next(val_loader)
400
- loss_flat = model(x, y, reduction='none').view(-1)
401
- y_flat = y.view(-1)
402
- nbytes = token_bytes[y_flat]
403
- total_nats_t += (loss_flat * (nbytes > 0)).sum(dtype=torch.float64)
404
- total_bytes_t += nbytes[nbytes > 0].sum(dtype=torch.int64)
405
- steps_done += this_chunk
406
- if steps_done < steps:
407
- torch.cuda.empty_cache()
408
-
409
- # Single GPU→CPU sync at end
410
- return float(compute_bpb_from_totals(total_nats_t, total_bytes_t).item())
411
 
412
  # ---------------------------------------------------------------------------
413
  # Main
 
13
  import sys
14
  import time
15
  import math
16
+ import argparse
17
+ import pickle
18
+ from multiprocessing import Pool
19
+ from typing import Any
20
 
21
  import requests
22
  import pyarrow.parquet as pq
 
30
 
31
  MAX_SEQ_LEN = int(os.environ.get("HYDRA_SEQ_LEN", "512")) # context length
32
  TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
33
+ EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
34
+ EVAL_CHUNK_TOKENS = int(os.environ.get("HYDRA_EVAL_CHUNK_TOKENS", str(EVAL_TOKENS)))
35
 
36
  # ---------------------------------------------------------------------------
37
  # Configuration
 
160
  print("Tokenizer: training BPE tokenizer...")
161
  t0 = time.time()
162
 
163
+ tokenizer_cls = getattr(rustbpe, "Tokenizer")
164
+ tokenizer: Any = tokenizer_cls()
165
  vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
166
  tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
167
 
 
228
  def get_bos_token_id(self):
229
  return self.bos_token_id
230
 
231
+ def encode(self, text, prepend=None, num_threads=8):
232
+ prepend_id = None
233
+ if prepend is not None:
234
+ prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
235
  if isinstance(text, str):
236
  ids = self.enc.encode_ordinary(text)
237
  if prepend is not None:
 
249
  return self.enc.decode(ids)
250
 
251
 
252
+ _TOKEN_BYTES_CACHE: dict[str, torch.Tensor] = {}
253
 
254
  def get_token_bytes(device="cpu"):
255
  key = str(device)
 
345
  gpu_buffer.copy_(cpu_buffer, non_blocking=True)
346
  yield inputs, targets, epoch
347
 
348
+ # ---------------------------------------------------------------------------
349
+ # Evaluation (DO NOT CHANGE — this is the fixed metric)
350
+ # ---------------------------------------------------------------------------
351
+
352
+ def compute_bpb_from_totals(total_nats: torch.Tensor, total_bytes: torch.Tensor) -> torch.Tensor:
353
+ if int(total_bytes.item()) <= 0:
354
+ raise ValueError("BPB normalization requires at least one non-special token")
355
+ return total_nats.to(dtype=torch.float64) / (math.log(2) * total_bytes.to(dtype=torch.float64))
356
+
357
+
358
+ def compute_bpb_from_losses(loss_flat: torch.Tensor, nbytes: torch.Tensor) -> torch.Tensor:
359
+ """Convert per-token losses and token byte lengths into bits-per-byte.
360
+
361
+ Tokens with zero byte length (special tokens) are excluded from both the
362
+ numerator and denominator so BPB remains comparable across tokenizer
363
+ special-token conventions.
364
+ """
365
+ mask = nbytes > 0
366
+ total_nats = (loss_flat * mask).sum(dtype=torch.float64)
367
+ total_bytes = nbytes[mask].sum(dtype=torch.int64)
368
+ return compute_bpb_from_totals(total_nats, total_bytes)
369
+
370
+ @torch.no_grad()
371
+ def evaluate_bpb(model, tokenizer, batch_size):
372
  """
373
  Bits per byte (BPB): vocab size-independent evaluation metric.
374
  Sums per-token cross-entropy (in nats), sums target byte lengths,
 
379
  Perf: accumulates on GPU (single sync at end), prefetches next batch
380
  while current forward runs.
381
  """
382
+ token_bytes = get_token_bytes(device="cuda")
383
+ val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
384
+ steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
385
+ chunk_steps = max(1, EVAL_CHUNK_TOKENS // (batch_size * MAX_SEQ_LEN))
386
 
387
  # GPU-resident accumulators — avoid per-batch .item() sync
388
  total_nats_t = torch.zeros(1, device="cuda", dtype=torch.float64)
389
  total_bytes_t = torch.zeros(1, device="cuda", dtype=torch.int64)
390
 
391
  # Prefetch first batch
392
+ next_batch = next(val_loader)
393
+ steps_done = 0
394
+ while steps_done < steps:
395
+ this_chunk = min(chunk_steps, steps - steps_done)
396
+ for _ in range(this_chunk):
397
+ x, y, _epoch = next_batch
398
+ # Prefetch NEXT batch while GPU computes current forward
399
+ next_batch = next(val_loader)
400
+ loss_flat = model(x, y, reduction='none').view(-1)
401
+ y_flat = y.view(-1)
402
+ nbytes = token_bytes[y_flat]
403
+ total_nats_t += (loss_flat * (nbytes > 0)).sum(dtype=torch.float64)
404
+ total_bytes_t += nbytes[nbytes > 0].sum(dtype=torch.int64)
405
+ steps_done += this_chunk
406
+ if steps_done < steps:
407
+ torch.cuda.empty_cache()
408
+
409
+ # Single GPU→CPU sync at end
410
+ return float(compute_bpb_from_totals(total_nats_t, total_bytes_t).item())
411
 
412
  # ---------------------------------------------------------------------------
413
  # Main
overlay/prepare_nemotron.py CHANGED
@@ -20,16 +20,15 @@ Full blend mode (env HYDRA_USE_FULL_BLEND=1):
20
  """
21
  from __future__ import annotations
22
 
23
- import os
24
- import random
25
- import importlib
26
- import shutil
27
- from itertools import cycle
28
- from typing import Any, Iterator, cast
29
 
30
  import torch
31
 
32
- import prepare as _p # reuse tokenizer, BOS, byte-length helpers
33
 
34
  NEMOTRON_REPO = "nvidia/Nemotron-Pretraining-Specialized-v1.1"
35
 
@@ -66,96 +65,94 @@ PHASE1_WEIGHTS = {
66
  "Nemotron-Pretraining-Formal-Logic": 0.20,
67
  "Nemotron-Pretraining-Multiple-Choice": 0.20,
68
  }
69
- PHASE2_WEIGHTS = {
70
  "Nemotron-Pretraining-Multiple-Choice": 0.45,
71
  "Nemotron-Pretraining-Economics": 0.20,
72
  "Nemotron-Pretraining-Formal-Logic": 0.15,
73
  "Nemotron-Pretraining-Code-Concepts": 0.10,
74
  "Nemotron-Pretraining-Unconditional-Algorithmic": 0.10,
75
- }
76
-
77
- StreamBatch = tuple[list[str], int]
78
- TokenBatch = tuple[list[list[int]], int]
79
-
80
-
81
- def _tokenizer_cache_repo() -> str:
82
- return (
83
- os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
84
- or os.environ.get("FEATHER_HF_OUTPUT_REPO")
85
- or os.environ.get("HF_REPO_ID")
86
- or os.environ.get("HYDRA_RETINA_CACHE_REPO")
87
- or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
88
- or ""
89
- )
90
-
91
-
92
- def _tokenizer_cache_prefix() -> str:
93
- return f"tokenizer/vocab{_p.VOCAB_SIZE}"
94
-
95
-
96
- def maybe_hydrate_tokenizer_cache() -> bool:
97
- """Try to download tokenizer artifacts from HF cache storage."""
98
- repo_id = _tokenizer_cache_repo()
99
- token = os.environ.get("HF_TOKEN")
100
- if not repo_id or not token:
101
- return False
102
-
103
- try:
104
- from huggingface_hub import hf_hub_download
105
- except Exception as e: # noqa: BLE001
106
- print(f"[nemotron] tokenizer cache unavailable: {type(e).__name__}: {e}", flush=True)
107
- return False
108
-
109
- os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
110
- prefix = _tokenizer_cache_prefix()
111
- try:
112
- tok_src = hf_hub_download(
113
- repo_id=repo_id,
114
- repo_type="model",
115
- subfolder=prefix,
116
- filename="tokenizer.pkl",
117
- token=token,
118
- local_dir=_p.TOKENIZER_DIR,
119
- )
120
- token_bytes_src = hf_hub_download(
121
- repo_id=repo_id,
122
- repo_type="model",
123
- subfolder=prefix,
124
- filename="token_bytes.pt",
125
- token=token,
126
- local_dir=_p.TOKENIZER_DIR,
127
- )
128
- shutil.copy2(tok_src, os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl"))
129
- shutil.copy2(token_bytes_src, os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt"))
130
- except Exception as e: # noqa: BLE001
131
- print(f"[nemotron] tokenizer cache miss in {repo_id}/{prefix}: {type(e).__name__}: {e}", flush=True)
132
- return False
133
-
134
- print(f"[nemotron] hydrated tokenizer cache from {repo_id}/{prefix}", flush=True)
135
- return True
136
-
137
-
138
- def upload_tokenizer_cache() -> None:
139
- """Upload tokenizer artifacts for reuse by future jobs."""
140
- repo_id = _tokenizer_cache_repo()
141
- token = os.environ.get("HF_TOKEN")
142
- if not repo_id or not token:
143
- return
144
-
145
- path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
146
- token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
147
- if not (os.path.exists(path) and os.path.exists(token_bytes_path)):
148
- return
149
-
150
- try:
151
- from huggingface_hub import HfApi
152
- api = HfApi(token=token)
153
- prefix = _tokenizer_cache_prefix()
154
- api.upload_file(path_or_fileobj=path, path_in_repo=f"{prefix}/tokenizer.pkl", repo_id=repo_id, repo_type="model")
155
- api.upload_file(path_or_fileobj=token_bytes_path, path_in_repo=f"{prefix}/token_bytes.pt", repo_id=repo_id, repo_type="model")
156
- print(f"[nemotron] uploaded tokenizer cache to {repo_id}/{prefix}", flush=True)
157
- except Exception as e: # noqa: BLE001
158
- print(f"[nemotron] tokenizer cache upload skipped: {type(e).__name__}: {e}", flush=True)
159
 
160
 
161
  def _phase_weights() -> dict[str, float]:
@@ -166,7 +163,7 @@ def _phase_weights() -> dict[str, float]:
166
  return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
167
 
168
 
169
- def _open_stream(config: str, split: str):
170
  """Open a streaming iterator over one dataset config.
171
 
172
  Handles two modes:
@@ -177,17 +174,17 @@ def _open_stream(config: str, split: str):
177
 
178
  Yields dicts; text extraction handled downstream by _extract_text.
179
  """
180
- load_dataset = importlib.import_module("datasets").load_dataset
181
- token = os.environ.get("HF_TOKEN")
182
- shuffle_buf = int(os.environ.get("HYDRA_STREAM_SHUFFLE_BUFFER", "2048"))
183
-
184
- if config in _BLEND_REGISTRY:
185
- repo, name, _text_col = _BLEND_REGISTRY[config]
186
- kwargs: dict[str, object] = dict(
187
- split="train",
188
- streaming=True,
189
- token=token,
190
- )
191
  if name is not None:
192
  kwargs["name"] = name
193
  # nemotron-specialized has multiple sub-configs; pick the first one
@@ -209,18 +206,18 @@ def _open_stream(config: str, split: str):
209
  return iter(ds)
210
 
211
 
212
- def _extract_text(row: dict[str, object]) -> str:
213
  """Pick the right text column — datasets have different column names.
214
 
215
  Priority order: text, content, prompt_completion, question, body.
216
  For math datasets that split into problem+solution, concatenate both.
217
  Fallback: concatenate all string-valued fields.
218
  """
219
- # Fast path: most datasets use "text" or "content".
220
- for k in ("text", "content", "prompt_completion", "question", "body"):
221
- value = row.get(k)
222
- if isinstance(value, str) and value:
223
- return value
224
  # Math datasets may have problem + solution as separate fields.
225
  if "problem" in row and "solution" in row:
226
  p = row["problem"] or ""
@@ -236,20 +233,20 @@ def _extract_text(row: dict[str, object]) -> str:
236
  return "\n".join(parts)
237
 
238
 
239
- class _WeightedStream:
240
  """Infinite weighted-round-robin over configs' streaming iterators."""
241
 
242
- def __init__(self, weights: dict[str, float], seed: int = 0):
243
- self.configs = list(weights.keys())
244
- self.weights = [weights[c] for c in self.configs]
245
- self.streams: dict[str, Iterator[dict[str, object]]] = {
246
- c: _open_stream(c, "train") for c in self.configs
247
- }
248
- self.rng = random.Random(seed)
249
- self.epoch = 1
250
- self._factual_docs: list[str] | None = None
251
- self._factual_idx = 0
252
- self._inject_counter = 0
253
 
254
  def _reopen(self, config: str):
255
  # stream exhausted — reopen (HF streaming typically infinite but restart on edge)
@@ -265,20 +262,20 @@ class _WeightedStream:
265
  # exist in the Nemotron configs. Controlled by HYDRA_FACTUAL_INJECT_RATE
266
  # (default 50 = inject one factual doc every 50 Nemotron docs = ~2%).
267
  inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
268
- if inject_rate > 0 and self._factual_docs is None:
269
- factual_path = os.path.join(
270
- os.path.dirname(os.path.abspath(__file__)), "data", "factual", "facts.txt")
271
- if os.path.exists(factual_path):
272
- self._factual_docs = open(factual_path).read().strip().split('\n')
273
- self._factual_idx = 0
274
- self._inject_counter = 0
275
- if inject_rate > 0 and self._factual_docs:
276
- self._inject_counter += 1
277
- if self._inject_counter >= inject_rate:
278
- self._inject_counter = 0
279
- doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
280
- self._factual_idx += 1
281
- return doc, self.epoch
282
 
283
  config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
284
  try:
@@ -311,9 +308,9 @@ def _document_batches(split: str, tokenizer_batch_size: int = 128) -> Iterator[t
311
  stream = _WeightedStream(_phase_weights(), seed=0)
312
 
313
  prefetch_depth = int(os.environ.get("HYDRA_STREAM_PREFETCH", "32"))
314
- q: queue.Queue[StreamBatch | object] = queue.Queue(maxsize=prefetch_depth)
315
- sentinel_stop = object()
316
- error_box: list[BaseException] = []
317
 
318
  def producer():
319
  try:
@@ -338,7 +335,7 @@ def _document_batches(split: str, tokenizer_batch_size: int = 128) -> Iterator[t
338
  if error_box:
339
  raise error_box[0]
340
  return
341
- yield cast(StreamBatch, item)
342
 
343
 
344
  def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 1000):
@@ -364,9 +361,9 @@ def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 10
364
  # Stage 2: tokenization prefetch thread. Each queue element is a list of
365
  # token-id lists (pre-tokenized docs). HYDRA_TOKEN_PREFETCH controls depth.
366
  tok_prefetch = int(os.environ.get("HYDRA_TOKEN_PREFETCH", "8"))
367
- tok_q: queue.Queue[TokenBatch | object] = queue.Queue(maxsize=tok_prefetch)
368
- tok_sentinel = object()
369
- tok_err_box: list[BaseException] = []
370
 
371
  def tokenizer_producer():
372
  try:
@@ -390,8 +387,8 @@ def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 10
390
  if tok_err_box:
391
  raise tok_err_box[0]
392
  raise StopIteration
393
- token_lists, epoch = cast(TokenBatch, item)
394
- doc_buffer.extend(token_lists)
395
 
396
  row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
397
  cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
@@ -465,24 +462,24 @@ def evaluate_bpb(model, tokenizer, B: int) -> float:
465
  return total_nats / (math.log(2) * max(total_bytes, 1))
466
 
467
 
468
- def ensure_tokenizer():
469
  """Ensure rustbpe tokenizer exists. If absent, train on a Nemotron stream
470
  sample using the same rustbpe.train_from_iterator API that prepare.py uses
471
  (production path — don't fork tokenizer training logic).
472
  """
473
  import pickle
474
  import torch
475
- path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
476
- token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
477
- if os.path.exists(path) and os.path.exists(token_bytes_path):
478
- print(f"[nemotron] tokenizer + token_bytes already trained at {_p.TOKENIZER_DIR}", flush=True)
479
- return
480
- if maybe_hydrate_tokenizer_cache() and os.path.exists(path) and os.path.exists(token_bytes_path):
481
- return
482
- os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
483
  print(f"[nemotron] training BPE (vocab_size={_p.VOCAB_SIZE}) on stream sample…", flush=True)
484
- import rustbpe
485
- import tiktoken
486
 
487
  # Pull a sample of docs — use full blend if active so BPE covers all 7 sources.
488
  n_docs = int(os.environ.get("HYDRA_BPE_TRAIN_DOCS", "20000"))
@@ -498,8 +495,8 @@ def ensure_tokenizer():
498
  print(f"[nemotron] collected {len(sample_texts)} sample docs; training BPE…", flush=True)
499
 
500
  # Train rustbpe — identical API to prepare.py's train_tokenizer().
501
- tokenizer_cls = getattr(rustbpe, "Tokenizer")
502
- tokenizer: Any = tokenizer_cls()
503
  vocab_size_no_special = _p.VOCAB_SIZE - len(_p.SPECIAL_TOKENS)
504
  tokenizer.train_from_iterator(iter(sample_texts), vocab_size_no_special, pattern=_p.SPLIT_PATTERN)
505
 
@@ -524,7 +521,7 @@ def ensure_tokenizer():
524
  for token_id in range(enc.n_vocab):
525
  tstr = enc.decode([token_id])
526
  token_bytes_list.append(0 if tstr in special_set else len(tstr.encode("utf-8")))
527
- token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
528
- torch.save(token_bytes_tensor, token_bytes_path)
529
- print(f"[nemotron] BPE + token_bytes saved to {_p.TOKENIZER_DIR}", flush=True)
530
- upload_tokenizer_cache()
 
20
  """
21
  from __future__ import annotations
22
 
23
+ import os
24
+ import random
25
+ import importlib
26
+ from itertools import cycle
27
+ from typing import Any, Iterator, cast
 
28
 
29
  import torch
30
 
31
+ import prepare as _p # reuse tokenizer, BOS, byte-length helpers
32
 
33
  NEMOTRON_REPO = "nvidia/Nemotron-Pretraining-Specialized-v1.1"
34
 
 
65
  "Nemotron-Pretraining-Formal-Logic": 0.20,
66
  "Nemotron-Pretraining-Multiple-Choice": 0.20,
67
  }
68
+ PHASE2_WEIGHTS = {
69
  "Nemotron-Pretraining-Multiple-Choice": 0.45,
70
  "Nemotron-Pretraining-Economics": 0.20,
71
  "Nemotron-Pretraining-Formal-Logic": 0.15,
72
  "Nemotron-Pretraining-Code-Concepts": 0.10,
73
  "Nemotron-Pretraining-Unconditional-Algorithmic": 0.10,
74
+ }
75
+
76
+ type StreamBatch = tuple[list[str], int]
77
+ type TokenBatch = tuple[list[list[int]], int]
78
+
79
+
80
+ def _tokenizer_cache_repo() -> str:
81
+ return (
82
+ os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
83
+ or os.environ.get("FEATHER_HF_OUTPUT_REPO")
84
+ or os.environ.get("HF_REPO_ID")
85
+ or os.environ.get("HYDRA_RETINA_CACHE_REPO")
86
+ or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
87
+ or ""
88
+ )
89
+
90
+
91
+ def _tokenizer_cache_prefix() -> str:
92
+ return f"tokenizer/vocab{_p.VOCAB_SIZE}"
93
+
94
+
95
+ def maybe_hydrate_tokenizer_cache() -> bool:
96
+ """Try to download tokenizer artifacts from HF cache storage."""
97
+ repo_id = _tokenizer_cache_repo()
98
+ token = os.environ.get("HF_TOKEN")
99
+ if not repo_id or not token:
100
+ return False
101
+
102
+ try:
103
+ from huggingface_hub import hf_hub_download
104
+ except Exception as e: # noqa: BLE001
105
+ print(f"[nemotron] tokenizer cache unavailable: {type(e).__name__}: {e}", flush=True)
106
+ return False
107
+
108
+ os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
109
+ prefix = _tokenizer_cache_prefix()
110
+ try:
111
+ hf_hub_download(
112
+ repo_id=repo_id,
113
+ repo_type="model",
114
+ subfolder=prefix,
115
+ filename="tokenizer.pkl",
116
+ token=token,
117
+ local_dir=_p.TOKENIZER_DIR,
118
+ )
119
+ hf_hub_download(
120
+ repo_id=repo_id,
121
+ repo_type="model",
122
+ subfolder=prefix,
123
+ filename="token_bytes.pt",
124
+ token=token,
125
+ local_dir=_p.TOKENIZER_DIR,
126
+ )
127
+ except Exception as e: # noqa: BLE001
128
+ print(f"[nemotron] tokenizer cache miss in {repo_id}/{prefix}: {type(e).__name__}: {e}", flush=True)
129
+ return False
130
+
131
+ print(f"[nemotron] hydrated tokenizer cache from {repo_id}/{prefix}", flush=True)
132
+ return True
133
+
134
+
135
+ def upload_tokenizer_cache() -> None:
136
+ """Upload tokenizer artifacts for reuse by future jobs."""
137
+ repo_id = _tokenizer_cache_repo()
138
+ token = os.environ.get("HF_TOKEN")
139
+ if not repo_id or not token:
140
+ return
141
+
142
+ path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
143
+ token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
144
+ if not (os.path.exists(path) and os.path.exists(token_bytes_path)):
145
+ return
146
+
147
+ try:
148
+ from huggingface_hub import HfApi
149
+ api = HfApi(token=token)
150
+ prefix = _tokenizer_cache_prefix()
151
+ api.upload_file(path_or_fileobj=path, path_in_repo=f"{prefix}/tokenizer.pkl", repo_id=repo_id, repo_type="model")
152
+ api.upload_file(path_or_fileobj=token_bytes_path, path_in_repo=f"{prefix}/token_bytes.pt", repo_id=repo_id, repo_type="model")
153
+ print(f"[nemotron] uploaded tokenizer cache to {repo_id}/{prefix}", flush=True)
154
+ except Exception as e: # noqa: BLE001
155
+ print(f"[nemotron] tokenizer cache upload skipped: {type(e).__name__}: {e}", flush=True)
 
 
156
 
157
 
158
  def _phase_weights() -> dict[str, float]:
 
163
  return PHASE2_WEIGHTS if phase == "phase2" else PHASE1_WEIGHTS
164
 
165
 
166
+ def _open_stream(config: str, split: str):
167
  """Open a streaming iterator over one dataset config.
168
 
169
  Handles two modes:
 
174
 
175
  Yields dicts; text extraction handled downstream by _extract_text.
176
  """
177
+ load_dataset = importlib.import_module("datasets").load_dataset
178
+ token = os.environ.get("HF_TOKEN")
179
+ shuffle_buf = int(os.environ.get("HYDRA_STREAM_SHUFFLE_BUFFER", "2048"))
180
+
181
+ if config in _BLEND_REGISTRY:
182
+ repo, name, _text_col = _BLEND_REGISTRY[config]
183
+ kwargs: dict[str, object] = dict(
184
+ split="train",
185
+ streaming=True,
186
+ token=token,
187
+ )
188
  if name is not None:
189
  kwargs["name"] = name
190
  # nemotron-specialized has multiple sub-configs; pick the first one
 
206
  return iter(ds)
207
 
208
 
209
+ def _extract_text(row: dict[str, object]) -> str:
210
  """Pick the right text column — datasets have different column names.
211
 
212
  Priority order: text, content, prompt_completion, question, body.
213
  For math datasets that split into problem+solution, concatenate both.
214
  Fallback: concatenate all string-valued fields.
215
  """
216
+ # Fast path: most datasets use "text" or "content".
217
+ for k in ("text", "content", "prompt_completion", "question", "body"):
218
+ value = row.get(k)
219
+ if isinstance(value, str) and value:
220
+ return value
221
  # Math datasets may have problem + solution as separate fields.
222
  if "problem" in row and "solution" in row:
223
  p = row["problem"] or ""
 
233
  return "\n".join(parts)
234
 
235
 
236
+ class _WeightedStream:
237
  """Infinite weighted-round-robin over configs' streaming iterators."""
238
 
239
+ def __init__(self, weights: dict[str, float], seed: int = 0):
240
+ self.configs = list(weights.keys())
241
+ self.weights = [weights[c] for c in self.configs]
242
+ self.streams: dict[str, Iterator[dict[str, object]]] = {
243
+ c: _open_stream(c, "train") for c in self.configs
244
+ }
245
+ self.rng = random.Random(seed)
246
+ self.epoch = 1
247
+ self._factual_docs: list[str] | None = None
248
+ self._factual_idx = 0
249
+ self._inject_counter = 0
250
 
251
  def _reopen(self, config: str):
252
  # stream exhausted — reopen (HF streaming typically infinite but restart on edge)
 
262
  # exist in the Nemotron configs. Controlled by HYDRA_FACTUAL_INJECT_RATE
263
  # (default 50 = inject one factual doc every 50 Nemotron docs = ~2%).
264
  inject_rate = int(os.environ.get("HYDRA_FACTUAL_INJECT_RATE", "50"))
265
+ if inject_rate > 0 and self._factual_docs is None:
266
+ factual_path = os.path.join(
267
+ os.path.dirname(os.path.abspath(__file__)), "data", "factual", "facts.txt")
268
+ if os.path.exists(factual_path):
269
+ self._factual_docs = open(factual_path).read().strip().split('\n')
270
+ self._factual_idx = 0
271
+ self._inject_counter = 0
272
+ if inject_rate > 0 and self._factual_docs:
273
+ self._inject_counter += 1
274
+ if self._inject_counter >= inject_rate:
275
+ self._inject_counter = 0
276
+ doc = self._factual_docs[self._factual_idx % len(self._factual_docs)]
277
+ self._factual_idx += 1
278
+ return doc, self.epoch
279
 
280
  config = self.rng.choices(self.configs, weights=self.weights, k=1)[0]
281
  try:
 
308
  stream = _WeightedStream(_phase_weights(), seed=0)
309
 
310
  prefetch_depth = int(os.environ.get("HYDRA_STREAM_PREFETCH", "32"))
311
+ q: queue.Queue[StreamBatch | object] = queue.Queue(maxsize=prefetch_depth)
312
+ sentinel_stop = object()
313
+ error_box: list[BaseException] = []
314
 
315
  def producer():
316
  try:
 
335
  if error_box:
336
  raise error_box[0]
337
  return
338
+ yield cast(StreamBatch, item)
339
 
340
 
341
  def make_dataloader(tokenizer, B: int, T: int, split: str, buffer_size: int = 1000):
 
361
  # Stage 2: tokenization prefetch thread. Each queue element is a list of
362
  # token-id lists (pre-tokenized docs). HYDRA_TOKEN_PREFETCH controls depth.
363
  tok_prefetch = int(os.environ.get("HYDRA_TOKEN_PREFETCH", "8"))
364
+ tok_q: queue.Queue[TokenBatch | object] = queue.Queue(maxsize=tok_prefetch)
365
+ tok_sentinel = object()
366
+ tok_err_box: list[BaseException] = []
367
 
368
  def tokenizer_producer():
369
  try:
 
387
  if tok_err_box:
388
  raise tok_err_box[0]
389
  raise StopIteration
390
+ token_lists, epoch = cast(TokenBatch, item)
391
+ doc_buffer.extend(token_lists)
392
 
393
  row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
394
  cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
 
462
  return total_nats / (math.log(2) * max(total_bytes, 1))
463
 
464
 
465
+ def ensure_tokenizer():
466
  """Ensure rustbpe tokenizer exists. If absent, train on a Nemotron stream
467
  sample using the same rustbpe.train_from_iterator API that prepare.py uses
468
  (production path — don't fork tokenizer training logic).
469
  """
470
  import pickle
471
  import torch
472
+ path = os.path.join(_p.TOKENIZER_DIR, "tokenizer.pkl")
473
+ token_bytes_path = os.path.join(_p.TOKENIZER_DIR, "token_bytes.pt")
474
+ if os.path.exists(path) and os.path.exists(token_bytes_path):
475
+ print(f"[nemotron] tokenizer + token_bytes already trained at {_p.TOKENIZER_DIR}", flush=True)
476
+ return
477
+ if maybe_hydrate_tokenizer_cache() and os.path.exists(path) and os.path.exists(token_bytes_path):
478
+ return
479
+ os.makedirs(_p.TOKENIZER_DIR, exist_ok=True)
480
  print(f"[nemotron] training BPE (vocab_size={_p.VOCAB_SIZE}) on stream sample…", flush=True)
481
+ import rustbpe
482
+ import tiktoken
483
 
484
  # Pull a sample of docs — use full blend if active so BPE covers all 7 sources.
485
  n_docs = int(os.environ.get("HYDRA_BPE_TRAIN_DOCS", "20000"))
 
495
  print(f"[nemotron] collected {len(sample_texts)} sample docs; training BPE…", flush=True)
496
 
497
  # Train rustbpe — identical API to prepare.py's train_tokenizer().
498
+ tokenizer_cls = getattr(rustbpe, "Tokenizer")
499
+ tokenizer: Any = tokenizer_cls()
500
  vocab_size_no_special = _p.VOCAB_SIZE - len(_p.SPECIAL_TOKENS)
501
  tokenizer.train_from_iterator(iter(sample_texts), vocab_size_no_special, pattern=_p.SPLIT_PATTERN)
502
 
 
521
  for token_id in range(enc.n_vocab):
522
  tstr = enc.decode([token_id])
523
  token_bytes_list.append(0 if tstr in special_set else len(tstr.encode("utf-8")))
524
+ token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
525
+ torch.save(token_bytes_tensor, token_bytes_path)
526
+ print(f"[nemotron] BPE + token_bytes saved to {_p.TOKENIZER_DIR}", flush=True)
527
+ upload_tokenizer_cache()
overlay/scripts/__init__.py CHANGED
@@ -1 +1 @@
1
- # Package marker for script-level shared utilities.
 
1
+ # Package marker for script-level shared utilities.
overlay/scripts/audit_overlay_sync.py CHANGED
@@ -1,100 +1,100 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- from pathlib import Path
7
-
8
-
9
- DEFAULT_INCLUDE_PATHS = [
10
- "hydra",
11
- "subsystems",
12
- "scripts",
13
- "htm_rust",
14
- "harness",
15
- "configs",
16
- "prepare.py",
17
- "prepare_nemotron.py",
18
- "train.py",
19
- "pyproject.toml",
20
- "uv.lock",
21
- ]
22
-
23
-
24
- def _iter_files(path: Path) -> list[Path]:
25
- if not path.exists():
26
- return []
27
- if path.is_file():
28
- return [path]
29
- return sorted(p for p in path.rglob("*") if p.is_file())
30
-
31
-
32
- def classify_overlay_pairs(*, repo_root: Path, include_paths: list[str]) -> dict[str, list[str]]:
33
- overlay_root = repo_root / "hf_jobs" / "feather_h200_image" / "overlay"
34
- identical: list[str] = []
35
- root_ahead: list[str] = []
36
- overlay_only: list[str] = []
37
- missing_overlay: list[str] = []
38
-
39
- for rel in include_paths:
40
- root_path = repo_root / rel
41
- overlay_path = overlay_root / rel
42
-
43
- root_files = {p.relative_to(root_path).as_posix(): p for p in _iter_files(root_path)} if root_path.exists() and root_path.is_dir() else {}
44
- overlay_files = {p.relative_to(overlay_path).as_posix(): p for p in _iter_files(overlay_path)} if overlay_path.exists() and overlay_path.is_dir() else {}
45
-
46
- if root_path.is_file() or overlay_path.is_file():
47
- rel_name = rel.replace("\\", "/")
48
- if root_path.exists() and overlay_path.exists():
49
- if root_path.read_bytes() == overlay_path.read_bytes():
50
- identical.append(rel_name)
51
- else:
52
- root_ahead.append(rel_name)
53
- elif root_path.exists():
54
- missing_overlay.append(rel_name)
55
- elif overlay_path.exists():
56
- overlay_only.append(rel_name)
57
- continue
58
-
59
- for subrel, root_file in root_files.items():
60
- rel_name = f"{rel}/{subrel}".replace("\\", "/")
61
- overlay_file = overlay_files.get(subrel)
62
- if overlay_file is None:
63
- missing_overlay.append(rel_name)
64
- elif root_file.read_bytes() == overlay_file.read_bytes():
65
- identical.append(rel_name)
66
- else:
67
- root_ahead.append(rel_name)
68
-
69
- for subrel in overlay_files:
70
- if subrel not in root_files:
71
- overlay_only.append(f"{rel}/{subrel}".replace("\\", "/"))
72
-
73
- for bucket in (identical, root_ahead, overlay_only, missing_overlay):
74
- bucket.sort()
75
-
76
- return {
77
- "identical": identical,
78
- "root_ahead": root_ahead,
79
- "overlay_only": overlay_only,
80
- "missing_overlay": missing_overlay,
81
- }
82
-
83
-
84
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
85
- parser = argparse.ArgumentParser(description="Audit mirrored H200 overlay files against root source-of-truth paths")
86
- parser.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[1])
87
- parser.add_argument("--include-path", action="append", default=[])
88
- return parser.parse_args(argv)
89
-
90
-
91
- def main(argv: list[str] | None = None) -> int:
92
- args = parse_args(argv)
93
- include_paths = args.include_path or DEFAULT_INCLUDE_PATHS
94
- payload = classify_overlay_pairs(repo_root=args.repo_root, include_paths=include_paths)
95
- print(json.dumps(payload, indent=2, sort_keys=True))
96
- return 0
97
-
98
-
99
- if __name__ == "__main__":
100
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+
8
+
9
+ DEFAULT_INCLUDE_PATHS = [
10
+ "hydra",
11
+ "subsystems",
12
+ "scripts",
13
+ "htm_rust",
14
+ "harness",
15
+ "configs",
16
+ "prepare.py",
17
+ "prepare_nemotron.py",
18
+ "train.py",
19
+ "pyproject.toml",
20
+ "uv.lock",
21
+ ]
22
+
23
+
24
+ def _iter_files(path: Path) -> list[Path]:
25
+ if not path.exists():
26
+ return []
27
+ if path.is_file():
28
+ return [path]
29
+ return sorted(p for p in path.rglob("*") if p.is_file())
30
+
31
+
32
+ def classify_overlay_pairs(*, repo_root: Path, include_paths: list[str]) -> dict[str, list[str]]:
33
+ overlay_root = repo_root / "hf_jobs" / "feather_h200_image" / "overlay"
34
+ identical: list[str] = []
35
+ root_ahead: list[str] = []
36
+ overlay_only: list[str] = []
37
+ missing_overlay: list[str] = []
38
+
39
+ for rel in include_paths:
40
+ root_path = repo_root / rel
41
+ overlay_path = overlay_root / rel
42
+
43
+ root_files = {p.relative_to(root_path).as_posix(): p for p in _iter_files(root_path)} if root_path.exists() and root_path.is_dir() else {}
44
+ overlay_files = {p.relative_to(overlay_path).as_posix(): p for p in _iter_files(overlay_path)} if overlay_path.exists() and overlay_path.is_dir() else {}
45
+
46
+ if root_path.is_file() or overlay_path.is_file():
47
+ rel_name = rel.replace("\\", "/")
48
+ if root_path.exists() and overlay_path.exists():
49
+ if root_path.read_bytes() == overlay_path.read_bytes():
50
+ identical.append(rel_name)
51
+ else:
52
+ root_ahead.append(rel_name)
53
+ elif root_path.exists():
54
+ missing_overlay.append(rel_name)
55
+ elif overlay_path.exists():
56
+ overlay_only.append(rel_name)
57
+ continue
58
+
59
+ for subrel, root_file in root_files.items():
60
+ rel_name = f"{rel}/{subrel}".replace("\\", "/")
61
+ overlay_file = overlay_files.get(subrel)
62
+ if overlay_file is None:
63
+ missing_overlay.append(rel_name)
64
+ elif root_file.read_bytes() == overlay_file.read_bytes():
65
+ identical.append(rel_name)
66
+ else:
67
+ root_ahead.append(rel_name)
68
+
69
+ for subrel in overlay_files:
70
+ if subrel not in root_files:
71
+ overlay_only.append(f"{rel}/{subrel}".replace("\\", "/"))
72
+
73
+ for bucket in (identical, root_ahead, overlay_only, missing_overlay):
74
+ bucket.sort()
75
+
76
+ return {
77
+ "identical": identical,
78
+ "root_ahead": root_ahead,
79
+ "overlay_only": overlay_only,
80
+ "missing_overlay": missing_overlay,
81
+ }
82
+
83
+
84
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
85
+ parser = argparse.ArgumentParser(description="Audit mirrored H200 overlay files against root source-of-truth paths")
86
+ parser.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[1])
87
+ parser.add_argument("--include-path", action="append", default=[])
88
+ return parser.parse_args(argv)
89
+
90
+
91
+ def main(argv: list[str] | None = None) -> int:
92
+ args = parse_args(argv)
93
+ include_paths = args.include_path or DEFAULT_INCLUDE_PATHS
94
+ payload = classify_overlay_pairs(repo_root=args.repo_root, include_paths=include_paths)
95
+ print(json.dumps(payload, indent=2, sort_keys=True))
96
+ return 0
97
+
98
+
99
+ if __name__ == "__main__":
100
+ raise SystemExit(main())
overlay/scripts/benchmark_assets.py CHANGED
@@ -1,124 +1,62 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import os
5
- import shutil
6
- from pathlib import Path
7
-
8
- from scripts.benchmark_checkpoint import checkpoint_candidates
9
-
10
- try:
11
- from huggingface_hub import HfApi
12
- except Exception: # pragma: no cover - optional import for offline test envs
13
- HfApi = None
14
-
15
-
16
- def _download_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> Path:
17
- from huggingface_hub import hf_hub_download
18
-
19
- path = hf_hub_download(
20
- repo_id=repo_id,
21
- repo_type="model",
22
- filename=filename,
23
- subfolder=subfolder,
24
- token=token,
25
- local_dir=local_dir,
26
- local_dir_use_symlinks=False,
27
- )
28
- return Path(path)
29
-
30
-
31
- def resolve_tokenizer_cache_repo(*, output_repo: str, retina_cache_repo: str) -> str:
32
- return (
33
- os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
34
- or os.environ.get("FEATHER_HF_OUTPUT_REPO")
35
- or os.environ.get("HF_REPO_ID")
36
- or os.environ.get("HYDRA_RETINA_CACHE_REPO")
37
- or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
38
- or output_repo
39
- or retina_cache_repo
40
- )
41
-
42
-
43
- def tokenizer_cache_prefix() -> str:
44
- vocab_size = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536"))
45
- return f"tokenizer/vocab{vocab_size}"
46
-
47
-
48
- def choose_remote_checkpoint_path(files: list[str]) -> str | None:
49
- preferred = [
50
- path for path in files
51
- if path.endswith("/pretrain_final.pt") or path.endswith("/best_bpb.pt") or path.endswith("/latest.pt")
52
- ]
53
- if not preferred:
54
- return None
55
- pretrain = sorted([p for p in preferred if p.endswith("/pretrain_final.pt")])
56
- if pretrain:
57
- return pretrain[-1]
58
- best = sorted([p for p in preferred if p.endswith("/best_bpb.pt")])
59
- if best:
60
- return best[-1]
61
- latest = sorted([p for p in preferred if p.endswith("/latest.pt")])
62
- if latest:
63
- return latest[-1]
64
- return None
65
-
66
-
67
- def hydrate_benchmark_assets(*, cache_dir: Path, output_repo: str, tokenizer_repo: str, token: str | None) -> dict[str, str]:
68
- cache_dir.mkdir(parents=True, exist_ok=True)
69
- tok_dir = cache_dir / "tokenizer"
70
- tok_dir.mkdir(parents=True, exist_ok=True)
71
- tok_repo = resolve_tokenizer_cache_repo(output_repo=tokenizer_repo, retina_cache_repo=tokenizer_repo)
72
- tok_prefix = tokenizer_cache_prefix()
73
-
74
- ckpt_path = None
75
- for candidate in checkpoint_candidates(cache_dir):
76
- if candidate.exists():
77
- ckpt_path = candidate
78
- break
79
- try:
80
- ckpt_path = _download_file(repo_id=output_repo, filename=candidate.name, local_dir=str(cache_dir), token=token)
81
- break
82
- except Exception:
83
- continue
84
- if ckpt_path is None:
85
- try:
86
- if HfApi is None:
87
- raise RuntimeError("huggingface_hub unavailable")
88
- files = HfApi(token=token).list_repo_files(repo_id=output_repo, repo_type="model", token=token)
89
- remote_path = choose_remote_checkpoint_path(files)
90
- if remote_path is not None:
91
- parent, filename = remote_path.rsplit("/", 1)
92
- downloaded_path = _download_file(
93
- repo_id=output_repo,
94
- filename=filename,
95
- local_dir=str(cache_dir),
96
- token=token,
97
- subfolder=parent,
98
- )
99
- canonical_path = cache_dir / filename
100
- if downloaded_path != canonical_path:
101
- canonical_path.parent.mkdir(parents=True, exist_ok=True)
102
- shutil.copy2(downloaded_path, canonical_path)
103
- ckpt_path = canonical_path
104
- except Exception:
105
- pass
106
- if ckpt_path is None:
107
- raise FileNotFoundError(f"No benchmark checkpoint found in cache or repo {output_repo}")
108
-
109
- tok_path = tok_dir / "tokenizer.pkl"
110
- if not tok_path.exists():
111
- downloaded_tok = _download_file(repo_id=tok_repo, filename="tokenizer.pkl", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
112
- if downloaded_tok != tok_path:
113
- shutil.copy2(downloaded_tok, tok_path)
114
-
115
- token_bytes_path = tok_dir / "token_bytes.pt"
116
- if not token_bytes_path.exists():
117
- downloaded_token_bytes = _download_file(repo_id=tok_repo, filename="token_bytes.pt", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
118
- if downloaded_token_bytes != token_bytes_path:
119
- shutil.copy2(downloaded_token_bytes, token_bytes_path)
120
-
121
- return {
122
- "checkpoint_path": str(ckpt_path),
123
- "tokenizer_dir": str(tok_dir),
124
- }
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+
7
+
8
+ def _download_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> Path:
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ path = hf_hub_download(
12
+ repo_id=repo_id,
13
+ repo_type="model",
14
+ filename=filename,
15
+ subfolder=subfolder,
16
+ token=token,
17
+ local_dir=local_dir,
18
+ local_dir_use_symlinks=False,
19
+ )
20
+ return Path(path)
21
+
22
+
23
+ def resolve_tokenizer_cache_repo(*, output_repo: str, retina_cache_repo: str) -> str:
24
+ return (
25
+ os.environ.get("HYDRA_TOKENIZER_CACHE_REPO")
26
+ or os.environ.get("FEATHER_HF_OUTPUT_REPO")
27
+ or os.environ.get("HF_REPO_ID")
28
+ or os.environ.get("HYDRA_RETINA_CACHE_REPO")
29
+ or os.environ.get("FEATHER_HF_RETINA_CACHE_REPO")
30
+ or output_repo
31
+ or retina_cache_repo
32
+ )
33
+
34
+
35
+ def tokenizer_cache_prefix() -> str:
36
+ vocab_size = int(os.environ.get("HYDRA_VOCAB_SIZE", "65536"))
37
+ return f"tokenizer/vocab{vocab_size}"
38
+
39
+
40
+ def hydrate_benchmark_assets(*, cache_dir: Path, output_repo: str, tokenizer_repo: str, token: str | None) -> dict[str, str]:
41
+ cache_dir.mkdir(parents=True, exist_ok=True)
42
+ tok_dir = cache_dir / "tokenizer"
43
+ tok_dir.mkdir(parents=True, exist_ok=True)
44
+ tok_repo = resolve_tokenizer_cache_repo(output_repo=tokenizer_repo, retina_cache_repo=tokenizer_repo)
45
+ tok_prefix = tokenizer_cache_prefix()
46
+
47
+ ckpt_path = cache_dir / "best_bpb.pt"
48
+ if not ckpt_path.exists():
49
+ ckpt_path = _download_file(repo_id=output_repo, filename="best_bpb.pt", local_dir=str(cache_dir), token=token)
50
+
51
+ tok_path = tok_dir / "tokenizer.pkl"
52
+ if not tok_path.exists():
53
+ tok_path = _download_file(repo_id=tok_repo, filename="tokenizer.pkl", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
54
+
55
+ token_bytes_path = tok_dir / "token_bytes.pt"
56
+ if not token_bytes_path.exists():
57
+ token_bytes_path = _download_file(repo_id=tok_repo, filename="token_bytes.pt", local_dir=str(tok_dir), token=token, subfolder=tok_prefix)
58
+
59
+ return {
60
+ "checkpoint_path": str(ckpt_path),
61
+ "tokenizer_dir": str(tok_dir),
62
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/benchmark_checkpoint.py CHANGED
@@ -1,118 +1,19 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import shutil
5
- from pathlib import Path
6
-
7
- from scripts.hf_routing import resolve_routing
8
-
9
- try:
10
- from huggingface_hub import HfApi
11
- except Exception: # pragma: no cover
12
- HfApi = None
13
-
14
-
15
- def choose_remote_checkpoint_path(files: list[str]) -> str | None:
16
- preferred = [
17
- path for path in files
18
- if path.endswith("/pretrain_final.pt") or path.endswith("/best_bpb.pt") or path.endswith("/latest.pt")
19
- ]
20
- if not preferred:
21
- return None
22
- pretrain = sorted([p for p in preferred if p.endswith("/pretrain_final.pt")])
23
- if pretrain:
24
- return pretrain[-1]
25
- best = sorted([p for p in preferred if p.endswith("/best_bpb.pt")])
26
- if best:
27
- return best[-1]
28
- latest = sorted([p for p in preferred if p.endswith("/latest.pt")])
29
- if latest:
30
- return latest[-1]
31
- return None
32
-
33
-
34
- def checkpoint_candidates(cache_dir: Path) -> list[Path]:
35
- return [
36
- cache_dir / "best_bpb.pt",
37
- cache_dir / "pretrain_final.pt",
38
- cache_dir / "latest.pt",
39
- ]
40
-
41
-
42
- def choose_checkpoint_candidate(cache_dir: Path) -> Path | None:
43
- for path in checkpoint_candidates(cache_dir):
44
- if path.exists():
45
- return path
46
- return None
47
-
48
-
49
- def resolve_checkpoint_source(*, cache_dir: Path, output_repo: str | None) -> dict[str, str]:
50
- local = choose_checkpoint_candidate(cache_dir)
51
- if local is not None:
52
- return {"mode": "local", "path": str(local)}
53
- if output_repo:
54
- return {"mode": "remote", "repo_id": output_repo}
55
- routing = resolve_routing(token=None)
56
- return {"mode": "remote", "repo_id": routing.output_repo}
57
-
58
-
59
- def _download_checkpoint_file(*, repo_id: str, filename: str, local_dir: str, token: str | None, subfolder: str | None = None) -> str:
60
- from huggingface_hub import hf_hub_download
61
-
62
- return hf_hub_download(
63
- repo_id=repo_id,
64
- repo_type="model",
65
- filename=filename,
66
- subfolder=subfolder,
67
- token=token,
68
- local_dir=local_dir,
69
- local_dir_use_symlinks=False,
70
- )
71
-
72
-
73
- def hydrate_checkpoint(*, cache_dir: Path, output_repo: str | None, token: str | None) -> Path | None:
74
- local = choose_checkpoint_candidate(cache_dir)
75
- if local is not None:
76
- return local
77
- source = resolve_checkpoint_source(cache_dir=cache_dir, output_repo=output_repo)
78
- if source["mode"] != "remote":
79
- return None
80
- cache_dir.mkdir(parents=True, exist_ok=True)
81
- for filename in ("best_bpb.pt", "pretrain_final.pt", "latest.pt"):
82
- try:
83
- path = Path(
84
- _download_checkpoint_file(
85
- repo_id=source["repo_id"],
86
- filename=filename,
87
- local_dir=str(cache_dir),
88
- token=token,
89
- )
90
- )
91
- if path.exists():
92
- return path
93
- except Exception:
94
- continue
95
- try:
96
- if HfApi is None:
97
- raise RuntimeError("huggingface_hub unavailable")
98
- files = HfApi(token=token).list_repo_files(repo_id=source["repo_id"], repo_type="model", token=token)
99
- remote_path = choose_remote_checkpoint_path(files)
100
- if remote_path is not None:
101
- parent, filename = remote_path.rsplit("/", 1)
102
- downloaded = Path(
103
- _download_checkpoint_file(
104
- repo_id=source["repo_id"],
105
- filename=filename,
106
- local_dir=str(cache_dir),
107
- token=token,
108
- subfolder=parent,
109
- )
110
- )
111
- canonical = cache_dir / filename
112
- if downloaded != canonical:
113
- shutil.copy2(downloaded, canonical)
114
- if canonical.exists():
115
- return canonical
116
- except Exception:
117
- pass
118
- return None
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+
7
+ def checkpoint_candidates(cache_dir: Path) -> list[Path]:
8
+ return [
9
+ cache_dir / "best_bpb.pt",
10
+ cache_dir / "pretrain_final.pt",
11
+ cache_dir / "latest.pt",
12
+ ]
13
+
14
+
15
+ def choose_checkpoint_candidate(cache_dir: Path) -> Path | None:
16
+ for path in checkpoint_candidates(cache_dir):
17
+ if path.exists():
18
+ return path
19
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/benchmark_checkpoint_report.py CHANGED
@@ -1,50 +1,50 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
-
6
-
7
- def build_checkpoint_report(files: list[str]) -> dict[str, object]:
8
- by_job: dict[str, dict[str, object]] = {}
9
- for path in files:
10
- parts = path.split("/")
11
- if len(parts) < 3 or parts[0] != "jobs":
12
- continue
13
- job_id = parts[1]
14
- filename = parts[-1]
15
- if filename not in {"best_bpb.pt", "pretrain_final.pt", "latest.pt"}:
16
- continue
17
- row = by_job.setdefault(job_id, {"job_id": job_id, "paths": []})
18
- row["paths"].append(path)
19
-
20
- candidates = []
21
- for job_id, row in by_job.items():
22
- paths = list(row["paths"])
23
- preferred = None
24
- for suffix in ("pretrain_final.pt", "best_bpb.pt", "latest.pt"):
25
- for path in paths:
26
- if path.endswith(suffix):
27
- preferred = path
28
- break
29
- if preferred is not None:
30
- break
31
- candidates.append({
32
- "job_id": job_id,
33
- "preferred_path": preferred,
34
- "available_paths": sorted(paths),
35
- })
36
-
37
- candidates.sort(key=lambda row: row["job_id"], reverse=True)
38
- return {
39
- "n_candidates": len(candidates),
40
- "candidates": candidates,
41
- }
42
-
43
-
44
- def main() -> int:
45
- print(json.dumps(build_checkpoint_report([]), indent=2, sort_keys=True))
46
- return 0
47
-
48
-
49
- if __name__ == "__main__":
50
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+
6
+
7
+ def build_checkpoint_report(files: list[str]) -> dict[str, object]:
8
+ by_job: dict[str, dict[str, object]] = {}
9
+ for path in files:
10
+ parts = path.split("/")
11
+ if len(parts) < 3 or parts[0] != "jobs":
12
+ continue
13
+ job_id = parts[1]
14
+ filename = parts[-1]
15
+ if filename not in {"best_bpb.pt", "pretrain_final.pt", "latest.pt"}:
16
+ continue
17
+ row = by_job.setdefault(job_id, {"job_id": job_id, "paths": []})
18
+ row["paths"].append(path)
19
+
20
+ candidates = []
21
+ for job_id, row in by_job.items():
22
+ paths = list(row["paths"])
23
+ preferred = None
24
+ for suffix in ("pretrain_final.pt", "best_bpb.pt", "latest.pt"):
25
+ for path in paths:
26
+ if path.endswith(suffix):
27
+ preferred = path
28
+ break
29
+ if preferred is not None:
30
+ break
31
+ candidates.append({
32
+ "job_id": job_id,
33
+ "preferred_path": preferred,
34
+ "available_paths": sorted(paths),
35
+ })
36
+
37
+ candidates.sort(key=lambda row: row["job_id"], reverse=True)
38
+ return {
39
+ "n_candidates": len(candidates),
40
+ "candidates": candidates,
41
+ }
42
+
43
+
44
+ def main() -> int:
45
+ print(json.dumps(build_checkpoint_report([]), indent=2, sort_keys=True))
46
+ return 0
47
+
48
+
49
+ if __name__ == "__main__":
50
+ raise SystemExit(main())
overlay/scripts/benchmark_contract.py CHANGED
@@ -1,67 +1,67 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- from pathlib import Path
6
- from typing import Any
7
-
8
-
9
- def _require_path(payload: dict[str, Any], path: str) -> None:
10
- current: Any = payload
11
- for part in path.split('.'):
12
- if not isinstance(current, dict) or part not in current:
13
- raise ValueError(f"missing required field: {path}")
14
- current = current[part]
15
-
16
-
17
- def validate_benchmark_contract(payload: dict[str, Any]) -> None:
18
- for field in [
19
- "cycle_id",
20
- "hardware_class",
21
- "seeds",
22
- "budget_modes",
23
- "coding_benchmarks.fast_iteration",
24
- "coding_benchmarks.milestone",
25
- "reasoning_benchmarks.fast_iteration",
26
- "reasoning_benchmarks.milestone",
27
- "variants.hydra_full",
28
- "variants.baseline_mamba_matched",
29
- ]:
30
- _require_path(payload, field)
31
-
32
- for section in [
33
- payload["coding_benchmarks"]["fast_iteration"],
34
- payload["coding_benchmarks"]["milestone"],
35
- payload["reasoning_benchmarks"]["fast_iteration"],
36
- payload["reasoning_benchmarks"]["milestone"],
37
- ]:
38
- if "name" not in section or "primary_metric" not in section or "decode" not in section:
39
- raise ValueError("benchmark sections require name, primary_metric, and decode")
40
-
41
- if not isinstance(payload["seeds"], list) or len(payload["seeds"]) < 3:
42
- raise ValueError("seeds must contain at least three values")
43
-
44
- if payload["variants"]["hydra_full"].get("status") != "runnable_now":
45
- raise ValueError("hydra_full must be runnable_now")
46
-
47
- if payload["variants"]["baseline_mamba_matched"].get("status") != "runnable_now":
48
- raise ValueError("baseline_mamba_matched must be runnable_now")
49
-
50
-
51
- def load_benchmark_contract(path: Path) -> dict[str, Any]:
52
- payload = json.loads(path.read_text(encoding="utf-8"))
53
- if not isinstance(payload, dict):
54
- raise ValueError("benchmark contract must be a JSON object")
55
- validate_benchmark_contract(payload)
56
- return payload
57
-
58
-
59
- def main() -> int:
60
- path = Path("artifacts/cycle_1_execution_freeze.json")
61
- payload = load_benchmark_contract(path)
62
- print(json.dumps(payload, indent=2, sort_keys=True))
63
- return 0
64
-
65
-
66
- if __name__ == "__main__":
67
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+
9
+ def _require_path(payload: dict[str, Any], path: str) -> None:
10
+ current: Any = payload
11
+ for part in path.split('.'):
12
+ if not isinstance(current, dict) or part not in current:
13
+ raise ValueError(f"missing required field: {path}")
14
+ current = current[part]
15
+
16
+
17
+ def validate_benchmark_contract(payload: dict[str, Any]) -> None:
18
+ for field in [
19
+ "cycle_id",
20
+ "hardware_class",
21
+ "seeds",
22
+ "budget_modes",
23
+ "coding_benchmarks.fast_iteration",
24
+ "coding_benchmarks.milestone",
25
+ "reasoning_benchmarks.fast_iteration",
26
+ "reasoning_benchmarks.milestone",
27
+ "variants.hydra_full",
28
+ "variants.baseline_mamba_matched",
29
+ ]:
30
+ _require_path(payload, field)
31
+
32
+ for section in [
33
+ payload["coding_benchmarks"]["fast_iteration"],
34
+ payload["coding_benchmarks"]["milestone"],
35
+ payload["reasoning_benchmarks"]["fast_iteration"],
36
+ payload["reasoning_benchmarks"]["milestone"],
37
+ ]:
38
+ if "name" not in section or "primary_metric" not in section or "decode" not in section:
39
+ raise ValueError("benchmark sections require name, primary_metric, and decode")
40
+
41
+ if not isinstance(payload["seeds"], list) or len(payload["seeds"]) < 3:
42
+ raise ValueError("seeds must contain at least three values")
43
+
44
+ if payload["variants"]["hydra_full"].get("status") != "runnable_now":
45
+ raise ValueError("hydra_full must be runnable_now")
46
+
47
+ if payload["variants"]["baseline_mamba_matched"].get("status") != "runnable_now":
48
+ raise ValueError("baseline_mamba_matched must be runnable_now")
49
+
50
+
51
+ def load_benchmark_contract(path: Path) -> dict[str, Any]:
52
+ payload = json.loads(path.read_text(encoding="utf-8"))
53
+ if not isinstance(payload, dict):
54
+ raise ValueError("benchmark contract must be a JSON object")
55
+ validate_benchmark_contract(payload)
56
+ return payload
57
+
58
+
59
+ def main() -> int:
60
+ path = Path("artifacts/cycle_1_execution_freeze.json")
61
+ payload = load_benchmark_contract(path)
62
+ print(json.dumps(payload, indent=2, sort_keys=True))
63
+ return 0
64
+
65
+
66
+ if __name__ == "__main__":
67
+ raise SystemExit(main())
overlay/scripts/benchmark_datasets.py CHANGED
@@ -1,190 +1,18 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- import pyarrow as pa
10
- import pyarrow.parquet as pq
11
-
12
- try:
13
- from huggingface_hub import HfApi, hf_hub_download
14
- except Exception: # pragma: no cover
15
- HfApi = None
16
- hf_hub_download = None
17
-
18
-
19
- CANONICAL_SUBSETS = {
20
- "MBPP": Path("data/benchmarks/mbpp.cycle1.jsonl"),
21
- "GSM8K": Path("data/benchmarks/gsm8k.cycle1.jsonl"),
22
- "HumanEval": Path("data/benchmarks/humaneval.cycle1.jsonl"),
23
- "ARC-Challenge": Path("data/benchmarks/arc_challenge.cycle1.jsonl"),
24
- }
25
-
26
-
27
- DATASET_SOURCES: dict[str, dict[str, str]] = {
28
- "MBPP": {"repo_id": "Muennighoff/mbpp", "subset": "full", "split": "test", "raw_path": "data/mbpp.jsonl"},
29
- "GSM8K": {"repo_id": "openai/gsm8k", "subset": "main", "split": "test"},
30
- "HumanEval": {"repo_id": "openai/openai_humaneval", "subset": "default", "split": "test"},
31
- "ARC-Challenge": {"repo_id": "allenai/ai2_arc", "subset": "ARC-Challenge", "split": "validation"},
32
- }
33
-
34
-
35
- def resolve_benchmark_dataset(benchmark_name: str, explicit_path: Path | None) -> Path:
36
- if explicit_path is not None:
37
- return explicit_path
38
- if benchmark_name not in CANONICAL_SUBSETS:
39
- raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
40
- return Path.cwd() / CANONICAL_SUBSETS[benchmark_name]
41
-
42
-
43
- def _normalize_gsm8k_answer(answer: str) -> str:
44
- if "####" in answer:
45
- return answer.split("####")[-1].strip()
46
- return answer.strip()
47
-
48
-
49
- def transform_dataset_row(benchmark_name: str, row: dict[str, Any], *, row_id: int) -> dict[str, Any]:
50
- if benchmark_name == "GSM8K":
51
- return {
52
- "question": str(row["question"]),
53
- "answer": _normalize_gsm8k_answer(str(row["answer"])),
54
- }
55
- if benchmark_name == "ARC-Challenge":
56
- choices = row["choices"]
57
- labels = list(choices["label"])
58
- texts = list(choices["text"])
59
- answer_key = str(row["answerKey"])
60
- answer_index = labels.index(answer_key)
61
- return {
62
- "question": str(row["question"]),
63
- "choices": [str(choice) for choice in texts],
64
- "answer": str(texts[answer_index]),
65
- }
66
- if benchmark_name == "MBPP":
67
- task_id = row.get("task_id", row_id)
68
- return {
69
- "task_id": str(task_id),
70
- "prompt": str(row["text"]),
71
- "tests": [str(test) for test in row["test_list"]],
72
- }
73
- if benchmark_name == "HumanEval":
74
- task_id = row.get("task_id", row_id)
75
- return {
76
- "task_id": str(task_id),
77
- "prompt": str(row["prompt"]),
78
- "test": str(row["test"]),
79
- }
80
- raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
81
-
82
-
83
- def write_canonical_dataset(*, benchmark_name: str, rows: list[dict[str, Any]], out_path: Path, limit: int) -> int:
84
- out_path.parent.mkdir(parents=True, exist_ok=True)
85
- transformed = [transform_dataset_row(benchmark_name, row, row_id=index) for index, row in enumerate(rows[:limit])]
86
- out_path.write_text("".join(json.dumps(row) + "\n" for row in transformed), encoding="utf-8")
87
- return len(transformed)
88
-
89
-
90
- def choose_dataset_parquet_path(benchmark_name: str, files: list[str]) -> str | None:
91
- source = DATASET_SOURCES[benchmark_name]
92
- subset = source["subset"].lower()
93
- split = source["split"].lower()
94
- candidates = [path for path in files if path.endswith(".parquet")]
95
- preferred = [path for path in candidates if subset in path.lower() and split in path.lower()]
96
- if preferred:
97
- return sorted(preferred)[0]
98
- split_only = [path for path in candidates if split in path.lower()]
99
- if split_only:
100
- return sorted(split_only)[0]
101
- return sorted(candidates)[0] if candidates else None
102
-
103
-
104
- def download_dataset_snapshot(benchmark_name: str, *, cache_dir: Path, token: str | None) -> Path:
105
- source = DATASET_SOURCES[benchmark_name]
106
- if HfApi is None or hf_hub_download is None:
107
- raise RuntimeError("huggingface_hub unavailable")
108
- raw_path = source.get("raw_path")
109
- if raw_path:
110
- if "/" in raw_path:
111
- subfolder, filename = raw_path.rsplit("/", 1)
112
- else:
113
- subfolder, filename = None, raw_path
114
- downloaded = hf_hub_download(
115
- repo_id=source["repo_id"],
116
- repo_type="dataset",
117
- filename=filename,
118
- subfolder=subfolder,
119
- token=token,
120
- local_dir=str(cache_dir),
121
- )
122
- return Path(downloaded)
123
- files = HfApi(token=token).list_repo_files(repo_id=source["repo_id"], repo_type="dataset", token=token)
124
- parquet_path = choose_dataset_parquet_path(benchmark_name, files)
125
- if parquet_path is None:
126
- raise FileNotFoundError(f"No parquet dataset file found for {benchmark_name}")
127
- if "/" in parquet_path:
128
- subfolder, filename = parquet_path.rsplit("/", 1)
129
- else:
130
- subfolder, filename = None, parquet_path
131
- downloaded = hf_hub_download(
132
- repo_id=source["repo_id"],
133
- repo_type="dataset",
134
- filename=filename,
135
- subfolder=subfolder,
136
- token=token,
137
- local_dir=str(cache_dir),
138
- )
139
- return Path(downloaded)
140
-
141
-
142
- def hydrate_canonical_dataset(
143
- *,
144
- benchmark_name: str,
145
- out_path: Path,
146
- limit: int,
147
- cache_dir: Path,
148
- token: str | None,
149
- ) -> int:
150
- source_path = download_dataset_snapshot(benchmark_name, cache_dir=cache_dir, token=token)
151
- if source_path.suffix == ".jsonl":
152
- rows = [json.loads(line) for line in source_path.read_text(encoding="utf-8").splitlines() if line.strip()]
153
- else:
154
- table = pq.read_table(source_path)
155
- rows = table.to_pylist()
156
- return write_canonical_dataset(benchmark_name=benchmark_name, rows=rows, out_path=out_path, limit=limit)
157
-
158
-
159
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
160
- parser = argparse.ArgumentParser(description="Hydrate a canonical benchmark dataset JSONL from a public source")
161
- parser.add_argument("--benchmark", required=True, choices=list(CANONICAL_SUBSETS))
162
- parser.add_argument("--samples", type=Path)
163
- parser.add_argument("--out", type=Path)
164
- parser.add_argument("--limit", type=int, default=20)
165
- parser.add_argument("--cache-dir", type=Path, default=Path(".cache/benchmarks"))
166
- parser.add_argument("--token")
167
- return parser.parse_args(argv)
168
-
169
-
170
- def main(argv: list[str] | None = None) -> int:
171
- args = parse_args(argv)
172
- if args.samples is not None:
173
- rows = [json.loads(line) for line in args.samples.read_text(encoding="utf-8").splitlines() if line.strip()]
174
- out_path = args.out or resolve_benchmark_dataset(args.benchmark, None)
175
- write_canonical_dataset(benchmark_name=args.benchmark, rows=rows, out_path=out_path, limit=args.limit)
176
- return 0
177
-
178
- out_path = args.out or resolve_benchmark_dataset(args.benchmark, None)
179
- hydrate_canonical_dataset(
180
- benchmark_name=args.benchmark,
181
- out_path=out_path,
182
- limit=args.limit,
183
- cache_dir=args.cache_dir,
184
- token=args.token,
185
- )
186
- return 0
187
-
188
-
189
- if __name__ == "__main__":
190
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+
7
+ CANONICAL_SUBSETS = {
8
+ "MBPP": Path("data/benchmarks/mbpp.cycle1.jsonl"),
9
+ "GSM8K": Path("data/benchmarks/gsm8k.cycle1.jsonl"),
10
+ }
11
+
12
+
13
+ def resolve_benchmark_dataset(benchmark_name: str, explicit_path: Path | None) -> Path:
14
+ if explicit_path is not None:
15
+ return explicit_path
16
+ if benchmark_name not in CANONICAL_SUBSETS:
17
+ raise ValueError(f"Unsupported benchmark dataset: {benchmark_name}")
18
+ return Path.cwd() / CANONICAL_SUBSETS[benchmark_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/benchmark_hyena_stack.py CHANGED
@@ -26,11 +26,8 @@ Invocation:
26
  # On A100/A10G (production cloud hardware), use time=900 (15 min) for
27
  # stable steady-state numbers.
28
 
29
- After each run the script prints:
30
- BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB>
31
-
32
- If `--min-tps` is set (>0), the script exits non-zero when steady-state TPS
33
- falls below the threshold.
34
 
35
  Collate those lines into the matrix table manually, then pick the winner
36
  for the 6-hour production run (HYDRA_TIME_BUDGET=21600).
@@ -50,34 +47,30 @@ REPO = Path(__file__).resolve().parents[1]
50
 
51
  CONFIGS = {
52
  # Baseline: B=8, no flash, no train-cache. Current reference point.
53
- "baseline": {
54
- "HYDRA_BATCH_SIZE": "8",
55
- "HYDRA_THROUGHPUT_MODE": "1",
56
- "HYDRA_HYENA_LAYERS": "3,7",
57
  "HYDRA_HYENA_FLASH_FFT": "0",
58
  "HYDRA_HYENA_TRAIN_CACHE": "0",
59
  "HYDRA_HYENA_FILTER_CACHE": "0",
60
  },
61
- "b16": {
62
- "HYDRA_BATCH_SIZE": "16",
63
- "HYDRA_THROUGHPUT_MODE": "1",
64
- "HYDRA_HYENA_LAYERS": "3,7",
65
  "HYDRA_HYENA_FLASH_FFT": "0",
66
  "HYDRA_HYENA_TRAIN_CACHE": "0",
67
  "HYDRA_HYENA_FILTER_CACHE": "0",
68
  },
69
- "cache": {
70
- "HYDRA_BATCH_SIZE": "16",
71
- "HYDRA_THROUGHPUT_MODE": "1",
72
- "HYDRA_HYENA_LAYERS": "3,7",
73
  "HYDRA_HYENA_FLASH_FFT": "0",
74
  "HYDRA_HYENA_TRAIN_CACHE": "1",
75
  "HYDRA_HYENA_FILTER_CACHE": "1",
76
  },
77
- "kernel": {
78
- "HYDRA_BATCH_SIZE": "16",
79
- "HYDRA_THROUGHPUT_MODE": "1",
80
- "HYDRA_HYENA_LAYERS": "3,7",
81
  "HYDRA_HYENA_FLASH_FFT": "1",
82
  "HYDRA_HYENA_TRAIN_CACHE": "1",
83
  "HYDRA_HYENA_FILTER_CACHE": "1",
@@ -88,7 +81,7 @@ CONFIGS = {
88
  }
89
 
90
 
91
- def build_env(cfg_overrides: dict[str, str]) -> dict[str, str]:
92
  """Compose a full env dict from the inherited env + config overrides."""
93
  env = os.environ.copy()
94
  # Ensure the Hyena layer selection is always present (defaults to off).
@@ -98,7 +91,7 @@ def build_env(cfg_overrides: dict[str, str]) -> dict[str, str]:
98
  return env
99
 
100
 
101
- def parse_step_line(line: str) -> dict[str, float] | None:
102
  """Parse a single step=... line into a dict of metrics, or None."""
103
  if not line.startswith("step="):
104
  return None
@@ -109,7 +102,7 @@ def parse_step_line(line: str) -> dict[str, float] | None:
109
  return None
110
 
111
 
112
- def summarize(log_path: Path, warmup_steps: int = 50) -> dict[str, float]:
113
  """Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak.
114
 
115
  Skips the first `warmup_steps` to discard CUDA graph capture / autotune
@@ -145,29 +138,20 @@ def summarize(log_path: Path, warmup_steps: int = 50) -> dict[str, float]:
145
  tps_sorted = sorted(tps_vals)
146
  tps_steady = tps_sorted[len(tps_sorted) // 2] # median
147
 
148
- return {
149
- "tps_steady": tps_steady,
150
- "bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0),
151
- "vram_peak": vram_peak,
152
- "steps": len(tps_vals) + warmup_steps,
153
- }
154
-
155
-
156
- def fails_tps_floor(summary: dict[str, float], min_tps: float) -> bool:
157
- if min_tps <= 0:
158
- return False
159
- tps_steady = float(summary.get("tps_steady", 0.0))
160
- return tps_steady < float(min_tps)
161
-
162
-
163
- def main() -> int:
164
- ap = argparse.ArgumentParser()
165
- ap.add_argument("--config", required=True, choices=list(CONFIGS))
166
- ap.add_argument("--time", type=int, default=300, help="training seconds")
167
- ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)")
168
- ap.add_argument("--min-tps", type=float, default=50000.0, help="Required steady-state TPS floor (set 0 to disable)")
169
- ap.add_argument("--warmup-steps", type=int, default=50, help="Number of initial steps to skip before TPS median")
170
- args = ap.parse_args()
171
 
172
  cfg = CONFIGS[args.config]
173
  log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log"))
@@ -194,25 +178,16 @@ def main() -> int:
194
  print(f"BENCH FAIL config={args.config}", flush=True)
195
  return proc.returncode
196
 
197
- summary = summarize(log_path, warmup_steps=max(0, int(args.warmup_steps)))
198
- print(
199
- f"BENCHMARK config={args.config} "
200
- f"tps_steady={summary['tps_steady']:.0f} "
201
- f"bpb_at_500={summary['bpb_at_500']:.4f} "
202
- f"vram_peak={summary['vram_peak']:.0f}MiB "
203
- f"steps={summary['steps']}",
204
- flush=True,
205
- )
206
-
207
- if fails_tps_floor(summary, args.min_tps):
208
- print(
209
- f"BENCH FAIL config={args.config} tps_steady={summary['tps_steady']:.0f} < min_tps={args.min_tps:.0f}",
210
- flush=True,
211
- )
212
- return 2
213
-
214
- print(f"BENCH PASS config={args.config} min_tps={args.min_tps:.0f}", flush=True)
215
- return 0
216
 
217
 
218
  if __name__ == "__main__":
 
26
  # On A100/A10G (production cloud hardware), use time=900 (15 min) for
27
  # stable steady-state numbers.
28
 
29
+ After each run the script prints:
30
+ BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB>
 
 
 
31
 
32
  Collate those lines into the matrix table manually, then pick the winner
33
  for the 6-hour production run (HYDRA_TIME_BUDGET=21600).
 
47
 
48
  CONFIGS = {
49
  # Baseline: B=8, no flash, no train-cache. Current reference point.
50
+ "baseline": {
51
+ "HYDRA_BATCH_SIZE": "8",
52
+ "HYDRA_HYENA_LAYERS": "3,7",
 
53
  "HYDRA_HYENA_FLASH_FFT": "0",
54
  "HYDRA_HYENA_TRAIN_CACHE": "0",
55
  "HYDRA_HYENA_FILTER_CACHE": "0",
56
  },
57
+ "b16": {
58
+ "HYDRA_BATCH_SIZE": "16",
59
+ "HYDRA_HYENA_LAYERS": "3,7",
 
60
  "HYDRA_HYENA_FLASH_FFT": "0",
61
  "HYDRA_HYENA_TRAIN_CACHE": "0",
62
  "HYDRA_HYENA_FILTER_CACHE": "0",
63
  },
64
+ "cache": {
65
+ "HYDRA_BATCH_SIZE": "16",
66
+ "HYDRA_HYENA_LAYERS": "3,7",
 
67
  "HYDRA_HYENA_FLASH_FFT": "0",
68
  "HYDRA_HYENA_TRAIN_CACHE": "1",
69
  "HYDRA_HYENA_FILTER_CACHE": "1",
70
  },
71
+ "kernel": {
72
+ "HYDRA_BATCH_SIZE": "16",
73
+ "HYDRA_HYENA_LAYERS": "3,7",
 
74
  "HYDRA_HYENA_FLASH_FFT": "1",
75
  "HYDRA_HYENA_TRAIN_CACHE": "1",
76
  "HYDRA_HYENA_FILTER_CACHE": "1",
 
81
  }
82
 
83
 
84
+ def build_env(cfg_overrides: dict) -> dict:
85
  """Compose a full env dict from the inherited env + config overrides."""
86
  env = os.environ.copy()
87
  # Ensure the Hyena layer selection is always present (defaults to off).
 
91
  return env
92
 
93
 
94
+ def parse_step_line(line: str) -> dict | None:
95
  """Parse a single step=... line into a dict of metrics, or None."""
96
  if not line.startswith("step="):
97
  return None
 
102
  return None
103
 
104
 
105
+ def summarize(log_path: Path, warmup_steps: int = 50) -> dict:
106
  """Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak.
107
 
108
  Skips the first `warmup_steps` to discard CUDA graph capture / autotune
 
138
  tps_sorted = sorted(tps_vals)
139
  tps_steady = tps_sorted[len(tps_sorted) // 2] # median
140
 
141
+ return {
142
+ "tps_steady": tps_steady,
143
+ "bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0),
144
+ "vram_peak": vram_peak,
145
+ "steps": len(tps_vals) + warmup_steps,
146
+ }
147
+
148
+
149
+ def main() -> int:
150
+ ap = argparse.ArgumentParser()
151
+ ap.add_argument("--config", required=True, choices=list(CONFIGS))
152
+ ap.add_argument("--time", type=int, default=300, help="training seconds")
153
+ ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)")
154
+ args = ap.parse_args()
 
 
 
 
 
 
 
 
 
155
 
156
  cfg = CONFIGS[args.config]
157
  log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log"))
 
178
  print(f"BENCH FAIL config={args.config}", flush=True)
179
  return proc.returncode
180
 
181
+ summary = summarize(log_path)
182
+ print(
183
+ f"BENCHMARK config={args.config} "
184
+ f"tps_steady={summary['tps_steady']:.0f} "
185
+ f"bpb_at_500={summary['bpb_at_500']:.4f} "
186
+ f"vram_peak={summary['vram_peak']:.0f}MiB "
187
+ f"steps={summary['steps']}",
188
+ flush=True,
189
+ )
190
+ return 0
 
 
 
 
 
 
 
 
 
191
 
192
 
193
  if __name__ == "__main__":
overlay/scripts/benchmark_preflight.py CHANGED
@@ -1,35 +1,31 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- from pathlib import Path
5
-
6
- from scripts.bootstrap_benchmark_env import build_bootstrap_report
7
- from scripts.benchmark_checkpoint import choose_checkpoint_candidate
8
-
9
-
10
- def build_readiness_report(*, cache_dir: Path, hf_token_present: bool, dependencies_present: bool = True, missing_dependencies: list[str] | None = None, output_repo: str | None = None, tokenizer_repo: str | None = None) -> dict[str, object]:
11
- checkpoint = choose_checkpoint_candidate(cache_dir)
12
- tokenizer_dir = cache_dir / "tokenizer"
13
- retina_path = cache_dir / "retina.npz"
14
- tokenizer_ready = (tokenizer_dir / "tokenizer.pkl").exists() and (tokenizer_dir / "token_bytes.pt").exists()
15
- retina_ready = retina_path.exists()
16
- checkpoint_present = checkpoint is not None
17
- runtime = build_bootstrap_report(missing_dependencies=list(missing_dependencies or []))
18
- return {
19
- "cache_dir": str(cache_dir),
20
- "checkpoint_present": checkpoint_present,
21
- "checkpoint_path": str(checkpoint) if checkpoint is not None else None,
22
- "tokenizer_ready": tokenizer_ready,
23
- "retina_ready": retina_ready,
24
- "retina_path": str(retina_path),
25
- "hf_token_present": hf_token_present,
26
- "dependencies_present": dependencies_present,
27
- "missing_dependencies": list(missing_dependencies or []),
28
- "install_hint": runtime["install_hint"],
29
- "install_command": runtime["install_command"],
30
- "install_blockers": runtime["install_blockers"],
31
- "output_repo": output_repo,
32
- "tokenizer_repo": tokenizer_repo,
33
- "hydration_possible": bool(hf_token_present and output_repo and tokenizer_repo),
34
- "ready_for_hydra_benchmarks": checkpoint_present and tokenizer_ready and retina_ready and dependencies_present,
35
- }
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ from pathlib import Path
5
+
6
+ from scripts.bootstrap_benchmark_env import build_bootstrap_report
7
+ from scripts.benchmark_checkpoint import choose_checkpoint_candidate
8
+
9
+
10
+ def build_readiness_report(*, cache_dir: Path, hf_token_present: bool, dependencies_present: bool = True, missing_dependencies: list[str] | None = None, output_repo: str | None = None, tokenizer_repo: str | None = None) -> dict[str, object]:
11
+ checkpoint = choose_checkpoint_candidate(cache_dir)
12
+ tokenizer_dir = cache_dir / "tokenizer"
13
+ tokenizer_ready = (tokenizer_dir / "tokenizer.pkl").exists() and (tokenizer_dir / "token_bytes.pt").exists()
14
+ checkpoint_present = checkpoint is not None
15
+ runtime = build_bootstrap_report(missing_dependencies=list(missing_dependencies or []))
16
+ return {
17
+ "cache_dir": str(cache_dir),
18
+ "checkpoint_present": checkpoint_present,
19
+ "checkpoint_path": str(checkpoint) if checkpoint is not None else None,
20
+ "tokenizer_ready": tokenizer_ready,
21
+ "hf_token_present": hf_token_present,
22
+ "dependencies_present": dependencies_present,
23
+ "missing_dependencies": list(missing_dependencies or []),
24
+ "install_hint": runtime["install_hint"],
25
+ "install_command": runtime["install_command"],
26
+ "install_blockers": runtime["install_blockers"],
27
+ "output_repo": output_repo,
28
+ "tokenizer_repo": tokenizer_repo,
29
+ "hydration_possible": bool(hf_token_present and output_repo and tokenizer_repo),
30
+ "ready_for_hydra_benchmarks": checkpoint_present and tokenizer_ready and dependencies_present,
31
+ }
 
 
 
 
overlay/scripts/benchmark_runner.py CHANGED
@@ -1,327 +1,248 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import re
7
- import sys
8
- from pathlib import Path
9
- from typing import Any, Callable
10
-
11
- REPO_ROOT = Path(__file__).resolve().parents[1]
12
- if str(REPO_ROOT) not in sys.path:
13
- sys.path.insert(0, str(REPO_ROOT))
14
-
15
- LEDGER_TEMPLATE_PATH = REPO_ROOT / "artifacts" / "benchmark_ledger.template.json"
16
-
17
- from scripts.hydra_generation import build_hydra_generator
18
- from scripts.benchmark_datasets import resolve_benchmark_dataset as resolve_canonical_dataset
19
- from scripts.benchmark_suite import build_prompt, validate_sample
20
-
21
-
22
- class BenchmarkExecutionError(RuntimeError):
23
- def __init__(
24
- self,
25
- *,
26
- benchmark: str,
27
- sample: dict[str, Any],
28
- generated_output: str,
29
- cause: BaseException,
30
- extracted_output: str | None = None,
31
- ):
32
- super().__init__(str(cause))
33
- self.benchmark = benchmark
34
- self.sample = sample
35
- self.generated_output = generated_output
36
- self.cause = cause
37
- self.extracted_output = extracted_output
38
-
39
-
40
- def load_jsonl_samples(path: Path) -> list[dict[str, Any]]:
41
- rows: list[dict[str, Any]] = []
42
- for line in path.read_text(encoding="utf-8").splitlines():
43
- if line.strip():
44
- rows.append(json.loads(line))
45
- return rows
46
-
47
-
48
- def _normalize_samples_path(path: Path) -> Path:
49
- return path if path.is_absolute() else REPO_ROOT / path
50
-
51
-
52
- def _preview_text(text: str, *, limit: int = 1000) -> str:
53
- if len(text) <= limit:
54
- return text
55
- return text[:limit] + "\n...[truncated]"
56
-
57
-
58
- def extract_python_code(text: str) -> str:
59
- fenced = re.search(r"```python\s*(.*?)```", text, flags=re.IGNORECASE | re.DOTALL)
60
- if fenced:
61
- extracted = fenced.group(1).strip("\n")
62
- return extracted + "\n"
63
-
64
- lines = text.splitlines()
65
- for index, line in enumerate(lines):
66
- if line.startswith("def "):
67
- extracted = "\n".join(lines[index:]).strip("\n")
68
- return extracted + "\n"
69
- return text
70
-
71
-
72
- def _score_mbpp(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
73
- passed = 0
74
- for sample in samples:
75
- validate_sample("MBPP", sample)
76
- raw_output = generate_fn(build_prompt("MBPP", sample))
77
- code = extract_python_code(raw_output)
78
- namespace: dict[str, Any] = {}
79
- try:
80
- exec(code, namespace, namespace)
81
- for test in sample["tests"]:
82
- exec(test, namespace, namespace)
83
- except Exception as exc:
84
- raise BenchmarkExecutionError(
85
- benchmark="MBPP",
86
- sample=sample,
87
- generated_output=raw_output,
88
- cause=exc,
89
- extracted_output=code,
90
- ) from exc
91
- passed += 1
92
- return passed / len(samples) if samples else 0.0
93
-
94
-
95
- def _extract_last_number(text: str) -> str | None:
96
- matches = re.findall(r"-?\d+(?:\.\d+)?", text)
97
- return matches[-1] if matches else None
98
-
99
-
100
- def _score_gsm8k(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
101
- passed = 0
102
- for sample in samples:
103
- validate_sample("GSM8K", sample)
104
- output = generate_fn(build_prompt("GSM8K", sample))
105
- pred = _extract_last_number(output)
106
- if pred is not None and pred == str(sample["answer"]):
107
- passed += 1
108
- return passed / len(samples) if samples else 0.0
109
-
110
-
111
- def _score_humaneval(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
112
- passed = 0
113
- for sample in samples:
114
- validate_sample("HumanEval", sample)
115
- raw_output = generate_fn(build_prompt("HumanEval", sample))
116
- code = extract_python_code(raw_output)
117
- namespace: dict[str, Any] = {}
118
- try:
119
- exec(code, namespace, namespace)
120
- exec(sample["test"], namespace, namespace)
121
- except Exception as exc:
122
- raise BenchmarkExecutionError(
123
- benchmark="HumanEval",
124
- sample=sample,
125
- generated_output=raw_output,
126
- cause=exc,
127
- extracted_output=code,
128
- ) from exc
129
- passed += 1
130
- return passed / len(samples) if samples else 0.0
131
-
132
-
133
- def _score_arc(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
134
- passed = 0
135
- for sample in samples:
136
- validate_sample("ARC-Challenge", sample)
137
- output = generate_fn(build_prompt("ARC-Challenge", sample)).strip()
138
- if output == str(sample["answer"]):
139
- passed += 1
140
- return passed / len(samples) if samples else 0.0
141
-
142
-
143
- def run_benchmark(benchmark_name: str, path: Path, generate_fn: Callable[[str], str]) -> dict[str, Any]:
144
- samples = load_jsonl_samples(path)
145
- if benchmark_name == "MBPP":
146
- return {
147
- "benchmark": "MBPP",
148
- "primary_metric": "pass_at_1",
149
- "score": _score_mbpp(samples, generate_fn),
150
- "n_samples": len(samples),
151
- }
152
- if benchmark_name == "GSM8K":
153
- return {
154
- "benchmark": "GSM8K",
155
- "primary_metric": "exact_match",
156
- "score": _score_gsm8k(samples, generate_fn),
157
- "n_samples": len(samples),
158
- }
159
- if benchmark_name == "HumanEval":
160
- return {
161
- "benchmark": "HumanEval",
162
- "primary_metric": "pass_at_1",
163
- "score": _score_humaneval(samples, generate_fn),
164
- "n_samples": len(samples),
165
- }
166
- if benchmark_name == "ARC-Challenge":
167
- return {
168
- "benchmark": "ARC-Challenge",
169
- "primary_metric": "accuracy",
170
- "score": _score_arc(samples, generate_fn),
171
- "n_samples": len(samples),
172
- }
173
- raise ValueError(f"Unsupported runnable benchmark: {benchmark_name}")
174
-
175
-
176
- def write_benchmark_result(path: Path, payload: dict[str, Any]) -> None:
177
- path.parent.mkdir(parents=True, exist_ok=True)
178
- path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
179
-
180
-
181
- def append_benchmark_run_record(
182
- ledger_path: Path,
183
- result: dict[str, Any],
184
- *,
185
- benchmark_name: str,
186
- variant: str,
187
- seed: int,
188
- samples_path: Path,
189
- ) -> None:
190
- if not ledger_path.exists():
191
- ledger_path.parent.mkdir(parents=True, exist_ok=True)
192
- ledger_path.write_text(LEDGER_TEMPLATE_PATH.read_text(encoding="utf-8"), encoding="utf-8")
193
- payload = json.loads(ledger_path.read_text(encoding="utf-8"))
194
- run_records = payload.setdefault("run_records", [])
195
- if len(run_records) == 1 and run_records[0].get("run_id") == "example-run-0001":
196
- run_records.clear()
197
- run_records.append(
198
- {
199
- "run_id": result.get("run_id", f"{benchmark_name.lower()}-{seed}"),
200
- "commit": "HEAD",
201
- "model_family": "hydra",
202
- "variant": variant,
203
- "seed": seed,
204
- "hardware": {
205
- "hardware_class": payload.get("benchmark_cycle", {}).get("hardware_class", "unknown"),
206
- },
207
- "budget": {
208
- "budget_mode": payload.get("benchmark_cycle", {}).get("budget_modes", [None])[0],
209
- },
210
- "capability": {
211
- "coding_score": result["score"] if benchmark_name in {"MBPP", "HumanEval"} else None,
212
- "reasoning_score": result["score"] if benchmark_name in {"GSM8K", "ARC-Challenge"} else None,
213
- },
214
- "artifacts": {
215
- "samples_path": str(samples_path),
216
- },
217
- }
218
- )
219
- ledger_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
220
-
221
-
222
- def resolve_samples_path(benchmark_name: str, samples: Path | None, suite_path: Path) -> Path:
223
- if samples is not None:
224
- return _normalize_samples_path(samples)
225
- payload = json.loads(suite_path.read_text(encoding="utf-8"))
226
- for section in ("coding_benchmarks", "reasoning_benchmarks"):
227
- if section not in payload:
228
- continue
229
- for slot in ("fast_iteration", "milestone"):
230
- entry = payload[section].get(slot)
231
- if isinstance(entry, dict) and entry.get("name") == benchmark_name and "sample_path" in entry:
232
- return _normalize_samples_path(Path(entry["sample_path"]))
233
- try:
234
- return _normalize_samples_path(resolve_canonical_dataset(benchmark_name, None))
235
- except ValueError:
236
- raise ValueError(f"No sample path found for benchmark: {benchmark_name}")
237
-
238
-
239
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
240
- parser = argparse.ArgumentParser(description="Run a local benchmark against JSONL samples")
241
- parser.add_argument("--benchmark", required=True, choices=["MBPP", "GSM8K", "HumanEval", "ARC-Challenge"])
242
- parser.add_argument("--samples", type=Path)
243
- parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
244
- parser.add_argument("--out", type=Path)
245
- parser.add_argument("--ledger", type=Path)
246
- parser.add_argument("--variant", default="hydra_full")
247
- parser.add_argument("--seed", type=int, default=42)
248
- parser.add_argument("--generator-mode", choices=["stub", "hydra"], default="stub")
249
- parser.add_argument("--checkpoint", type=Path)
250
- parser.add_argument("--device")
251
- parser.add_argument("--max-new-tokens", type=int, default=256)
252
- parser.add_argument("--temperature", type=float, default=0.2)
253
- parser.add_argument("--top-p", type=float, default=0.95)
254
- return parser.parse_args(argv)
255
-
256
-
257
- def main(argv: list[str] | None = None) -> int:
258
- args = parse_args(argv)
259
- sample_path = resolve_samples_path(args.benchmark, args.samples, args.suite)
260
- try:
261
- if args.generator_mode == "hydra":
262
- generator = build_hydra_generator(
263
- checkpoint_path=args.checkpoint,
264
- device=args.device,
265
- max_new_tokens=args.max_new_tokens,
266
- temperature=args.temperature,
267
- top_p=args.top_p,
268
- )
269
- else:
270
- def generator(prompt: str) -> str:
271
- return prompt
272
-
273
- result = run_benchmark(args.benchmark, sample_path, generator)
274
- exit_code = 0
275
- except FileNotFoundError as exc:
276
- result = {
277
- "benchmark": args.benchmark,
278
- "status": "failed",
279
- "failure_type": "missing_checkpoint",
280
- "error": str(exc),
281
- "n_samples": 0,
282
- }
283
- exit_code = 1
284
- except BenchmarkExecutionError as exc:
285
- result = {
286
- "benchmark": args.benchmark,
287
- "status": "failed",
288
- "failure_type": type(exc.cause).__name__,
289
- "error": str(exc.cause),
290
- "n_samples": 0,
291
- "debug": {
292
- "sample": {
293
- "task_id": exc.sample.get("task_id"),
294
- "question": exc.sample.get("question"),
295
- },
296
- "generated_output_preview": _preview_text(exc.generated_output),
297
- "extracted_code_preview": _preview_text(exc.extracted_output) if exc.extracted_output is not None else None,
298
- },
299
- }
300
- exit_code = 1
301
- except Exception as exc: # noqa: BLE001
302
- result = {
303
- "benchmark": args.benchmark,
304
- "status": "failed",
305
- "failure_type": type(exc).__name__,
306
- "error": str(exc),
307
- "n_samples": 0,
308
- }
309
- exit_code = 1
310
-
311
- if args.out is not None:
312
- write_benchmark_result(args.out, result)
313
- if args.ledger is not None and exit_code == 0:
314
- append_benchmark_run_record(
315
- args.ledger,
316
- result,
317
- benchmark_name=args.benchmark,
318
- variant=args.variant,
319
- seed=args.seed,
320
- samples_path=sample_path,
321
- )
322
- print(json.dumps(result, indent=2, sort_keys=True))
323
- return exit_code
324
-
325
-
326
- if __name__ == "__main__":
327
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import re
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Any, Callable
10
+
11
+ REPO_ROOT = Path(__file__).resolve().parents[1]
12
+ if str(REPO_ROOT) not in sys.path:
13
+ sys.path.insert(0, str(REPO_ROOT))
14
+
15
+ LEDGER_TEMPLATE_PATH = REPO_ROOT / "artifacts" / "benchmark_ledger.template.json"
16
+
17
+ from scripts.hydra_generation import build_hydra_generator
18
+ from scripts.benchmark_datasets import resolve_benchmark_dataset as resolve_canonical_dataset
19
+ from scripts.benchmark_suite import build_prompt, validate_sample
20
+
21
+
22
+ def load_jsonl_samples(path: Path) -> list[dict[str, Any]]:
23
+ rows: list[dict[str, Any]] = []
24
+ for line in path.read_text(encoding="utf-8").splitlines():
25
+ if line.strip():
26
+ rows.append(json.loads(line))
27
+ return rows
28
+
29
+
30
+ def _score_mbpp(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
31
+ passed = 0
32
+ for sample in samples:
33
+ validate_sample("MBPP", sample)
34
+ code = generate_fn(build_prompt("MBPP", sample))
35
+ namespace: dict[str, Any] = {}
36
+ exec(code, namespace, namespace)
37
+ for test in sample["tests"]:
38
+ exec(test, namespace, namespace)
39
+ passed += 1
40
+ return passed / len(samples) if samples else 0.0
41
+
42
+
43
+ def _extract_last_number(text: str) -> str | None:
44
+ matches = re.findall(r"-?\d+(?:\.\d+)?", text)
45
+ return matches[-1] if matches else None
46
+
47
+
48
+ def _score_gsm8k(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
49
+ passed = 0
50
+ for sample in samples:
51
+ validate_sample("GSM8K", sample)
52
+ output = generate_fn(build_prompt("GSM8K", sample))
53
+ pred = _extract_last_number(output)
54
+ if pred is not None and pred == str(sample["answer"]):
55
+ passed += 1
56
+ return passed / len(samples) if samples else 0.0
57
+
58
+
59
+ def _score_humaneval(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
60
+ passed = 0
61
+ for sample in samples:
62
+ validate_sample("HumanEval", sample)
63
+ code = generate_fn(build_prompt("HumanEval", sample))
64
+ namespace: dict[str, Any] = {}
65
+ exec(code, namespace, namespace)
66
+ exec(sample["test"], namespace, namespace)
67
+ passed += 1
68
+ return passed / len(samples) if samples else 0.0
69
+
70
+
71
+ def _score_arc(samples: list[dict[str, Any]], generate_fn: Callable[[str], str]) -> float:
72
+ passed = 0
73
+ for sample in samples:
74
+ validate_sample("ARC-Challenge", sample)
75
+ output = generate_fn(build_prompt("ARC-Challenge", sample)).strip()
76
+ if output == str(sample["answer"]):
77
+ passed += 1
78
+ return passed / len(samples) if samples else 0.0
79
+
80
+
81
+ def run_benchmark(benchmark_name: str, path: Path, generate_fn: Callable[[str], str]) -> dict[str, Any]:
82
+ samples = load_jsonl_samples(path)
83
+ if benchmark_name == "MBPP":
84
+ return {
85
+ "benchmark": "MBPP",
86
+ "primary_metric": "pass_at_1",
87
+ "score": _score_mbpp(samples, generate_fn),
88
+ "n_samples": len(samples),
89
+ }
90
+ if benchmark_name == "GSM8K":
91
+ return {
92
+ "benchmark": "GSM8K",
93
+ "primary_metric": "exact_match",
94
+ "score": _score_gsm8k(samples, generate_fn),
95
+ "n_samples": len(samples),
96
+ }
97
+ if benchmark_name == "HumanEval":
98
+ return {
99
+ "benchmark": "HumanEval",
100
+ "primary_metric": "pass_at_1",
101
+ "score": _score_humaneval(samples, generate_fn),
102
+ "n_samples": len(samples),
103
+ }
104
+ if benchmark_name == "ARC-Challenge":
105
+ return {
106
+ "benchmark": "ARC-Challenge",
107
+ "primary_metric": "accuracy",
108
+ "score": _score_arc(samples, generate_fn),
109
+ "n_samples": len(samples),
110
+ }
111
+ raise ValueError(f"Unsupported runnable benchmark: {benchmark_name}")
112
+
113
+
114
+ def write_benchmark_result(path: Path, payload: dict[str, Any]) -> None:
115
+ path.parent.mkdir(parents=True, exist_ok=True)
116
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
117
+
118
+
119
+ def append_benchmark_run_record(
120
+ ledger_path: Path,
121
+ result: dict[str, Any],
122
+ *,
123
+ benchmark_name: str,
124
+ variant: str,
125
+ seed: int,
126
+ samples_path: Path,
127
+ ) -> None:
128
+ if not ledger_path.exists():
129
+ ledger_path.parent.mkdir(parents=True, exist_ok=True)
130
+ ledger_path.write_text(LEDGER_TEMPLATE_PATH.read_text(encoding="utf-8"), encoding="utf-8")
131
+ payload = json.loads(ledger_path.read_text(encoding="utf-8"))
132
+ run_records = payload.setdefault("run_records", [])
133
+ if len(run_records) == 1 and run_records[0].get("run_id") == "example-run-0001":
134
+ run_records.clear()
135
+ run_records.append(
136
+ {
137
+ "run_id": result.get("run_id", f"{benchmark_name.lower()}-{seed}"),
138
+ "commit": "HEAD",
139
+ "model_family": "hydra",
140
+ "variant": variant,
141
+ "seed": seed,
142
+ "hardware": {
143
+ "hardware_class": payload.get("benchmark_cycle", {}).get("hardware_class", "unknown"),
144
+ },
145
+ "budget": {
146
+ "budget_mode": payload.get("benchmark_cycle", {}).get("budget_modes", [None])[0],
147
+ },
148
+ "capability": {
149
+ "coding_score": result["score"] if benchmark_name in {"MBPP", "HumanEval"} else None,
150
+ "reasoning_score": result["score"] if benchmark_name in {"GSM8K", "ARC-Challenge"} else None,
151
+ },
152
+ "artifacts": {
153
+ "samples_path": str(samples_path),
154
+ },
155
+ }
156
+ )
157
+ ledger_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
158
+
159
+
160
+ def resolve_samples_path(benchmark_name: str, samples: Path | None, suite_path: Path) -> Path:
161
+ if samples is not None:
162
+ return samples
163
+ payload = json.loads(suite_path.read_text(encoding="utf-8"))
164
+ for section in ("coding_benchmarks", "reasoning_benchmarks"):
165
+ if section not in payload:
166
+ continue
167
+ for slot in ("fast_iteration", "milestone"):
168
+ entry = payload[section].get(slot)
169
+ if isinstance(entry, dict) and entry.get("name") == benchmark_name and "sample_path" in entry:
170
+ return Path(entry["sample_path"])
171
+ try:
172
+ return resolve_canonical_dataset(benchmark_name, None)
173
+ except ValueError:
174
+ raise ValueError(f"No sample path found for benchmark: {benchmark_name}")
175
+
176
+
177
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
178
+ parser = argparse.ArgumentParser(description="Run a local benchmark against JSONL samples")
179
+ parser.add_argument("--benchmark", required=True, choices=["MBPP", "GSM8K", "HumanEval", "ARC-Challenge"])
180
+ parser.add_argument("--samples", type=Path)
181
+ parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
182
+ parser.add_argument("--out", type=Path)
183
+ parser.add_argument("--ledger", type=Path)
184
+ parser.add_argument("--variant", default="hydra_full")
185
+ parser.add_argument("--seed", type=int, default=42)
186
+ parser.add_argument("--generator-mode", choices=["stub", "hydra"], default="stub")
187
+ parser.add_argument("--checkpoint", type=Path)
188
+ parser.add_argument("--device")
189
+ parser.add_argument("--max-new-tokens", type=int, default=256)
190
+ parser.add_argument("--temperature", type=float, default=0.2)
191
+ parser.add_argument("--top-p", type=float, default=0.95)
192
+ return parser.parse_args(argv)
193
+
194
+
195
+ def main(argv: list[str] | None = None) -> int:
196
+ args = parse_args(argv)
197
+ sample_path = resolve_samples_path(args.benchmark, args.samples, args.suite)
198
+ try:
199
+ if args.generator_mode == "hydra":
200
+ generator = build_hydra_generator(
201
+ checkpoint_path=args.checkpoint,
202
+ device=args.device,
203
+ max_new_tokens=args.max_new_tokens,
204
+ temperature=args.temperature,
205
+ top_p=args.top_p,
206
+ )
207
+ else:
208
+ def generator(prompt: str) -> str:
209
+ return prompt
210
+
211
+ result = run_benchmark(args.benchmark, sample_path, generator)
212
+ exit_code = 0
213
+ except FileNotFoundError as exc:
214
+ result = {
215
+ "benchmark": args.benchmark,
216
+ "status": "failed",
217
+ "failure_type": "missing_checkpoint",
218
+ "error": str(exc),
219
+ "n_samples": 0,
220
+ }
221
+ exit_code = 1
222
+ except Exception as exc: # noqa: BLE001
223
+ result = {
224
+ "benchmark": args.benchmark,
225
+ "status": "failed",
226
+ "failure_type": type(exc).__name__,
227
+ "error": str(exc),
228
+ "n_samples": 0,
229
+ }
230
+ exit_code = 1
231
+
232
+ if args.out is not None:
233
+ write_benchmark_result(args.out, result)
234
+ if args.ledger is not None and exit_code == 0:
235
+ append_benchmark_run_record(
236
+ args.ledger,
237
+ result,
238
+ benchmark_name=args.benchmark,
239
+ variant=args.variant,
240
+ seed=args.seed,
241
+ samples_path=sample_path,
242
+ )
243
+ print(json.dumps(result, indent=2, sort_keys=True))
244
+ return exit_code
245
+
246
+
247
+ if __name__ == "__main__":
248
+ raise SystemExit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/benchmark_suite.py CHANGED
@@ -1,84 +1,84 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Any
8
-
9
-
10
- @dataclass(frozen=True)
11
- class BenchmarkSpec:
12
- name: str
13
- family: str
14
- required_fields: tuple[str, ...]
15
-
16
-
17
- REGISTRY: dict[str, BenchmarkSpec] = {
18
- "MBPP": BenchmarkSpec("MBPP", "coding", ("task_id", "prompt", "tests")),
19
- "HumanEval": BenchmarkSpec("HumanEval", "coding", ("task_id", "prompt", "test")),
20
- "GSM8K": BenchmarkSpec("GSM8K", "reasoning", ("question", "answer")),
21
- "ARC-Challenge": BenchmarkSpec("ARC-Challenge", "reasoning", ("question", "choices", "answer")),
22
- }
23
-
24
-
25
- def validate_sample(benchmark_name: str, sample: dict[str, Any]) -> None:
26
- spec = REGISTRY[benchmark_name]
27
- for field in spec.required_fields:
28
- if field not in sample:
29
- raise ValueError(f"{benchmark_name} sample missing required field: {field}")
30
-
31
-
32
- def build_prompt(benchmark_name: str, sample: dict[str, Any]) -> str:
33
- validate_sample(benchmark_name, sample)
34
- if benchmark_name == "MBPP":
35
- tests = sample["tests"]
36
- rendered_tests = "\n".join(str(t) for t in tests)
37
- return (
38
- "Write a Python function that solves the task below.\n\n"
39
- f"Task:\n{sample['prompt']}\n\n"
40
- f"Tests:\n{rendered_tests}\n"
41
- )
42
- if benchmark_name == "HumanEval":
43
- return (
44
- "Complete the following Python function exactly as specified.\n\n"
45
- f"Prompt:\n{sample['prompt']}\n\n"
46
- f"Reference test:\n{sample['test']}\n"
47
- )
48
- if benchmark_name == "GSM8K":
49
- return f"Solve the following math word problem. Return only the final answer.\n\nQuestion: {sample['question']}\n"
50
- if benchmark_name == "ARC-Challenge":
51
- choices = sample["choices"]
52
- rendered_choices = "\n".join(f"- {choice}" for choice in choices)
53
- return (
54
- "Answer the following multiple-choice science question. Return only the correct option text or label.\n\n"
55
- f"Question: {sample['question']}\nChoices:\n{rendered_choices}\n"
56
- )
57
- raise ValueError(f"Unknown benchmark: {benchmark_name}")
58
-
59
-
60
- def load_cycle_benchmark_suite(path: Path) -> dict[str, dict[str, BenchmarkSpec]]:
61
- payload = json.loads(path.read_text(encoding="utf-8"))
62
- out: dict[str, dict[str, BenchmarkSpec]] = {"coding_benchmarks": {}, "reasoning_benchmarks": {}}
63
- for section in ("coding_benchmarks", "reasoning_benchmarks"):
64
- if section not in payload:
65
- raise ValueError(f"missing benchmark section: {section}")
66
- for slot in ("fast_iteration", "milestone"):
67
- if slot not in payload[section]:
68
- raise ValueError(f"missing benchmark slot: {section}.{slot}")
69
- name = payload[section][slot]["name"]
70
- if name not in REGISTRY:
71
- raise ValueError(f"unsupported benchmark: {name}")
72
- out[section][slot] = REGISTRY[name]
73
- return out
74
-
75
-
76
- def main() -> int:
77
- path = Path("artifacts/benchmark_suite.cycle1.json")
78
- suite = load_cycle_benchmark_suite(path)
79
- print(json.dumps({k: {slot: spec.name for slot, spec in section.items()} for k, section in suite.items()}, indent=2))
80
- return 0
81
-
82
-
83
- if __name__ == "__main__":
84
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class BenchmarkSpec:
12
+ name: str
13
+ family: str
14
+ required_fields: tuple[str, ...]
15
+
16
+
17
+ REGISTRY: dict[str, BenchmarkSpec] = {
18
+ "MBPP": BenchmarkSpec("MBPP", "coding", ("task_id", "prompt", "tests")),
19
+ "HumanEval": BenchmarkSpec("HumanEval", "coding", ("task_id", "prompt", "test")),
20
+ "GSM8K": BenchmarkSpec("GSM8K", "reasoning", ("question", "answer")),
21
+ "ARC-Challenge": BenchmarkSpec("ARC-Challenge", "reasoning", ("question", "choices", "answer")),
22
+ }
23
+
24
+
25
+ def validate_sample(benchmark_name: str, sample: dict[str, Any]) -> None:
26
+ spec = REGISTRY[benchmark_name]
27
+ for field in spec.required_fields:
28
+ if field not in sample:
29
+ raise ValueError(f"{benchmark_name} sample missing required field: {field}")
30
+
31
+
32
+ def build_prompt(benchmark_name: str, sample: dict[str, Any]) -> str:
33
+ validate_sample(benchmark_name, sample)
34
+ if benchmark_name == "MBPP":
35
+ tests = sample["tests"]
36
+ rendered_tests = "\n".join(str(t) for t in tests)
37
+ return (
38
+ "Write a Python function that solves the task below.\n\n"
39
+ f"Task:\n{sample['prompt']}\n\n"
40
+ f"Tests:\n{rendered_tests}\n"
41
+ )
42
+ if benchmark_name == "HumanEval":
43
+ return (
44
+ "Complete the following Python function exactly as specified.\n\n"
45
+ f"Prompt:\n{sample['prompt']}\n\n"
46
+ f"Reference test:\n{sample['test']}\n"
47
+ )
48
+ if benchmark_name == "GSM8K":
49
+ return f"Solve the following math word problem. Return only the final answer.\n\nQuestion: {sample['question']}\n"
50
+ if benchmark_name == "ARC-Challenge":
51
+ choices = sample["choices"]
52
+ rendered_choices = "\n".join(f"- {choice}" for choice in choices)
53
+ return (
54
+ "Answer the following multiple-choice science question. Return only the correct option text or label.\n\n"
55
+ f"Question: {sample['question']}\nChoices:\n{rendered_choices}\n"
56
+ )
57
+ raise ValueError(f"Unknown benchmark: {benchmark_name}")
58
+
59
+
60
+ def load_cycle_benchmark_suite(path: Path) -> dict[str, dict[str, BenchmarkSpec]]:
61
+ payload = json.loads(path.read_text(encoding="utf-8"))
62
+ out: dict[str, dict[str, BenchmarkSpec]] = {"coding_benchmarks": {}, "reasoning_benchmarks": {}}
63
+ for section in ("coding_benchmarks", "reasoning_benchmarks"):
64
+ if section not in payload:
65
+ raise ValueError(f"missing benchmark section: {section}")
66
+ for slot in ("fast_iteration", "milestone"):
67
+ if slot not in payload[section]:
68
+ raise ValueError(f"missing benchmark slot: {section}.{slot}")
69
+ name = payload[section][slot]["name"]
70
+ if name not in REGISTRY:
71
+ raise ValueError(f"unsupported benchmark: {name}")
72
+ out[section][slot] = REGISTRY[name]
73
+ return out
74
+
75
+
76
+ def main() -> int:
77
+ path = Path("artifacts/benchmark_suite.cycle1.json")
78
+ suite = load_cycle_benchmark_suite(path)
79
+ print(json.dumps({k: {slot: spec.name for slot, spec in section.items()} for k, section in suite.items()}, indent=2))
80
+ return 0
81
+
82
+
83
+ if __name__ == "__main__":
84
+ raise SystemExit(main())
overlay/scripts/bootstrap_benchmark_env.py CHANGED
@@ -1,63 +1,63 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- import shutil
6
-
7
- import torch
8
-
9
-
10
- PACKAGE_MAP = {
11
- "mamba_ssm": "mamba-ssm",
12
- "transformers": "transformers",
13
- }
14
-
15
-
16
- def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
17
- packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
18
- return [] if not packages else ["python", "-m", "pip", "install", *packages]
19
-
20
-
21
- def diagnose_install_blockers(
22
- *,
23
- missing_dependencies: list[str],
24
- torch_version: str,
25
- cuda_available: bool,
26
- nvcc_present: bool,
27
- ) -> list[str]:
28
- blockers: list[str] = []
29
- if "mamba_ssm" in missing_dependencies:
30
- if "+cpu" in torch_version or not cuda_available:
31
- blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
32
- if not nvcc_present:
33
- blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
34
- return blockers
35
-
36
-
37
- def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
38
- ready = len(missing_dependencies) == 0
39
- packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
40
- install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
41
- blockers = diagnose_install_blockers(
42
- missing_dependencies=missing_dependencies,
43
- torch_version=getattr(torch, "__version__", "unknown"),
44
- cuda_available=torch.cuda.is_available(),
45
- nvcc_present=shutil.which("nvcc") is not None,
46
- )
47
- return {
48
- "ready": ready,
49
- "missing_dependencies": list(missing_dependencies),
50
- "install_hint": install_hint,
51
- "install_command": build_install_command(missing_dependencies=missing_dependencies),
52
- "install_blockers": blockers,
53
- }
54
-
55
-
56
- def main() -> int:
57
- report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
58
- print(json.dumps(report, indent=2, sort_keys=True))
59
- return 0
60
-
61
-
62
- if __name__ == "__main__":
63
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import shutil
6
+
7
+ import torch
8
+
9
+
10
+ PACKAGE_MAP = {
11
+ "mamba_ssm": "mamba-ssm",
12
+ "transformers": "transformers",
13
+ }
14
+
15
+
16
+ def build_install_command(*, missing_dependencies: list[str]) -> list[str]:
17
+ packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
18
+ return [] if not packages else ["python", "-m", "pip", "install", *packages]
19
+
20
+
21
+ def diagnose_install_blockers(
22
+ *,
23
+ missing_dependencies: list[str],
24
+ torch_version: str,
25
+ cuda_available: bool,
26
+ nvcc_present: bool,
27
+ ) -> list[str]:
28
+ blockers: list[str] = []
29
+ if "mamba_ssm" in missing_dependencies:
30
+ if "+cpu" in torch_version or not cuda_available:
31
+ blockers.append("mamba_ssm install likely blocked by CPU-only torch runtime")
32
+ if not nvcc_present:
33
+ blockers.append("mamba_ssm install likely blocked because nvcc is unavailable")
34
+ return blockers
35
+
36
+
37
+ def build_bootstrap_report(*, missing_dependencies: list[str]) -> dict[str, object]:
38
+ ready = len(missing_dependencies) == 0
39
+ packages = [PACKAGE_MAP.get(name, name) for name in missing_dependencies]
40
+ install_hint = "" if ready else f"Install missing benchmark dependencies: {', '.join(packages)}"
41
+ blockers = diagnose_install_blockers(
42
+ missing_dependencies=missing_dependencies,
43
+ torch_version=getattr(torch, "__version__", "unknown"),
44
+ cuda_available=torch.cuda.is_available(),
45
+ nvcc_present=shutil.which("nvcc") is not None,
46
+ )
47
+ return {
48
+ "ready": ready,
49
+ "missing_dependencies": list(missing_dependencies),
50
+ "install_hint": install_hint,
51
+ "install_command": build_install_command(missing_dependencies=missing_dependencies),
52
+ "install_blockers": blockers,
53
+ }
54
+
55
+
56
+ def main() -> int:
57
+ report = build_bootstrap_report(missing_dependencies=["mamba_ssm"])
58
+ print(json.dumps(report, indent=2, sort_keys=True))
59
+ return 0
60
+
61
+
62
+ if __name__ == "__main__":
63
+ raise SystemExit(main())
overlay/scripts/cycle1a_report.py CHANGED
@@ -1,52 +1,52 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import json
5
- from collections import defaultdict
6
- from pathlib import Path
7
- from typing import Any
8
-
9
-
10
- def build_cycle1a_report(run_dir: Path) -> dict[str, Any]:
11
- runs = []
12
- for path in sorted(run_dir.glob("*.json")):
13
- try:
14
- payload = json.loads(path.read_text(encoding="utf-8"))
15
- except Exception:
16
- continue
17
- if isinstance(payload, dict) and "benchmark" in payload:
18
- runs.append((path.name, payload))
19
-
20
- n_failed = sum(1 for _, payload in runs if payload.get("status") == "failed")
21
- by_benchmark: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
22
- for filename, payload in runs:
23
- if payload.get("status") == "failed":
24
- continue
25
- parts = filename.removesuffix('.json').split('_')
26
- if len(parts) < 3:
27
- continue
28
- benchmark = payload["benchmark"]
29
- variant = '_'.join(parts[1:-1])
30
- score = float(payload.get("score", 0.0))
31
- slot = by_benchmark.setdefault(benchmark, {}).setdefault(variant, {"scores": []})
32
- slot["scores"].append(score)
33
-
34
- for benchmark, variants in by_benchmark.items():
35
- for variant, slot in variants.items():
36
- scores = slot.pop("scores")
37
- slot["mean_score"] = sum(scores) / len(scores)
38
- slot["n_scores"] = len(scores)
39
-
40
- if runs and n_failed == len(runs):
41
- panel_status = "blocked"
42
- elif n_failed > 0:
43
- panel_status = "partial"
44
- else:
45
- panel_status = "ready"
46
-
47
- return {
48
- "n_runs": len(runs),
49
- "n_failed": n_failed,
50
- "panel_status": panel_status,
51
- "by_benchmark": by_benchmark,
52
- }
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ def build_cycle1a_report(run_dir: Path) -> dict[str, Any]:
11
+ runs = []
12
+ for path in sorted(run_dir.glob("*.json")):
13
+ try:
14
+ payload = json.loads(path.read_text(encoding="utf-8"))
15
+ except Exception:
16
+ continue
17
+ if isinstance(payload, dict) and "benchmark" in payload:
18
+ runs.append((path.name, payload))
19
+
20
+ n_failed = sum(1 for _, payload in runs if payload.get("status") == "failed")
21
+ by_benchmark: dict[str, dict[str, dict[str, float]]] = defaultdict(dict)
22
+ for filename, payload in runs:
23
+ if payload.get("status") == "failed":
24
+ continue
25
+ parts = filename.removesuffix('.json').split('_')
26
+ if len(parts) < 3:
27
+ continue
28
+ benchmark = payload["benchmark"]
29
+ variant = '_'.join(parts[1:-1])
30
+ score = float(payload.get("score", 0.0))
31
+ slot = by_benchmark.setdefault(benchmark, {}).setdefault(variant, {"scores": []})
32
+ slot["scores"].append(score)
33
+
34
+ for benchmark, variants in by_benchmark.items():
35
+ for variant, slot in variants.items():
36
+ scores = slot.pop("scores")
37
+ slot["mean_score"] = sum(scores) / len(scores)
38
+ slot["n_scores"] = len(scores)
39
+
40
+ if runs and n_failed == len(runs):
41
+ panel_status = "blocked"
42
+ elif n_failed > 0:
43
+ panel_status = "partial"
44
+ else:
45
+ panel_status = "ready"
46
+
47
+ return {
48
+ "n_runs": len(runs),
49
+ "n_failed": n_failed,
50
+ "panel_status": panel_status,
51
+ "by_benchmark": by_benchmark,
52
+ }
overlay/scripts/cycle_executor.py CHANGED
@@ -1,332 +1,312 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import importlib.util
6
- import importlib
7
- import json
8
- import os
9
- import subprocess
10
- import sys
11
- from pathlib import Path
12
- from typing import Any
13
-
14
- from scripts.benchmark_preflight import build_readiness_report
15
- from scripts.hf_routing import resolve_routing
16
-
17
-
18
- REPO_ROOT = Path(__file__).resolve().parents[1]
19
- FREEZE_PATH = REPO_ROOT / "artifacts" / "cycle_1_execution_freeze.json"
20
- RUNNER_PATH = REPO_ROOT / "scripts" / "benchmark_runner.py"
21
-
22
-
23
- def active_hf_token() -> str | None:
24
- token = os.environ.get("HF_TOKEN")
25
- if token:
26
- return token
27
- try:
28
- from huggingface_hub.utils import get_token
29
- return get_token()
30
- except Exception:
31
- return None
32
-
33
-
34
- def missing_benchmark_dependencies() -> list[str]:
35
- required = ["mamba_ssm", "transformers"]
36
- missing: list[str] = []
37
- for name in required:
38
- try:
39
- spec = importlib.util.find_spec(name)
40
- except (ImportError, ValueError):
41
- spec = None
42
- if spec is None:
43
- try:
44
- importlib.import_module(name)
45
- except Exception:
46
- missing.append(name)
47
- return missing
48
-
49
-
50
- def load_cycle_freeze(path: Path) -> dict[str, Any]:
51
- return json.loads(path.read_text(encoding="utf-8"))
52
-
53
-
54
- def load_cycle_benchmarks(path: Path) -> list[str]:
55
- payload = json.loads(path.read_text(encoding="utf-8"))
56
- out: list[str] = []
57
- for section in ("coding_benchmarks", "reasoning_benchmarks"):
58
- for slot in ("fast_iteration", "milestone"):
59
- entry = payload.get(section, {}).get(slot)
60
- if isinstance(entry, dict) and entry.get("name"):
61
- out.append(str(entry["name"]))
62
- return out
63
-
64
-
65
- def variant_env_for_benchmark(freeze: dict[str, Any], variant: str) -> dict[str, str]:
66
- variant_cfg = freeze["variants"][variant]
67
- return {str(k): str(v) for k, v in variant_cfg.get("env", {}).items()}
68
-
69
-
70
- def decode_config_for_benchmark(freeze: dict[str, Any], benchmark: str) -> dict[str, Any]:
71
- for section in ("coding_benchmarks", "reasoning_benchmarks"):
72
- for slot in ("fast_iteration", "milestone"):
73
- entry = freeze.get(section, {}).get(slot)
74
- if isinstance(entry, dict) and entry.get("name") == benchmark:
75
- return dict(entry.get("decode", {}))
76
- return {}
77
-
78
-
79
- def build_preflight_report(
80
- *,
81
- cache_dir: Path,
82
- output_repo: str | None = None,
83
- tokenizer_repo: str | None = None,
84
- ) -> dict[str, object]:
85
- return build_readiness_report(
86
- cache_dir=cache_dir,
87
- hf_token_present=bool(active_hf_token()),
88
- dependencies_present=not bool(missing_benchmark_dependencies()),
89
- missing_dependencies=missing_benchmark_dependencies(),
90
- output_repo=output_repo,
91
- tokenizer_repo=tokenizer_repo,
92
- )
93
-
94
-
95
- def write_preflight_report(path: Path, payload: dict[str, object]) -> None:
96
- path.parent.mkdir(parents=True, exist_ok=True)
97
- path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
98
-
99
-
100
- def write_cycle_summary(path: Path, payload: list[dict[str, Any]]) -> None:
101
- path.parent.mkdir(parents=True, exist_ok=True)
102
- path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
103
-
104
-
105
- def build_remote_checkpoint_report(output_repo: str, token: str | None) -> dict[str, Any]:
106
- from huggingface_hub import HfApi
107
-
108
- from scripts.benchmark_checkpoint_report import build_checkpoint_report
109
-
110
- files = HfApi(token=token).list_repo_files(repo_id=output_repo, repo_type="model", token=token)
111
- return build_checkpoint_report(files)
112
-
113
-
114
- def ensure_benchmark_assets(
115
- *,
116
- cache_dir: Path,
117
- output_repo: str,
118
- tokenizer_repo: str,
119
- token: str | None,
120
- hydrate: bool,
121
- ) -> dict[str, str] | None:
122
- if not hydrate:
123
- return None
124
- from scripts.benchmark_assets import hydrate_benchmark_assets
125
-
126
- return hydrate_benchmark_assets(
127
- cache_dir=cache_dir,
128
- output_repo=output_repo,
129
- tokenizer_repo=tokenizer_repo,
130
- token=token,
131
- )
132
-
133
-
134
- def build_benchmark_command(
135
- freeze: dict[str, Any],
136
- *,
137
- benchmark: str,
138
- variant: str,
139
- seed: int,
140
- out_dir: Path,
141
- ) -> tuple[list[str], dict[str, str]]:
142
- env = os.environ.copy()
143
- env.update(variant_env_for_benchmark(freeze, variant))
144
- env["HYDRA_SEED"] = str(seed)
145
- decode_cfg = decode_config_for_benchmark(freeze, benchmark)
146
-
147
- out_dir.mkdir(parents=True, exist_ok=True)
148
- result_path = out_dir / f"{benchmark.lower()}_{variant}_seed{seed}.json"
149
- ledger_path = out_dir / "benchmark_ledger.json"
150
- cmd = [
151
- sys.executable,
152
- str(RUNNER_PATH),
153
- "--benchmark",
154
- benchmark,
155
- "--generator-mode",
156
- "hydra",
157
- "--out",
158
- str(result_path),
159
- "--ledger",
160
- str(ledger_path),
161
- "--variant",
162
- variant,
163
- "--seed",
164
- str(seed),
165
- "--max-new-tokens",
166
- str(int(decode_cfg.get("max_tokens", 256))),
167
- "--temperature",
168
- str(float(decode_cfg.get("temperature", 0.2))),
169
- "--top-p",
170
- str(float(decode_cfg.get("top_p", 0.95))),
171
- ]
172
- return cmd, env
173
-
174
-
175
- def build_cycle_plan(freeze: dict[str, Any], *, benchmark: str, out_dir: Path) -> list[dict[str, Any]]:
176
- runnable_variants = [
177
- name for name, cfg in freeze.get("variants", {}).items()
178
- if isinstance(cfg, dict) and cfg.get("status") == "runnable_now"
179
- ]
180
- seeds = [int(seed) for seed in freeze.get("seeds", [])]
181
- plan: list[dict[str, Any]] = []
182
- for variant in runnable_variants:
183
- for seed in seeds:
184
- cmd, env = build_benchmark_command(
185
- freeze,
186
- benchmark=benchmark,
187
- variant=variant,
188
- seed=seed,
189
- out_dir=out_dir,
190
- )
191
- plan.append({
192
- "benchmark": benchmark,
193
- "variant": variant,
194
- "seed": seed,
195
- "command": cmd,
196
- "env": env,
197
- })
198
- return plan
199
-
200
-
201
- def execute_cycle_plan(plan: list[dict[str, Any]], *, repo_root: Path) -> list[dict[str, Any]]:
202
- results: list[dict[str, Any]] = []
203
- for item in plan:
204
- proc = subprocess.run(item["command"], cwd=str(repo_root), env=item["env"])
205
- results.append(
206
- {
207
- "benchmark": item["benchmark"],
208
- "variant": item["variant"],
209
- "seed": item["seed"],
210
- "returncode": proc.returncode,
211
- }
212
- )
213
- return results
214
-
215
-
216
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
217
- parser = argparse.ArgumentParser(description="Execute a frozen Cycle 1 benchmark run")
218
- parser.add_argument("--freeze", type=Path, default=FREEZE_PATH)
219
- parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
220
- parser.add_argument("--benchmark", required=True)
221
- parser.add_argument("--variant", required=True)
222
- parser.add_argument("--seed", type=int, required=True)
223
- parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "runs")
224
- parser.add_argument("--preflight-out", type=Path)
225
- parser.add_argument("--summary-out", type=Path)
226
- parser.add_argument("--hydrate-assets", action="store_true")
227
- parser.add_argument("--all-runnable", action="store_true")
228
- parser.add_argument("--all-benchmarks", action="store_true")
229
- parser.add_argument("--require-ready", action="store_true")
230
- parser.add_argument("--output-repo")
231
- parser.add_argument("--tokenizer-repo")
232
- return parser.parse_args(argv)
233
-
234
-
235
- def main(argv: list[str] | None = None) -> int:
236
- args = parse_args(argv)
237
- cache_dir = Path(os.path.expanduser("~/.cache/autoresearch"))
238
- report = None
239
- token = active_hf_token()
240
- routing = resolve_routing(token=token)
241
- output_repo = args.output_repo or routing.output_repo
242
- tokenizer_repo = args.tokenizer_repo or routing.output_repo
243
- if args.hydrate_assets:
244
- try:
245
- ensure_benchmark_assets(
246
- cache_dir=cache_dir,
247
- output_repo=output_repo,
248
- tokenizer_repo=tokenizer_repo,
249
- token=token,
250
- hydrate=True,
251
- )
252
- except FileNotFoundError as exc:
253
- checkpoint_report = None
254
- try:
255
- checkpoint_report = build_remote_checkpoint_report(output_repo, token)
256
- except Exception:
257
- checkpoint_report = None
258
- if args.summary_out is not None:
259
- write_cycle_summary(
260
- args.summary_out,
261
- [{
262
- "status": "blocked",
263
- "reason": "asset_hydration_failed",
264
- "error": str(exc),
265
- "checkpoint_candidates": checkpoint_report,
266
- }],
267
- )
268
- return 3
269
- if args.preflight_out is not None:
270
- report = build_preflight_report(
271
- cache_dir=cache_dir,
272
- output_repo=output_repo,
273
- tokenizer_repo=tokenizer_repo,
274
- )
275
- write_preflight_report(args.preflight_out, report)
276
- if args.require_ready:
277
- if report is None:
278
- report = build_preflight_report(
279
- cache_dir=cache_dir,
280
- output_repo=output_repo,
281
- tokenizer_repo=tokenizer_repo,
282
- )
283
- if not bool(report.get("ready_for_hydra_benchmarks")):
284
- checkpoint_report = None
285
- try:
286
- checkpoint_report = build_remote_checkpoint_report(output_repo, token)
287
- except Exception:
288
- checkpoint_report = None
289
- if args.summary_out is not None:
290
- write_cycle_summary(
291
- args.summary_out,
292
- [{
293
- "status": "blocked",
294
- "reason": "preflight_not_ready",
295
- "preflight": report,
296
- "checkpoint_candidates": checkpoint_report,
297
- }],
298
- )
299
- return 2
300
- freeze = load_cycle_freeze(args.freeze)
301
- if args.all_runnable:
302
- benchmarks = load_cycle_benchmarks(args.suite) if args.all_benchmarks else [args.benchmark]
303
- plan = []
304
- for benchmark in benchmarks:
305
- plan.extend(build_cycle_plan(freeze, benchmark=benchmark, out_dir=args.out_dir))
306
- results = execute_cycle_plan(plan, repo_root=REPO_ROOT)
307
- if args.summary_out is not None:
308
- write_cycle_summary(args.summary_out, results)
309
- return 0 if all(item["returncode"] == 0 for item in results) else 1
310
- cmd, env = build_benchmark_command(
311
- freeze,
312
- benchmark=args.benchmark,
313
- variant=args.variant,
314
- seed=args.seed,
315
- out_dir=args.out_dir,
316
- )
317
- proc = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env)
318
- if args.summary_out is not None:
319
- write_cycle_summary(
320
- args.summary_out,
321
- [{
322
- "benchmark": args.benchmark,
323
- "variant": args.variant,
324
- "seed": args.seed,
325
- "returncode": proc.returncode,
326
- }],
327
- )
328
- return proc.returncode
329
-
330
-
331
- if __name__ == "__main__":
332
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import importlib.util
6
+ import importlib
7
+ import json
8
+ import os
9
+ import subprocess
10
+ import sys
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from scripts.benchmark_preflight import build_readiness_report
15
+ from scripts.hf_routing import resolve_routing
16
+
17
+
18
+ REPO_ROOT = Path(__file__).resolve().parents[1]
19
+ FREEZE_PATH = REPO_ROOT / "artifacts" / "cycle_1_execution_freeze.json"
20
+ RUNNER_PATH = REPO_ROOT / "scripts" / "benchmark_runner.py"
21
+
22
+
23
+ def active_hf_token() -> str | None:
24
+ token = os.environ.get("HF_TOKEN")
25
+ if token:
26
+ return token
27
+ try:
28
+ from huggingface_hub.utils import get_token
29
+ return get_token()
30
+ except Exception:
31
+ return None
32
+
33
+
34
+ def missing_benchmark_dependencies() -> list[str]:
35
+ required = ["mamba_ssm", "transformers"]
36
+ missing: list[str] = []
37
+ for name in required:
38
+ try:
39
+ spec = importlib.util.find_spec(name)
40
+ except (ImportError, ValueError):
41
+ spec = None
42
+ if spec is None:
43
+ try:
44
+ importlib.import_module(name)
45
+ except Exception:
46
+ missing.append(name)
47
+ return missing
48
+
49
+
50
+ def load_cycle_freeze(path: Path) -> dict[str, Any]:
51
+ return json.loads(path.read_text(encoding="utf-8"))
52
+
53
+
54
+ def load_cycle_benchmarks(path: Path) -> list[str]:
55
+ payload = json.loads(path.read_text(encoding="utf-8"))
56
+ out: list[str] = []
57
+ for section in ("coding_benchmarks", "reasoning_benchmarks"):
58
+ for slot in ("fast_iteration", "milestone"):
59
+ entry = payload.get(section, {}).get(slot)
60
+ if isinstance(entry, dict) and entry.get("name"):
61
+ out.append(str(entry["name"]))
62
+ return out
63
+
64
+
65
+ def build_preflight_report(
66
+ *,
67
+ cache_dir: Path,
68
+ output_repo: str | None = None,
69
+ tokenizer_repo: str | None = None,
70
+ ) -> dict[str, object]:
71
+ return build_readiness_report(
72
+ cache_dir=cache_dir,
73
+ hf_token_present=bool(active_hf_token()),
74
+ dependencies_present=not bool(missing_benchmark_dependencies()),
75
+ missing_dependencies=missing_benchmark_dependencies(),
76
+ output_repo=output_repo,
77
+ tokenizer_repo=tokenizer_repo,
78
+ )
79
+
80
+
81
+ def write_preflight_report(path: Path, payload: dict[str, object]) -> None:
82
+ path.parent.mkdir(parents=True, exist_ok=True)
83
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
84
+
85
+
86
+ def write_cycle_summary(path: Path, payload: list[dict[str, Any]]) -> None:
87
+ path.parent.mkdir(parents=True, exist_ok=True)
88
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
89
+
90
+
91
+ def build_remote_checkpoint_report(output_repo: str, token: str | None) -> dict[str, Any]:
92
+ from huggingface_hub import HfApi
93
+
94
+ from scripts.benchmark_checkpoint_report import build_checkpoint_report
95
+
96
+ files = HfApi(token=token).list_repo_files(repo_id=output_repo, repo_type="model", token=token)
97
+ return build_checkpoint_report(files)
98
+
99
+
100
+ def ensure_benchmark_assets(
101
+ *,
102
+ cache_dir: Path,
103
+ output_repo: str,
104
+ tokenizer_repo: str,
105
+ token: str | None,
106
+ hydrate: bool,
107
+ ) -> dict[str, str] | None:
108
+ if not hydrate:
109
+ return None
110
+ from scripts.benchmark_assets import hydrate_benchmark_assets
111
+
112
+ return hydrate_benchmark_assets(
113
+ cache_dir=cache_dir,
114
+ output_repo=output_repo,
115
+ tokenizer_repo=tokenizer_repo,
116
+ token=token,
117
+ )
118
+
119
+
120
+ def build_benchmark_command(
121
+ freeze: dict[str, Any],
122
+ *,
123
+ benchmark: str,
124
+ variant: str,
125
+ seed: int,
126
+ out_dir: Path,
127
+ ) -> tuple[list[str], dict[str, str]]:
128
+ variant_cfg = freeze["variants"][variant]
129
+ env = os.environ.copy()
130
+ env.update({str(k): str(v) for k, v in variant_cfg.get("env", {}).items()})
131
+ env["HYDRA_SEED"] = str(seed)
132
+
133
+ out_dir.mkdir(parents=True, exist_ok=True)
134
+ result_path = out_dir / f"{benchmark.lower()}_{variant}_seed{seed}.json"
135
+ ledger_path = out_dir / "benchmark_ledger.json"
136
+ cmd = [
137
+ sys.executable,
138
+ str(RUNNER_PATH),
139
+ "--benchmark",
140
+ benchmark,
141
+ "--generator-mode",
142
+ "hydra",
143
+ "--out",
144
+ str(result_path),
145
+ "--ledger",
146
+ str(ledger_path),
147
+ "--variant",
148
+ variant,
149
+ "--seed",
150
+ str(seed),
151
+ ]
152
+ return cmd, env
153
+
154
+
155
+ def build_cycle_plan(freeze: dict[str, Any], *, benchmark: str, out_dir: Path) -> list[dict[str, Any]]:
156
+ runnable_variants = [
157
+ name for name, cfg in freeze.get("variants", {}).items()
158
+ if isinstance(cfg, dict) and cfg.get("status") == "runnable_now"
159
+ ]
160
+ seeds = [int(seed) for seed in freeze.get("seeds", [])]
161
+ plan: list[dict[str, Any]] = []
162
+ for variant in runnable_variants:
163
+ for seed in seeds:
164
+ cmd, env = build_benchmark_command(
165
+ freeze,
166
+ benchmark=benchmark,
167
+ variant=variant,
168
+ seed=seed,
169
+ out_dir=out_dir,
170
+ )
171
+ plan.append({
172
+ "benchmark": benchmark,
173
+ "variant": variant,
174
+ "seed": seed,
175
+ "command": cmd,
176
+ "env": env,
177
+ })
178
+ return plan
179
+
180
+
181
+ def execute_cycle_plan(plan: list[dict[str, Any]], *, repo_root: Path) -> list[dict[str, Any]]:
182
+ results: list[dict[str, Any]] = []
183
+ for item in plan:
184
+ proc = subprocess.run(item["command"], cwd=str(repo_root), env=item["env"])
185
+ results.append(
186
+ {
187
+ "benchmark": item["benchmark"],
188
+ "variant": item["variant"],
189
+ "seed": item["seed"],
190
+ "returncode": proc.returncode,
191
+ }
192
+ )
193
+ return results
194
+
195
+
196
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
197
+ parser = argparse.ArgumentParser(description="Execute a frozen Cycle 1 benchmark run")
198
+ parser.add_argument("--freeze", type=Path, default=FREEZE_PATH)
199
+ parser.add_argument("--suite", type=Path, default=REPO_ROOT / "artifacts" / "benchmark_suite.cycle1.json")
200
+ parser.add_argument("--benchmark", required=True)
201
+ parser.add_argument("--variant", required=True)
202
+ parser.add_argument("--seed", type=int, required=True)
203
+ parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "runs")
204
+ parser.add_argument("--preflight-out", type=Path)
205
+ parser.add_argument("--summary-out", type=Path)
206
+ parser.add_argument("--hydrate-assets", action="store_true")
207
+ parser.add_argument("--all-runnable", action="store_true")
208
+ parser.add_argument("--all-benchmarks", action="store_true")
209
+ parser.add_argument("--require-ready", action="store_true")
210
+ parser.add_argument("--output-repo")
211
+ parser.add_argument("--tokenizer-repo")
212
+ return parser.parse_args(argv)
213
+
214
+
215
+ def main(argv: list[str] | None = None) -> int:
216
+ args = parse_args(argv)
217
+ cache_dir = Path(os.path.expanduser("~/.cache/autoresearch"))
218
+ report = None
219
+ token = active_hf_token()
220
+ routing = resolve_routing(token=token)
221
+ output_repo = args.output_repo or routing.output_repo
222
+ tokenizer_repo = args.tokenizer_repo or routing.output_repo
223
+ if args.hydrate_assets:
224
+ try:
225
+ ensure_benchmark_assets(
226
+ cache_dir=cache_dir,
227
+ output_repo=output_repo,
228
+ tokenizer_repo=tokenizer_repo,
229
+ token=token,
230
+ hydrate=True,
231
+ )
232
+ except FileNotFoundError as exc:
233
+ checkpoint_report = None
234
+ try:
235
+ checkpoint_report = build_remote_checkpoint_report(output_repo, token)
236
+ except Exception:
237
+ checkpoint_report = None
238
+ if args.summary_out is not None:
239
+ write_cycle_summary(
240
+ args.summary_out,
241
+ [{
242
+ "status": "blocked",
243
+ "reason": "asset_hydration_failed",
244
+ "error": str(exc),
245
+ "checkpoint_candidates": checkpoint_report,
246
+ }],
247
+ )
248
+ return 3
249
+ if args.preflight_out is not None:
250
+ report = build_preflight_report(
251
+ cache_dir=cache_dir,
252
+ output_repo=output_repo,
253
+ tokenizer_repo=tokenizer_repo,
254
+ )
255
+ write_preflight_report(args.preflight_out, report)
256
+ if args.require_ready:
257
+ if report is None:
258
+ report = build_preflight_report(
259
+ cache_dir=cache_dir,
260
+ output_repo=output_repo,
261
+ tokenizer_repo=tokenizer_repo,
262
+ )
263
+ if not bool(report.get("ready_for_hydra_benchmarks")):
264
+ checkpoint_report = None
265
+ try:
266
+ checkpoint_report = build_remote_checkpoint_report(output_repo, token)
267
+ except Exception:
268
+ checkpoint_report = None
269
+ if args.summary_out is not None:
270
+ write_cycle_summary(
271
+ args.summary_out,
272
+ [{
273
+ "status": "blocked",
274
+ "reason": "preflight_not_ready",
275
+ "preflight": report,
276
+ "checkpoint_candidates": checkpoint_report,
277
+ }],
278
+ )
279
+ return 2
280
+ freeze = load_cycle_freeze(args.freeze)
281
+ if args.all_runnable:
282
+ benchmarks = load_cycle_benchmarks(args.suite) if args.all_benchmarks else [args.benchmark]
283
+ plan = []
284
+ for benchmark in benchmarks:
285
+ plan.extend(build_cycle_plan(freeze, benchmark=benchmark, out_dir=args.out_dir))
286
+ results = execute_cycle_plan(plan, repo_root=REPO_ROOT)
287
+ if args.summary_out is not None:
288
+ write_cycle_summary(args.summary_out, results)
289
+ return 0 if all(item["returncode"] == 0 for item in results) else 1
290
+ cmd, env = build_benchmark_command(
291
+ freeze,
292
+ benchmark=args.benchmark,
293
+ variant=args.variant,
294
+ seed=args.seed,
295
+ out_dir=args.out_dir,
296
+ )
297
+ proc = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env)
298
+ if args.summary_out is not None:
299
+ write_cycle_summary(
300
+ args.summary_out,
301
+ [{
302
+ "benchmark": args.benchmark,
303
+ "variant": args.variant,
304
+ "seed": args.seed,
305
+ "returncode": proc.returncode,
306
+ }],
307
+ )
308
+ return proc.returncode
309
+
310
+
311
+ if __name__ == "__main__":
312
+ raise SystemExit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/export_hpo_priors.py CHANGED
@@ -1,94 +1,94 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import datetime as dt
6
- import json
7
- from pathlib import Path
8
- from typing import Any
9
-
10
- import optuna
11
-
12
- from scripts.hpo_leaderboard import build_leaderboard
13
-
14
-
15
- def parse_args() -> argparse.Namespace:
16
- parser = argparse.ArgumentParser(description="Export top Optuna trials as transfer-learning priors")
17
- parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
18
- parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
19
- parser.add_argument("--top-k", type=int, default=20)
20
- parser.add_argument("--out", type=Path, default=Path("docs") / "hpo_transfer_priors.json")
21
- parser.add_argument("--metric", default="val_bpb")
22
- return parser.parse_args()
23
-
24
-
25
- def _completed_trials(study: optuna.Study) -> list[optuna.trial.FrozenTrial]:
26
- trials = [t for t in study.trials if t.value is not None]
27
- reverse = study.direction == optuna.study.StudyDirection.MAXIMIZE
28
- return sorted(trials, key=lambda t: float(t.value), reverse=reverse)
29
-
30
-
31
- def _serialize_trial(trial: optuna.trial.FrozenTrial) -> dict[str, Any]:
32
- return {
33
- "trial_number": trial.number,
34
- "value": float(trial.value) if trial.value is not None else None,
35
- "params": dict(trial.params),
36
- "user_attrs": dict(trial.user_attrs),
37
- }
38
-
39
-
40
- def collect_prior_trials(*, storage: str, study_names: list[str], top_k: int, metric: str) -> dict[str, Any]:
41
- leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
42
- selected = leaderboard["clean_trials"][: max(0, top_k)]
43
- trials = [
44
- {
45
- "study_name": row["study_name"],
46
- "trial_number": row["trial_number"],
47
- "value": row["value"],
48
- "params": row["params"],
49
- "user_attrs": row["user_attrs"],
50
- }
51
- for row in selected
52
- ]
53
- quarantined = [
54
- {
55
- "study_name": row["study_name"],
56
- "trial_number": row["trial_number"],
57
- "value": row["value"],
58
- "params": row["params"],
59
- "user_attrs": row["user_attrs"],
60
- "contamination_reason": row["contamination_reason"],
61
- }
62
- for row in leaderboard["contaminated_trials"]
63
- ]
64
- return {
65
- "schema_version": 2,
66
- "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
67
- "study_names": study_names,
68
- "metric": metric,
69
- "n_total_trials": sum(int(s["n_trials"]) for s in leaderboard["studies"]),
70
- "n_completed_trials": sum(int(s["n_completed"]) for s in leaderboard["studies"]),
71
- "n_exported_trials": len(trials),
72
- "n_quarantined_trials": len(quarantined),
73
- "top_k": top_k,
74
- "trials": trials,
75
- "quarantined_trials": quarantined,
76
- }
77
-
78
-
79
- def main() -> int:
80
- args = parse_args()
81
- study_names = args.study_name or ["hydra_hpo"]
82
- payload = collect_prior_trials(storage=args.storage, study_names=study_names, top_k=args.top_k, metric=args.metric)
83
-
84
- args.out.parent.mkdir(parents=True, exist_ok=True)
85
- args.out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
86
- print(
87
- f"[hpo-priors] wrote {args.out} with {payload['n_exported_trials']} clean trials "
88
- f"({payload['n_quarantined_trials']} quarantined)"
89
- )
90
- return 0
91
-
92
-
93
- if __name__ == "__main__":
94
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import datetime as dt
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import optuna
11
+
12
+ from scripts.hpo_leaderboard import build_leaderboard
13
+
14
+
15
+ def parse_args() -> argparse.Namespace:
16
+ parser = argparse.ArgumentParser(description="Export top Optuna trials as transfer-learning priors")
17
+ parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
18
+ parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
19
+ parser.add_argument("--top-k", type=int, default=20)
20
+ parser.add_argument("--out", type=Path, default=Path("docs") / "hpo_transfer_priors.json")
21
+ parser.add_argument("--metric", default="val_bpb")
22
+ return parser.parse_args()
23
+
24
+
25
+ def _completed_trials(study: optuna.Study) -> list[optuna.trial.FrozenTrial]:
26
+ trials = [t for t in study.trials if t.value is not None]
27
+ reverse = study.direction == optuna.study.StudyDirection.MAXIMIZE
28
+ return sorted(trials, key=lambda t: float(t.value), reverse=reverse)
29
+
30
+
31
+ def _serialize_trial(trial: optuna.trial.FrozenTrial) -> dict[str, Any]:
32
+ return {
33
+ "trial_number": trial.number,
34
+ "value": float(trial.value) if trial.value is not None else None,
35
+ "params": dict(trial.params),
36
+ "user_attrs": dict(trial.user_attrs),
37
+ }
38
+
39
+
40
+ def collect_prior_trials(*, storage: str, study_names: list[str], top_k: int, metric: str) -> dict[str, Any]:
41
+ leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
42
+ selected = leaderboard["clean_trials"][: max(0, top_k)]
43
+ trials = [
44
+ {
45
+ "study_name": row["study_name"],
46
+ "trial_number": row["trial_number"],
47
+ "value": row["value"],
48
+ "params": row["params"],
49
+ "user_attrs": row["user_attrs"],
50
+ }
51
+ for row in selected
52
+ ]
53
+ quarantined = [
54
+ {
55
+ "study_name": row["study_name"],
56
+ "trial_number": row["trial_number"],
57
+ "value": row["value"],
58
+ "params": row["params"],
59
+ "user_attrs": row["user_attrs"],
60
+ "contamination_reason": row["contamination_reason"],
61
+ }
62
+ for row in leaderboard["contaminated_trials"]
63
+ ]
64
+ return {
65
+ "schema_version": 2,
66
+ "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
67
+ "study_names": study_names,
68
+ "metric": metric,
69
+ "n_total_trials": sum(int(s["n_trials"]) for s in leaderboard["studies"]),
70
+ "n_completed_trials": sum(int(s["n_completed"]) for s in leaderboard["studies"]),
71
+ "n_exported_trials": len(trials),
72
+ "n_quarantined_trials": len(quarantined),
73
+ "top_k": top_k,
74
+ "trials": trials,
75
+ "quarantined_trials": quarantined,
76
+ }
77
+
78
+
79
+ def main() -> int:
80
+ args = parse_args()
81
+ study_names = args.study_name or ["hydra_hpo"]
82
+ payload = collect_prior_trials(storage=args.storage, study_names=study_names, top_k=args.top_k, metric=args.metric)
83
+
84
+ args.out.parent.mkdir(parents=True, exist_ok=True)
85
+ args.out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
86
+ print(
87
+ f"[hpo-priors] wrote {args.out} with {payload['n_exported_trials']} clean trials "
88
+ f"({payload['n_quarantined_trials']} quarantined)"
89
+ )
90
+ return 0
91
+
92
+
93
+ if __name__ == "__main__":
94
+ raise SystemExit(main())
overlay/scripts/hf_routing.py CHANGED
@@ -1,94 +1,94 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- from dataclasses import dataclass
5
-
6
- from huggingface_hub import HfApi
7
-
8
-
9
- _OWNER_ALIASES = {
10
- 'jack': 'jackoatmon',
11
- 'jackoatmon': 'jackoatmon',
12
- 'icarus': 'icarus112',
13
- 'icarus112': 'icarus112',
14
- }
15
-
16
-
17
- def _normalize_owner(value: str | None) -> str | None:
18
- if not value:
19
- return None
20
- normalized = value.strip().lower().lstrip('@')
21
- if not normalized:
22
- return None
23
- return _OWNER_ALIASES.get(normalized, normalized)
24
-
25
-
26
- def _owner_from_env() -> str | None:
27
- for key in (
28
- 'FEATHER_HF_OWNER',
29
- 'FEATHER_HF_NAMESPACE_OWNER',
30
- 'FEATHER_HF_PROFILE',
31
- 'FEATHER_HF_NAMESPACE',
32
- ):
33
- owner = _normalize_owner(os.environ.get(key))
34
- if owner:
35
- return owner
36
- return None
37
-
38
-
39
- def resolve_owner(token: str | None = None) -> str:
40
- """Resolve active HF owner in a collaborator-safe way.
41
-
42
- Resolution precedence:
43
- 1) explicit env owner override (FEATHER_HF_OWNER/...)
44
- 2) Hugging Face `whoami` from HF_TOKEN (unless disabled)
45
- 3) default to jackoatmon
46
- """
47
- owner = _owner_from_env()
48
- if owner:
49
- return owner
50
-
51
- if os.environ.get('FEATHER_HF_DISABLE_WHOAMI', '0') != '1':
52
- active_token = token or os.environ.get('HF_TOKEN')
53
- if active_token:
54
- try:
55
- info = HfApi(token=active_token).whoami(token=active_token)
56
- if isinstance(info, dict):
57
- whoami_owner = _normalize_owner(info.get('name'))
58
- if whoami_owner:
59
- return whoami_owner
60
- except Exception:
61
- # We intentionally fail-open to deterministic defaults.
62
- pass
63
-
64
- return 'jackoatmon'
65
-
66
-
67
- @dataclass(frozen=True)
68
- class HfRouting:
69
- owner: str
70
- space_repo: str
71
- output_repo: str
72
- retina_cache_repo: str
73
- job_namespace: str
74
-
75
-
76
- def resolve_routing(token: str | None = None) -> HfRouting:
77
- owner = resolve_owner(token=token)
78
-
79
- space_name = os.environ.get('FEATHER_HF_SPACE_NAME', 'feather-a10-runtime')
80
- output_name = os.environ.get('FEATHER_HF_OUTPUT_REPO_NAME', 'feather-pretrain-checkpoints')
81
- retina_name = os.environ.get('FEATHER_HF_RETINA_REPO_NAME', 'feather-retina-cache')
82
-
83
- space_repo = os.environ.get('FEATHER_HF_SPACE_REPO') or f'{owner}/{space_name}'
84
- output_repo = os.environ.get('FEATHER_HF_OUTPUT_REPO') or f'{owner}/{output_name}'
85
- retina_cache_repo = os.environ.get('FEATHER_HF_RETINA_CACHE_REPO') or f'{owner}/{retina_name}'
86
- job_namespace = os.environ.get('FEATHER_HF_JOB_NAMESPACE') or owner
87
-
88
- return HfRouting(
89
- owner=owner,
90
- space_repo=space_repo,
91
- output_repo=output_repo,
92
- retina_cache_repo=retina_cache_repo,
93
- job_namespace=job_namespace,
94
- )
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+
6
+ from huggingface_hub import HfApi
7
+
8
+
9
+ _OWNER_ALIASES = {
10
+ 'jack': 'jackoatmon',
11
+ 'jackoatmon': 'jackoatmon',
12
+ 'icarus': 'icarus112',
13
+ 'icarus112': 'icarus112',
14
+ }
15
+
16
+
17
+ def _normalize_owner(value: str | None) -> str | None:
18
+ if not value:
19
+ return None
20
+ normalized = value.strip().lower().lstrip('@')
21
+ if not normalized:
22
+ return None
23
+ return _OWNER_ALIASES.get(normalized, normalized)
24
+
25
+
26
+ def _owner_from_env() -> str | None:
27
+ for key in (
28
+ 'FEATHER_HF_OWNER',
29
+ 'FEATHER_HF_NAMESPACE_OWNER',
30
+ 'FEATHER_HF_PROFILE',
31
+ 'FEATHER_HF_NAMESPACE',
32
+ ):
33
+ owner = _normalize_owner(os.environ.get(key))
34
+ if owner:
35
+ return owner
36
+ return None
37
+
38
+
39
+ def resolve_owner(token: str | None = None) -> str:
40
+ """Resolve active HF owner in a collaborator-safe way.
41
+
42
+ Resolution precedence:
43
+ 1) explicit env owner override (FEATHER_HF_OWNER/...)
44
+ 2) Hugging Face `whoami` from HF_TOKEN (unless disabled)
45
+ 3) default to jackoatmon
46
+ """
47
+ owner = _owner_from_env()
48
+ if owner:
49
+ return owner
50
+
51
+ if os.environ.get('FEATHER_HF_DISABLE_WHOAMI', '0') != '1':
52
+ active_token = token or os.environ.get('HF_TOKEN')
53
+ if active_token:
54
+ try:
55
+ info = HfApi(token=active_token).whoami(token=active_token)
56
+ if isinstance(info, dict):
57
+ whoami_owner = _normalize_owner(info.get('name'))
58
+ if whoami_owner:
59
+ return whoami_owner
60
+ except Exception:
61
+ # We intentionally fail-open to deterministic defaults.
62
+ pass
63
+
64
+ return 'jackoatmon'
65
+
66
+
67
+ @dataclass(frozen=True)
68
+ class HfRouting:
69
+ owner: str
70
+ space_repo: str
71
+ output_repo: str
72
+ retina_cache_repo: str
73
+ job_namespace: str
74
+
75
+
76
+ def resolve_routing(token: str | None = None) -> HfRouting:
77
+ owner = resolve_owner(token=token)
78
+
79
+ space_name = os.environ.get('FEATHER_HF_SPACE_NAME', 'feather-a10-runtime')
80
+ output_name = os.environ.get('FEATHER_HF_OUTPUT_REPO_NAME', 'feather-pretrain-checkpoints')
81
+ retina_name = os.environ.get('FEATHER_HF_RETINA_REPO_NAME', 'feather-retina-cache')
82
+
83
+ space_repo = os.environ.get('FEATHER_HF_SPACE_REPO') or f'{owner}/{space_name}'
84
+ output_repo = os.environ.get('FEATHER_HF_OUTPUT_REPO') or f'{owner}/{output_name}'
85
+ retina_cache_repo = os.environ.get('FEATHER_HF_RETINA_CACHE_REPO') or f'{owner}/{retina_name}'
86
+ job_namespace = os.environ.get('FEATHER_HF_JOB_NAMESPACE') or owner
87
+
88
+ return HfRouting(
89
+ owner=owner,
90
+ space_repo=space_repo,
91
+ output_repo=output_repo,
92
+ retina_cache_repo=retina_cache_repo,
93
+ job_namespace=job_namespace,
94
+ )
overlay/scripts/hpo_component_report.py CHANGED
@@ -1,130 +1,130 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import datetime as dt
6
- import json
7
- import math
8
- from collections import defaultdict
9
- from pathlib import Path
10
- from typing import Any
11
-
12
- from scripts.hpo_leaderboard import build_leaderboard
13
-
14
-
15
- _COMPONENT_KEYS = [
16
- "engram_subsample",
17
- "htm_subsample",
18
- "htm_learn_every",
19
- "engram_n_columns",
20
- "engram_layer_idx",
21
- "sdr_target_active",
22
- "mamba3_chunk",
23
- "dropout",
24
- "hyena_layers",
25
- ]
26
-
27
-
28
- def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
29
- params = dict(row.get("params") or {})
30
- attrs = row.get("user_attrs") or {}
31
- for key, value in attrs.items():
32
- if key.startswith("param_"):
33
- params.setdefault(key.removeprefix("param_"), value)
34
- return params
35
-
36
-
37
- def _pearson(xs: list[float], ys: list[float]) -> float | None:
38
- if len(xs) < 2 or len(xs) != len(ys):
39
- return None
40
- mean_x = sum(xs) / len(xs)
41
- mean_y = sum(ys) / len(ys)
42
- cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
43
- var_x = sum((x - mean_x) ** 2 for x in xs)
44
- var_y = sum((y - mean_y) ** 2 for y in ys)
45
- if var_x <= 0 or var_y <= 0:
46
- return None
47
- return cov / math.sqrt(var_x * var_y)
48
-
49
-
50
- def build_component_report(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
51
- leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
52
- clean_trials = leaderboard["clean_trials"]
53
-
54
- ablations: dict[str, list[dict[str, Any]]] = {}
55
- numeric_correlations: list[dict[str, Any]] = []
56
-
57
- for key in _COMPONENT_KEYS:
58
- grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
59
- numeric_x: list[float] = []
60
- metric_y: list[float] = []
61
- tps_y: list[float] = []
62
- for row in clean_trials:
63
- params = _recover_params(row)
64
- if key not in params:
65
- continue
66
- value = params[key]
67
- grouped[str(value)].append({"value": value, "metric": float(row["value"]), "tps": row.get("tps")})
68
- if isinstance(value, (int, float)) and isinstance(row.get("tps"), (int, float)):
69
- numeric_x.append(float(value))
70
- metric_y.append(float(row["value"]))
71
- tps_y.append(float(row["tps"]))
72
-
73
- rows: list[dict[str, Any]] = []
74
- for grouped_rows in grouped.values():
75
- value = grouped_rows[0]["value"]
76
- metric_vals = [r["metric"] for r in grouped_rows]
77
- tps_vals = [float(r["tps"]) for r in grouped_rows if isinstance(r["tps"], (int, float))]
78
- rows.append({
79
- "value": value,
80
- "n_trials": len(grouped_rows),
81
- "mean_metric": sum(metric_vals) / len(metric_vals),
82
- "mean_tps": (sum(tps_vals) / len(tps_vals)) if tps_vals else None,
83
- })
84
- if rows:
85
- rows.sort(key=lambda row: str(row["value"]))
86
- ablations[key] = rows
87
-
88
- pearson_metric = _pearson(numeric_x, metric_y)
89
- pearson_tps = _pearson(numeric_x, tps_y)
90
- if pearson_metric is not None or pearson_tps is not None:
91
- numeric_correlations.append({
92
- "param": key,
93
- "pearson_with_metric": pearson_metric,
94
- "pearson_with_tps": pearson_tps,
95
- "n_points": len(numeric_x),
96
- })
97
-
98
- numeric_correlations.sort(key=lambda row: row["param"])
99
- return {
100
- "schema_version": 1,
101
- "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
102
- "metric": metric,
103
- "study_names": study_names,
104
- "n_clean_trials": len(clean_trials),
105
- "component_ablations": ablations,
106
- "numeric_correlations": numeric_correlations,
107
- }
108
-
109
-
110
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
111
- parser = argparse.ArgumentParser(description="Build component ablation and correlation report from clean HPO trials")
112
- parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
113
- parser.add_argument("--study-name", action="append", default=[])
114
- parser.add_argument("--metric", default="val_bpb")
115
- parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "component_report.json")
116
- return parser.parse_args(argv)
117
-
118
-
119
- def main(argv: list[str] | None = None) -> int:
120
- args = parse_args(argv)
121
- study_names = args.study_name or ["hydra_hpo"]
122
- payload = build_component_report(storage=args.storage, study_names=study_names, metric=args.metric)
123
- args.out.parent.mkdir(parents=True, exist_ok=True)
124
- args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
125
- print(json.dumps(payload, indent=2, sort_keys=True))
126
- return 0
127
-
128
-
129
- if __name__ == "__main__":
130
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import datetime as dt
6
+ import json
7
+ import math
8
+ from collections import defaultdict
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ from scripts.hpo_leaderboard import build_leaderboard
13
+
14
+
15
+ _COMPONENT_KEYS = [
16
+ "engram_subsample",
17
+ "htm_subsample",
18
+ "htm_learn_every",
19
+ "engram_n_columns",
20
+ "engram_layer_idx",
21
+ "sdr_target_active",
22
+ "mamba3_chunk",
23
+ "dropout",
24
+ "hyena_layers",
25
+ ]
26
+
27
+
28
+ def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
29
+ params = dict(row.get("params") or {})
30
+ attrs = row.get("user_attrs") or {}
31
+ for key, value in attrs.items():
32
+ if key.startswith("param_"):
33
+ params.setdefault(key.removeprefix("param_"), value)
34
+ return params
35
+
36
+
37
+ def _pearson(xs: list[float], ys: list[float]) -> float | None:
38
+ if len(xs) < 2 or len(xs) != len(ys):
39
+ return None
40
+ mean_x = sum(xs) / len(xs)
41
+ mean_y = sum(ys) / len(ys)
42
+ cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
43
+ var_x = sum((x - mean_x) ** 2 for x in xs)
44
+ var_y = sum((y - mean_y) ** 2 for y in ys)
45
+ if var_x <= 0 or var_y <= 0:
46
+ return None
47
+ return cov / math.sqrt(var_x * var_y)
48
+
49
+
50
+ def build_component_report(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
51
+ leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
52
+ clean_trials = leaderboard["clean_trials"]
53
+
54
+ ablations: dict[str, list[dict[str, Any]]] = {}
55
+ numeric_correlations: list[dict[str, Any]] = []
56
+
57
+ for key in _COMPONENT_KEYS:
58
+ grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
59
+ numeric_x: list[float] = []
60
+ metric_y: list[float] = []
61
+ tps_y: list[float] = []
62
+ for row in clean_trials:
63
+ params = _recover_params(row)
64
+ if key not in params:
65
+ continue
66
+ value = params[key]
67
+ grouped[str(value)].append({"value": value, "metric": float(row["value"]), "tps": row.get("tps")})
68
+ if isinstance(value, (int, float)) and isinstance(row.get("tps"), (int, float)):
69
+ numeric_x.append(float(value))
70
+ metric_y.append(float(row["value"]))
71
+ tps_y.append(float(row["tps"]))
72
+
73
+ rows: list[dict[str, Any]] = []
74
+ for grouped_rows in grouped.values():
75
+ value = grouped_rows[0]["value"]
76
+ metric_vals = [r["metric"] for r in grouped_rows]
77
+ tps_vals = [float(r["tps"]) for r in grouped_rows if isinstance(r["tps"], (int, float))]
78
+ rows.append({
79
+ "value": value,
80
+ "n_trials": len(grouped_rows),
81
+ "mean_metric": sum(metric_vals) / len(metric_vals),
82
+ "mean_tps": (sum(tps_vals) / len(tps_vals)) if tps_vals else None,
83
+ })
84
+ if rows:
85
+ rows.sort(key=lambda row: str(row["value"]))
86
+ ablations[key] = rows
87
+
88
+ pearson_metric = _pearson(numeric_x, metric_y)
89
+ pearson_tps = _pearson(numeric_x, tps_y)
90
+ if pearson_metric is not None or pearson_tps is not None:
91
+ numeric_correlations.append({
92
+ "param": key,
93
+ "pearson_with_metric": pearson_metric,
94
+ "pearson_with_tps": pearson_tps,
95
+ "n_points": len(numeric_x),
96
+ })
97
+
98
+ numeric_correlations.sort(key=lambda row: row["param"])
99
+ return {
100
+ "schema_version": 1,
101
+ "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
102
+ "metric": metric,
103
+ "study_names": study_names,
104
+ "n_clean_trials": len(clean_trials),
105
+ "component_ablations": ablations,
106
+ "numeric_correlations": numeric_correlations,
107
+ }
108
+
109
+
110
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
111
+ parser = argparse.ArgumentParser(description="Build component ablation and correlation report from clean HPO trials")
112
+ parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
113
+ parser.add_argument("--study-name", action="append", default=[])
114
+ parser.add_argument("--metric", default="val_bpb")
115
+ parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "component_report.json")
116
+ return parser.parse_args(argv)
117
+
118
+
119
+ def main(argv: list[str] | None = None) -> int:
120
+ args = parse_args(argv)
121
+ study_names = args.study_name or ["hydra_hpo"]
122
+ payload = build_component_report(storage=args.storage, study_names=study_names, metric=args.metric)
123
+ args.out.parent.mkdir(parents=True, exist_ok=True)
124
+ args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
125
+ print(json.dumps(payload, indent=2, sort_keys=True))
126
+ return 0
127
+
128
+
129
+ if __name__ == "__main__":
130
+ raise SystemExit(main())
overlay/scripts/hpo_leaderboard.py CHANGED
@@ -1,156 +1,156 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import datetime as dt
6
- import json
7
- from pathlib import Path
8
- from typing import Any
9
-
10
- import optuna
11
-
12
-
13
- def _trial_direction(study: optuna.Study) -> str:
14
- return "maximize" if study.direction == optuna.study.StudyDirection.MAXIMIZE else "minimize"
15
-
16
-
17
- def _contamination_reason(trial: optuna.trial.FrozenTrial, metric: str) -> str | None:
18
- if trial.value is None:
19
- return "missing_value"
20
- attrs = trial.user_attrs
21
- source = attrs.get("objective_source")
22
- eval_status = attrs.get("eval_status")
23
- objective_metric = attrs.get("objective_metric")
24
-
25
- if source in {"train_log_fallback", "missing_metric", "missing_metrics", "missing_final_val"}:
26
- return f"objective_source={source}"
27
- if eval_status not in {None, "completed"}:
28
- return f"eval_status={eval_status}"
29
- if objective_metric not in {None, metric}:
30
- return f"objective_metric={objective_metric}"
31
- return None
32
-
33
-
34
- def _serialize_trial(study_name: str, trial: optuna.trial.FrozenTrial, metric: str) -> dict[str, Any]:
35
- attrs = dict(trial.user_attrs)
36
- source = attrs.get("objective_source") or "legacy_completed_value"
37
- row = {
38
- "study_name": study_name,
39
- "trial_number": trial.number,
40
- "value": float(trial.value) if trial.value is not None else None,
41
- "metric": metric,
42
- "objective_source": source,
43
- "objective_metric": attrs.get("objective_metric", metric),
44
- "eval_status": attrs.get("eval_status"),
45
- "hf_job_id": attrs.get("hf_job_id"),
46
- "tps": attrs.get("tps"),
47
- "params": dict(trial.params),
48
- "user_attrs": attrs,
49
- }
50
- reason = _contamination_reason(trial, metric)
51
- if reason is not None:
52
- row["contamination_reason"] = reason
53
- return row
54
-
55
-
56
- def _is_pareto_dominated(candidate: dict[str, Any], peers: list[dict[str, Any]]) -> bool:
57
- candidate_value = float(candidate["value"])
58
- candidate_tps = float(candidate["tps"])
59
- for peer in peers:
60
- if peer is candidate or peer.get("tps") is None:
61
- continue
62
- peer_value = float(peer["value"])
63
- peer_tps = float(peer["tps"])
64
- no_worse = peer_value <= candidate_value and peer_tps >= candidate_tps
65
- strictly_better = peer_value < candidate_value or peer_tps > candidate_tps
66
- if no_worse and strictly_better:
67
- return True
68
- return False
69
-
70
-
71
- def _annotate_pareto(clean_trials: list[dict[str, Any]]) -> list[dict[str, Any]]:
72
- pareto_trials: list[dict[str, Any]] = []
73
- comparable = [row for row in clean_trials if row.get("tps") is not None]
74
- for row in clean_trials:
75
- if row.get("tps") is None:
76
- row["pareto_frontier"] = False
77
- row["pareto_dominated"] = None
78
- row["pareto_reason"] = "missing_tps"
79
- continue
80
- dominated = _is_pareto_dominated(row, comparable)
81
- row["pareto_frontier"] = not dominated
82
- row["pareto_dominated"] = dominated
83
- row["pareto_reason"] = "frontier" if not dominated else "dominated"
84
- if not dominated:
85
- pareto_trials.append(row)
86
- pareto_trials.sort(key=lambda row: (float(row["value"]), -float(row["tps"])))
87
- return pareto_trials
88
-
89
-
90
- def build_leaderboard(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
91
- clean_trials: list[dict[str, Any]] = []
92
- contaminated_trials: list[dict[str, Any]] = []
93
- study_summaries: list[dict[str, Any]] = []
94
- direction = "minimize"
95
-
96
- for study_name in study_names:
97
- study = optuna.load_study(study_name=study_name, storage=storage)
98
- direction = _trial_direction(study)
99
- completed = [t for t in study.trials if t.value is not None]
100
- study_summaries.append({
101
- "study_name": study_name,
102
- "direction": direction,
103
- "n_trials": len(study.trials),
104
- "n_completed": len(completed),
105
- })
106
- for trial in completed:
107
- row = _serialize_trial(study_name, trial, metric)
108
- if "contamination_reason" in row:
109
- contaminated_trials.append(row)
110
- else:
111
- clean_trials.append(row)
112
-
113
- reverse = direction == "maximize"
114
- clean_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
115
- contaminated_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
116
- pareto_trials = _annotate_pareto(clean_trials)
117
-
118
- return {
119
- "schema_version": 1,
120
- "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
121
- "metric": metric,
122
- "direction": direction,
123
- "study_names": study_names,
124
- "studies": study_summaries,
125
- "n_clean_trials": len(clean_trials),
126
- "n_contaminated_trials": len(contaminated_trials),
127
- "pareto_metric_x": metric,
128
- "pareto_metric_y": "tps",
129
- "n_pareto_trials": len(pareto_trials),
130
- "clean_trials": clean_trials,
131
- "contaminated_trials": contaminated_trials,
132
- "pareto_trials": pareto_trials,
133
- }
134
-
135
-
136
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
137
- parser = argparse.ArgumentParser(description="Build a clean Optuna HPO leaderboard")
138
- parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
139
- parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
140
- parser.add_argument("--metric", default="val_bpb")
141
- parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "leaderboard.json")
142
- return parser.parse_args(argv)
143
-
144
-
145
- def main(argv: list[str] | None = None) -> int:
146
- args = parse_args(argv)
147
- study_names = args.study_name or ["hydra_hpo"]
148
- payload = build_leaderboard(storage=args.storage, study_names=study_names, metric=args.metric)
149
- args.out.parent.mkdir(parents=True, exist_ok=True)
150
- args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
151
- print(json.dumps(payload, indent=2, sort_keys=True))
152
- return 0
153
-
154
-
155
- if __name__ == "__main__":
156
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import datetime as dt
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import optuna
11
+
12
+
13
+ def _trial_direction(study: optuna.Study) -> str:
14
+ return "maximize" if study.direction == optuna.study.StudyDirection.MAXIMIZE else "minimize"
15
+
16
+
17
+ def _contamination_reason(trial: optuna.trial.FrozenTrial, metric: str) -> str | None:
18
+ if trial.value is None:
19
+ return "missing_value"
20
+ attrs = trial.user_attrs
21
+ source = attrs.get("objective_source")
22
+ eval_status = attrs.get("eval_status")
23
+ objective_metric = attrs.get("objective_metric")
24
+
25
+ if source in {"train_log_fallback", "missing_metric", "missing_metrics", "missing_final_val"}:
26
+ return f"objective_source={source}"
27
+ if eval_status not in {None, "completed"}:
28
+ return f"eval_status={eval_status}"
29
+ if objective_metric not in {None, metric}:
30
+ return f"objective_metric={objective_metric}"
31
+ return None
32
+
33
+
34
+ def _serialize_trial(study_name: str, trial: optuna.trial.FrozenTrial, metric: str) -> dict[str, Any]:
35
+ attrs = dict(trial.user_attrs)
36
+ source = attrs.get("objective_source") or "legacy_completed_value"
37
+ row = {
38
+ "study_name": study_name,
39
+ "trial_number": trial.number,
40
+ "value": float(trial.value) if trial.value is not None else None,
41
+ "metric": metric,
42
+ "objective_source": source,
43
+ "objective_metric": attrs.get("objective_metric", metric),
44
+ "eval_status": attrs.get("eval_status"),
45
+ "hf_job_id": attrs.get("hf_job_id"),
46
+ "tps": attrs.get("tps"),
47
+ "params": dict(trial.params),
48
+ "user_attrs": attrs,
49
+ }
50
+ reason = _contamination_reason(trial, metric)
51
+ if reason is not None:
52
+ row["contamination_reason"] = reason
53
+ return row
54
+
55
+
56
+ def _is_pareto_dominated(candidate: dict[str, Any], peers: list[dict[str, Any]]) -> bool:
57
+ candidate_value = float(candidate["value"])
58
+ candidate_tps = float(candidate["tps"])
59
+ for peer in peers:
60
+ if peer is candidate or peer.get("tps") is None:
61
+ continue
62
+ peer_value = float(peer["value"])
63
+ peer_tps = float(peer["tps"])
64
+ no_worse = peer_value <= candidate_value and peer_tps >= candidate_tps
65
+ strictly_better = peer_value < candidate_value or peer_tps > candidate_tps
66
+ if no_worse and strictly_better:
67
+ return True
68
+ return False
69
+
70
+
71
+ def _annotate_pareto(clean_trials: list[dict[str, Any]]) -> list[dict[str, Any]]:
72
+ pareto_trials: list[dict[str, Any]] = []
73
+ comparable = [row for row in clean_trials if row.get("tps") is not None]
74
+ for row in clean_trials:
75
+ if row.get("tps") is None:
76
+ row["pareto_frontier"] = False
77
+ row["pareto_dominated"] = None
78
+ row["pareto_reason"] = "missing_tps"
79
+ continue
80
+ dominated = _is_pareto_dominated(row, comparable)
81
+ row["pareto_frontier"] = not dominated
82
+ row["pareto_dominated"] = dominated
83
+ row["pareto_reason"] = "frontier" if not dominated else "dominated"
84
+ if not dominated:
85
+ pareto_trials.append(row)
86
+ pareto_trials.sort(key=lambda row: (float(row["value"]), -float(row["tps"])))
87
+ return pareto_trials
88
+
89
+
90
+ def build_leaderboard(*, storage: str, study_names: list[str], metric: str = "val_bpb") -> dict[str, Any]:
91
+ clean_trials: list[dict[str, Any]] = []
92
+ contaminated_trials: list[dict[str, Any]] = []
93
+ study_summaries: list[dict[str, Any]] = []
94
+ direction = "minimize"
95
+
96
+ for study_name in study_names:
97
+ study = optuna.load_study(study_name=study_name, storage=storage)
98
+ direction = _trial_direction(study)
99
+ completed = [t for t in study.trials if t.value is not None]
100
+ study_summaries.append({
101
+ "study_name": study_name,
102
+ "direction": direction,
103
+ "n_trials": len(study.trials),
104
+ "n_completed": len(completed),
105
+ })
106
+ for trial in completed:
107
+ row = _serialize_trial(study_name, trial, metric)
108
+ if "contamination_reason" in row:
109
+ contaminated_trials.append(row)
110
+ else:
111
+ clean_trials.append(row)
112
+
113
+ reverse = direction == "maximize"
114
+ clean_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
115
+ contaminated_trials.sort(key=lambda row: float(row["value"]), reverse=reverse)
116
+ pareto_trials = _annotate_pareto(clean_trials)
117
+
118
+ return {
119
+ "schema_version": 1,
120
+ "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
121
+ "metric": metric,
122
+ "direction": direction,
123
+ "study_names": study_names,
124
+ "studies": study_summaries,
125
+ "n_clean_trials": len(clean_trials),
126
+ "n_contaminated_trials": len(contaminated_trials),
127
+ "pareto_metric_x": metric,
128
+ "pareto_metric_y": "tps",
129
+ "n_pareto_trials": len(pareto_trials),
130
+ "clean_trials": clean_trials,
131
+ "contaminated_trials": contaminated_trials,
132
+ "pareto_trials": pareto_trials,
133
+ }
134
+
135
+
136
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
137
+ parser = argparse.ArgumentParser(description="Build a clean Optuna HPO leaderboard")
138
+ parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
139
+ parser.add_argument("--study-name", action="append", default=[], help="Repeat to merge multiple studies")
140
+ parser.add_argument("--metric", default="val_bpb")
141
+ parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "leaderboard.json")
142
+ return parser.parse_args(argv)
143
+
144
+
145
+ def main(argv: list[str] | None = None) -> int:
146
+ args = parse_args(argv)
147
+ study_names = args.study_name or ["hydra_hpo"]
148
+ payload = build_leaderboard(storage=args.storage, study_names=study_names, metric=args.metric)
149
+ args.out.parent.mkdir(parents=True, exist_ok=True)
150
+ args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
151
+ print(json.dumps(payload, indent=2, sort_keys=True))
152
+ return 0
153
+
154
+
155
+ if __name__ == "__main__":
156
+ raise SystemExit(main())
overlay/scripts/hpo_orchestrator.py CHANGED
@@ -1,25 +1,25 @@
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
- import argparse
5
- import json
6
- import os
7
- import subprocess
8
- import sys
9
- from pathlib import Path
10
- from typing import Any
11
-
12
- import optuna
13
-
14
-
15
- REPO_ROOT = Path(__file__).resolve().parents[1]
16
- if str(REPO_ROOT) not in sys.path:
17
- sys.path.insert(0, str(REPO_ROOT))
18
-
19
- from scripts.hf_routing import resolve_routing
20
- from scripts.optuna_hpo import _enqueue_transfer_priors
21
-
22
- HPO_SCRIPT = REPO_ROOT / "scripts" / "optuna_hpo.py"
23
 
24
 
25
  def _run_worker(args: list[str]) -> int:
@@ -28,7 +28,7 @@ def _run_worker(args: list[str]) -> int:
28
  return proc.returncode
29
 
30
 
31
- def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
32
  try:
33
  study = optuna.load_study(study_name=study_name, storage=storage)
34
  except KeyError:
@@ -62,29 +62,29 @@ def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
62
  "best_trial_user_attrs": study.best_trial.user_attrs,
63
  }
64
  )
65
- return stats
66
-
67
-
68
- def _prime_transfer_priors(storage: str, study_name: str, priors_file: Path, apply_priors: bool) -> int:
69
- if not apply_priors:
70
- return 0
71
- study = optuna.create_study(
72
- study_name=study_name,
73
- storage=storage,
74
- load_if_exists=True,
75
- direction="minimize",
76
- )
77
- return _enqueue_transfer_priors(study, priors_file, apply_priors=True)
78
-
79
-
80
- def _disable_worker_priors(worker_args: list[str]) -> list[str]:
81
- cleaned: list[str] = []
82
- for item in worker_args:
83
- if item in {"--apply-priors", "--no-apply-priors"}:
84
- continue
85
- cleaned.append(item)
86
- cleaned.append("--no-apply-priors")
87
- return cleaned
88
 
89
 
90
  def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
@@ -117,25 +117,25 @@ def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
117
  base.hf_command,
118
  "--hf-token-env",
119
  base.hf_token_env,
120
- "--hf-poll-interval",
121
- str(base.hf_poll_interval),
122
- "--hf-launcher-script",
123
- str(base.hf_launcher_script),
124
- "--priors-file",
125
- str(base.priors_file),
126
- ]
127
  if base.hf_output_repo:
128
  common.extend(["--hf-output-repo", base.hf_output_repo])
129
  if base.hf_use_bash:
130
  common.append("--hf-use-bash")
131
- if base.hf_stop_after_metric:
132
- common.append("--hf-stop-after-metric")
133
- else:
134
- common.append("--no-hf-stop-after-metric")
135
- if base.apply_priors:
136
- common.append("--apply-priors")
137
- else:
138
- common.append("--no-apply-priors")
139
  if phase == "phase1":
140
  return [
141
  *common,
@@ -184,32 +184,32 @@ def cmd_phase(args: argparse.Namespace) -> int:
184
  return rc
185
 
186
 
187
- def cmd_parallel(args: argparse.Namespace) -> int:
188
- enqueued_priors = _prime_transfer_priors(args.storage, args.study_name, args.priors_file, args.apply_priors)
189
- worker_args = _disable_worker_priors(_phase_args(args.phase, args))
190
- procs: list[subprocess.Popen[str]] = []
191
- for _ in range(args.workers):
192
- cmd = [sys.executable, str(HPO_SCRIPT), *worker_args]
193
  procs.append(subprocess.Popen(cmd, cwd=str(REPO_ROOT), text=True))
194
 
195
  exit_codes = [p.wait() for p in procs]
196
  stats = _study_stats(args.storage, args.study_name)
197
  payload = {
198
  "phase": args.phase,
199
- "workers": args.workers,
200
- "exit_codes": exit_codes,
201
- "enqueued_priors": enqueued_priors,
202
- "stats": stats,
203
- }
204
  args.summary_out.parent.mkdir(parents=True, exist_ok=True)
205
  args.summary_out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
206
  print(json.dumps(payload, indent=2))
207
  return 0 if all(code == 0 for code in exit_codes) else 1
208
 
209
 
210
- def cmd_recommend(args: argparse.Namespace) -> int:
211
- stats = _study_stats(args.storage, args.study_name)
212
- min_tps_floor = float(args.min_tps)
213
  if stats.get("status") == "missing":
214
  payload = {
215
  "stats": stats,
@@ -226,41 +226,41 @@ def cmd_recommend(args: argparse.Namespace) -> int:
226
 
227
  n_completed = int(stats.get("n_completed", 0))
228
 
229
- if n_completed < 10:
230
- recommendation = {
231
- "status": "insufficient_data",
232
- "next_step": "Run phase1 with 2-4 parallel workers until >=10 completed trials.",
233
- "early_stop_policy": {
234
- "patience_trials": 8,
235
- "min_improvement": 0.001,
236
- },
237
- "throughput_guard": {
238
- "min_tps": min_tps_floor,
239
- "note": "Trials below this TPS floor are pruned.",
240
- },
241
- "transfer_learning": {
242
- "export_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 10 --out docs/hpo_transfer_priors.json",
243
- "use_priors": "Enabled by default in scripts/optuna_hpo.py (override with --no-apply-priors)",
244
- },
245
- }
246
- else:
247
- recommendation = {
248
- "status": "ready_for_full_optimization",
249
- "next_step": "Run phase2 with 3-4 parallel workers.",
250
  "suggested_full_run": {
251
  "trials": 60,
252
- "workers": 4,
253
- "trial_time_budget": 300,
254
- "trial_timeout": 900,
255
- "min_tps": min_tps_floor,
256
- "patience_trials": 12,
257
- "min_improvement": 0.0005,
258
- },
259
- "transfer_learning": {
260
- "refresh_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 20 --out docs/hpo_transfer_priors.json",
261
- "notes": "Carry priors into new studies unless architecture/objective diverges significantly.",
262
- },
263
- }
264
 
265
  payload = {"stats": stats, "recommendation": recommendation}
266
  args.summary_out.parent.mkdir(parents=True, exist_ok=True)
@@ -269,9 +269,9 @@ def cmd_recommend(args: argparse.Namespace) -> int:
269
  return 0
270
 
271
 
272
- def build_parser() -> argparse.ArgumentParser:
273
- routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
274
- parser = argparse.ArgumentParser(description="Phase-oriented orchestration for Optuna HPO")
275
  sub = parser.add_subparsers(dest="cmd", required=True)
276
 
277
  def add_common(p: argparse.ArgumentParser) -> None:
@@ -283,21 +283,21 @@ def build_parser() -> argparse.ArgumentParser:
283
  p.add_argument("--min-tps", type=float, default=50000.0)
284
  p.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "orchestrator_summary.json")
285
  p.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local")
286
- p.add_argument("--hf-namespace", default=routing_defaults.job_namespace)
287
- p.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}")
288
  p.add_argument("--hf-flavor", default="a10g-large")
289
  p.add_argument("--hf-timeout", default="25m")
290
  p.add_argument("--hf-command", default="/app/entrypoint.py")
291
  p.add_argument("--hf-use-bash", action="store_true")
292
  p.add_argument("--hf-token-env", default="HF_TOKEN")
293
- p.add_argument("--hf-poll-interval", type=int, default=12)
294
- p.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py")
295
- p.add_argument("--hf-output-repo", default=routing_defaults.output_repo)
296
- p.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json")
297
- p.add_argument("--apply-priors", action="store_true", default=True)
298
- p.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
299
- p.add_argument("--hf-stop-after-metric", action="store_true", default=True)
300
- p.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
301
 
302
  # Phase-1 defaults
303
  p.add_argument("--phase1-trials", type=int, default=30)
 
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
+ import argparse
5
+ import json
6
+ import os
7
+ import subprocess
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import optuna
13
+
14
+
15
+ REPO_ROOT = Path(__file__).resolve().parents[1]
16
+ if str(REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(REPO_ROOT))
18
+
19
+ from scripts.hf_routing import resolve_routing
20
+ from scripts.optuna_hpo import _enqueue_transfer_priors
21
+
22
+ HPO_SCRIPT = REPO_ROOT / "scripts" / "optuna_hpo.py"
23
 
24
 
25
  def _run_worker(args: list[str]) -> int:
 
28
  return proc.returncode
29
 
30
 
31
+ def _study_stats(storage: str, study_name: str) -> dict[str, Any]:
32
  try:
33
  study = optuna.load_study(study_name=study_name, storage=storage)
34
  except KeyError:
 
62
  "best_trial_user_attrs": study.best_trial.user_attrs,
63
  }
64
  )
65
+ return stats
66
+
67
+
68
+ def _prime_transfer_priors(storage: str, study_name: str, priors_file: Path, apply_priors: bool) -> int:
69
+ if not apply_priors:
70
+ return 0
71
+ study = optuna.create_study(
72
+ study_name=study_name,
73
+ storage=storage,
74
+ load_if_exists=True,
75
+ direction="minimize",
76
+ )
77
+ return _enqueue_transfer_priors(study, priors_file, apply_priors=True)
78
+
79
+
80
+ def _disable_worker_priors(worker_args: list[str]) -> list[str]:
81
+ cleaned: list[str] = []
82
+ for item in worker_args:
83
+ if item in {"--apply-priors", "--no-apply-priors"}:
84
+ continue
85
+ cleaned.append(item)
86
+ cleaned.append("--no-apply-priors")
87
+ return cleaned
88
 
89
 
90
  def _phase_args(phase: str, base: argparse.Namespace) -> list[str]:
 
117
  base.hf_command,
118
  "--hf-token-env",
119
  base.hf_token_env,
120
+ "--hf-poll-interval",
121
+ str(base.hf_poll_interval),
122
+ "--hf-launcher-script",
123
+ str(base.hf_launcher_script),
124
+ "--priors-file",
125
+ str(base.priors_file),
126
+ ]
127
  if base.hf_output_repo:
128
  common.extend(["--hf-output-repo", base.hf_output_repo])
129
  if base.hf_use_bash:
130
  common.append("--hf-use-bash")
131
+ if base.hf_stop_after_metric:
132
+ common.append("--hf-stop-after-metric")
133
+ else:
134
+ common.append("--no-hf-stop-after-metric")
135
+ if base.apply_priors:
136
+ common.append("--apply-priors")
137
+ else:
138
+ common.append("--no-apply-priors")
139
  if phase == "phase1":
140
  return [
141
  *common,
 
184
  return rc
185
 
186
 
187
+ def cmd_parallel(args: argparse.Namespace) -> int:
188
+ enqueued_priors = _prime_transfer_priors(args.storage, args.study_name, args.priors_file, args.apply_priors)
189
+ worker_args = _disable_worker_priors(_phase_args(args.phase, args))
190
+ procs: list[subprocess.Popen[str]] = []
191
+ for _ in range(args.workers):
192
+ cmd = [sys.executable, str(HPO_SCRIPT), *worker_args]
193
  procs.append(subprocess.Popen(cmd, cwd=str(REPO_ROOT), text=True))
194
 
195
  exit_codes = [p.wait() for p in procs]
196
  stats = _study_stats(args.storage, args.study_name)
197
  payload = {
198
  "phase": args.phase,
199
+ "workers": args.workers,
200
+ "exit_codes": exit_codes,
201
+ "enqueued_priors": enqueued_priors,
202
+ "stats": stats,
203
+ }
204
  args.summary_out.parent.mkdir(parents=True, exist_ok=True)
205
  args.summary_out.write_text(json.dumps(payload, indent=2), encoding="utf-8")
206
  print(json.dumps(payload, indent=2))
207
  return 0 if all(code == 0 for code in exit_codes) else 1
208
 
209
 
210
+ def cmd_recommend(args: argparse.Namespace) -> int:
211
+ stats = _study_stats(args.storage, args.study_name)
212
+ min_tps_floor = float(args.min_tps)
213
  if stats.get("status") == "missing":
214
  payload = {
215
  "stats": stats,
 
226
 
227
  n_completed = int(stats.get("n_completed", 0))
228
 
229
+ if n_completed < 10:
230
+ recommendation = {
231
+ "status": "insufficient_data",
232
+ "next_step": "Run phase1 with 2-4 parallel workers until >=10 completed trials.",
233
+ "early_stop_policy": {
234
+ "patience_trials": 8,
235
+ "min_improvement": 0.001,
236
+ },
237
+ "throughput_guard": {
238
+ "min_tps": min_tps_floor,
239
+ "note": "Trials below this TPS floor are pruned.",
240
+ },
241
+ "transfer_learning": {
242
+ "export_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 10 --out docs/hpo_transfer_priors.json",
243
+ "use_priors": "Enabled by default in scripts/optuna_hpo.py (override with --no-apply-priors)",
244
+ },
245
+ }
246
+ else:
247
+ recommendation = {
248
+ "status": "ready_for_full_optimization",
249
+ "next_step": "Run phase2 with 3-4 parallel workers.",
250
  "suggested_full_run": {
251
  "trials": 60,
252
+ "workers": 4,
253
+ "trial_time_budget": 300,
254
+ "trial_timeout": 900,
255
+ "min_tps": min_tps_floor,
256
+ "patience_trials": 12,
257
+ "min_improvement": 0.0005,
258
+ },
259
+ "transfer_learning": {
260
+ "refresh_priors": f"python scripts/export_hpo_priors.py --storage {args.storage} --study-name {args.study_name} --top-k 20 --out docs/hpo_transfer_priors.json",
261
+ "notes": "Carry priors into new studies unless architecture/objective diverges significantly.",
262
+ },
263
+ }
264
 
265
  payload = {"stats": stats, "recommendation": recommendation}
266
  args.summary_out.parent.mkdir(parents=True, exist_ok=True)
 
269
  return 0
270
 
271
 
272
+ def build_parser() -> argparse.ArgumentParser:
273
+ routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
274
+ parser = argparse.ArgumentParser(description="Phase-oriented orchestration for Optuna HPO")
275
  sub = parser.add_subparsers(dest="cmd", required=True)
276
 
277
  def add_common(p: argparse.ArgumentParser) -> None:
 
283
  p.add_argument("--min-tps", type=float, default=50000.0)
284
  p.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "orchestrator_summary.json")
285
  p.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local")
286
+ p.add_argument("--hf-namespace", default=routing_defaults.job_namespace)
287
+ p.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}")
288
  p.add_argument("--hf-flavor", default="a10g-large")
289
  p.add_argument("--hf-timeout", default="25m")
290
  p.add_argument("--hf-command", default="/app/entrypoint.py")
291
  p.add_argument("--hf-use-bash", action="store_true")
292
  p.add_argument("--hf-token-env", default="HF_TOKEN")
293
+ p.add_argument("--hf-poll-interval", type=int, default=12)
294
+ p.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py")
295
+ p.add_argument("--hf-output-repo", default=routing_defaults.output_repo)
296
+ p.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json")
297
+ p.add_argument("--apply-priors", action="store_true", default=True)
298
+ p.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
299
+ p.add_argument("--hf-stop-after-metric", action="store_true", default=True)
300
+ p.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
301
 
302
  # Phase-1 defaults
303
  p.add_argument("--phase1-trials", type=int, default=30)
overlay/scripts/hpo_retest.py CHANGED
@@ -1,151 +1,151 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import datetime as dt
6
- import json
7
- from pathlib import Path
8
- from typing import Any
9
-
10
- import optuna
11
-
12
- from scripts.hpo_leaderboard import build_leaderboard
13
-
14
-
15
- _PARAM_TO_ENV = {
16
- "d_model": "HYDRA_D_MODEL",
17
- "n_layer": "HYDRA_N_LAYER",
18
- "d_state": "HYDRA_D_STATE",
19
- "headdim": "HYDRA_HEADDIM",
20
- "expand": "HYDRA_EXPAND",
21
- "seq_len": "HYDRA_SEQ_LEN",
22
- "batch_size": "HYDRA_BATCH_SIZE",
23
- "matrix_lr": "HYDRA_MATRIX_LR",
24
- "embed_lr": "HYDRA_EMBED_LR",
25
- "unembed_lr": "HYDRA_UNEMBED_LR",
26
- "engram_n_columns": "HYDRA_ENGRAM_N_COLUMNS",
27
- "engram_layer_idx": "HYDRA_ENGRAM_LAYER_IDX",
28
- "sdr_target_active": "HYDRA_SDR_TARGET_ACTIVE",
29
- "htm_learn_every": "HYDRA_HTM_LEARN_EVERY",
30
- "htm_subsample": "HYDRA_HTM_SUBSAMPLE",
31
- "engram_subsample": "HYDRA_ENGRAM_SUBSAMPLE",
32
- "mamba3_chunk": "HYDRA_MAMBA3_CHUNK",
33
- "dropout": "HYDRA_DROPOUT",
34
- }
35
-
36
- _DEFAULT_ENV = {
37
- "HYDRA_USE_NEMOTRON": "1",
38
- "HYDRA_LOCAL_SHARDS_ONLY": "0",
39
- "HYDRA_THROUGHPUT_MODE": "0",
40
- "HYDRA_FASTPATH": "0",
41
- "HYDRA_FORCE_HTM_CPU": "0",
42
- "HYDRA_INERT_MAMBA": "0",
43
- "HYDRA_ALLOW_SYNTHETIC_RETINA": "0",
44
- "HYDRA_HTM_FUSED": "1",
45
- "HYDRA_HYENA_LAYERS": "",
46
- "HYDRA_CKPT_INTERVAL": "0",
47
- "HYDRA_ENGRAM_SUBSAMPLE": "1",
48
- "HYDRA_HTM_SUBSAMPLE": "2",
49
- "HYDRA_HTM_LEARN_EVERY": "8",
50
- }
51
-
52
-
53
- def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
54
- params = dict(row.get("params") or {})
55
- attrs = row.get("user_attrs") or {}
56
- for key, value in attrs.items():
57
- if key.startswith("param_"):
58
- params.setdefault(key.removeprefix("param_"), value)
59
- return params
60
-
61
-
62
- def _candidate_env(params: dict[str, Any], *, eval_tokens: int, eval_batch: int, time_budget: int) -> dict[str, str]:
63
- env = dict(_DEFAULT_ENV)
64
- env["HYDRA_EVAL_TOKENS"] = str(eval_tokens)
65
- env["HYDRA_EVAL_BATCH"] = str(eval_batch)
66
- env["HYDRA_TIME_BUDGET"] = str(time_budget)
67
- for key, value in params.items():
68
- env_key = _PARAM_TO_ENV.get(key)
69
- if env_key is not None:
70
- env[env_key] = str(value)
71
- if "HYDRA_BATCH_SIZE" in env and "HYDRA_SEQ_LEN" in env:
72
- grad_accum = int(params.get("grad_accum", 16))
73
- env["HYDRA_TOTAL_BATCH"] = str(int(env["HYDRA_BATCH_SIZE"]) * int(env["HYDRA_SEQ_LEN"]) * grad_accum)
74
- return env
75
-
76
-
77
- def build_retest_plan(
78
- *,
79
- storage: str,
80
- study_names: list[str],
81
- top_k: int,
82
- metric: str = "val_bpb",
83
- eval_tokens: int = 16384,
84
- eval_batch: int = 2,
85
- time_budget: int = 420,
86
- ) -> dict[str, Any]:
87
- leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
88
- rows = [*leaderboard["contaminated_trials"], *leaderboard["clean_trials"]]
89
- reverse = leaderboard["direction"] == "maximize"
90
- rows.sort(key=lambda row: float(row["value"]), reverse=reverse)
91
- candidates = []
92
- for row in rows[: max(0, top_k)]:
93
- params = _recover_params(row)
94
- env = _candidate_env(params, eval_tokens=eval_tokens, eval_batch=eval_batch, time_budget=time_budget)
95
- reason = row.get("contamination_reason") or "canonical_truth_eval_retest"
96
- candidates.append({
97
- "study_name": row["study_name"],
98
- "trial_number": row["trial_number"],
99
- "source_value": row["value"],
100
- "source_objective": row["objective_source"],
101
- "source_job_id": row.get("hf_job_id"),
102
- "needs_retest_reason": reason,
103
- "params": params,
104
- "env": env,
105
- })
106
- return {
107
- "schema_version": 1,
108
- "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
109
- "metric": metric,
110
- "study_names": study_names,
111
- "eval_tokens": eval_tokens,
112
- "eval_batch": eval_batch,
113
- "time_budget": time_budget,
114
- "n_candidates": len(candidates),
115
- "candidates": candidates,
116
- }
117
-
118
-
119
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
120
- parser = argparse.ArgumentParser(description="Plan canonical-eval retests for historical HPO configs")
121
- parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
122
- parser.add_argument("--study-name", action="append", default=[])
123
- parser.add_argument("--metric", default="val_bpb")
124
- parser.add_argument("--top-k", type=int, default=10)
125
- parser.add_argument("--eval-tokens", type=int, default=16384)
126
- parser.add_argument("--eval-batch", type=int, default=2)
127
- parser.add_argument("--time-budget", type=int, default=420)
128
- parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "retest_plan.json")
129
- return parser.parse_args(argv)
130
-
131
-
132
- def main(argv: list[str] | None = None) -> int:
133
- args = parse_args(argv)
134
- study_names = args.study_name or ["hydra_hpo"]
135
- payload = build_retest_plan(
136
- storage=args.storage,
137
- study_names=study_names,
138
- top_k=args.top_k,
139
- metric=args.metric,
140
- eval_tokens=args.eval_tokens,
141
- eval_batch=args.eval_batch,
142
- time_budget=args.time_budget,
143
- )
144
- args.out.parent.mkdir(parents=True, exist_ok=True)
145
- args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
146
- print(json.dumps(payload, indent=2, sort_keys=True))
147
- return 0
148
-
149
-
150
- if __name__ == "__main__":
151
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import datetime as dt
6
+ import json
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import optuna
11
+
12
+ from scripts.hpo_leaderboard import build_leaderboard
13
+
14
+
15
+ _PARAM_TO_ENV = {
16
+ "d_model": "HYDRA_D_MODEL",
17
+ "n_layer": "HYDRA_N_LAYER",
18
+ "d_state": "HYDRA_D_STATE",
19
+ "headdim": "HYDRA_HEADDIM",
20
+ "expand": "HYDRA_EXPAND",
21
+ "seq_len": "HYDRA_SEQ_LEN",
22
+ "batch_size": "HYDRA_BATCH_SIZE",
23
+ "matrix_lr": "HYDRA_MATRIX_LR",
24
+ "embed_lr": "HYDRA_EMBED_LR",
25
+ "unembed_lr": "HYDRA_UNEMBED_LR",
26
+ "engram_n_columns": "HYDRA_ENGRAM_N_COLUMNS",
27
+ "engram_layer_idx": "HYDRA_ENGRAM_LAYER_IDX",
28
+ "sdr_target_active": "HYDRA_SDR_TARGET_ACTIVE",
29
+ "htm_learn_every": "HYDRA_HTM_LEARN_EVERY",
30
+ "htm_subsample": "HYDRA_HTM_SUBSAMPLE",
31
+ "engram_subsample": "HYDRA_ENGRAM_SUBSAMPLE",
32
+ "mamba3_chunk": "HYDRA_MAMBA3_CHUNK",
33
+ "dropout": "HYDRA_DROPOUT",
34
+ }
35
+
36
+ _DEFAULT_ENV = {
37
+ "HYDRA_USE_NEMOTRON": "1",
38
+ "HYDRA_LOCAL_SHARDS_ONLY": "0",
39
+ "HYDRA_THROUGHPUT_MODE": "0",
40
+ "HYDRA_FASTPATH": "0",
41
+ "HYDRA_FORCE_HTM_CPU": "0",
42
+ "HYDRA_INERT_MAMBA": "0",
43
+ "HYDRA_ALLOW_SYNTHETIC_RETINA": "0",
44
+ "HYDRA_HTM_FUSED": "1",
45
+ "HYDRA_HYENA_LAYERS": "",
46
+ "HYDRA_CKPT_INTERVAL": "0",
47
+ "HYDRA_ENGRAM_SUBSAMPLE": "1",
48
+ "HYDRA_HTM_SUBSAMPLE": "2",
49
+ "HYDRA_HTM_LEARN_EVERY": "8",
50
+ }
51
+
52
+
53
+ def _recover_params(row: dict[str, Any]) -> dict[str, Any]:
54
+ params = dict(row.get("params") or {})
55
+ attrs = row.get("user_attrs") or {}
56
+ for key, value in attrs.items():
57
+ if key.startswith("param_"):
58
+ params.setdefault(key.removeprefix("param_"), value)
59
+ return params
60
+
61
+
62
+ def _candidate_env(params: dict[str, Any], *, eval_tokens: int, eval_batch: int, time_budget: int) -> dict[str, str]:
63
+ env = dict(_DEFAULT_ENV)
64
+ env["HYDRA_EVAL_TOKENS"] = str(eval_tokens)
65
+ env["HYDRA_EVAL_BATCH"] = str(eval_batch)
66
+ env["HYDRA_TIME_BUDGET"] = str(time_budget)
67
+ for key, value in params.items():
68
+ env_key = _PARAM_TO_ENV.get(key)
69
+ if env_key is not None:
70
+ env[env_key] = str(value)
71
+ if "HYDRA_BATCH_SIZE" in env and "HYDRA_SEQ_LEN" in env:
72
+ grad_accum = int(params.get("grad_accum", 16))
73
+ env["HYDRA_TOTAL_BATCH"] = str(int(env["HYDRA_BATCH_SIZE"]) * int(env["HYDRA_SEQ_LEN"]) * grad_accum)
74
+ return env
75
+
76
+
77
+ def build_retest_plan(
78
+ *,
79
+ storage: str,
80
+ study_names: list[str],
81
+ top_k: int,
82
+ metric: str = "val_bpb",
83
+ eval_tokens: int = 16384,
84
+ eval_batch: int = 2,
85
+ time_budget: int = 420,
86
+ ) -> dict[str, Any]:
87
+ leaderboard = build_leaderboard(storage=storage, study_names=study_names, metric=metric)
88
+ rows = [*leaderboard["contaminated_trials"], *leaderboard["clean_trials"]]
89
+ reverse = leaderboard["direction"] == "maximize"
90
+ rows.sort(key=lambda row: float(row["value"]), reverse=reverse)
91
+ candidates = []
92
+ for row in rows[: max(0, top_k)]:
93
+ params = _recover_params(row)
94
+ env = _candidate_env(params, eval_tokens=eval_tokens, eval_batch=eval_batch, time_budget=time_budget)
95
+ reason = row.get("contamination_reason") or "canonical_truth_eval_retest"
96
+ candidates.append({
97
+ "study_name": row["study_name"],
98
+ "trial_number": row["trial_number"],
99
+ "source_value": row["value"],
100
+ "source_objective": row["objective_source"],
101
+ "source_job_id": row.get("hf_job_id"),
102
+ "needs_retest_reason": reason,
103
+ "params": params,
104
+ "env": env,
105
+ })
106
+ return {
107
+ "schema_version": 1,
108
+ "generated_at": dt.datetime.now(dt.UTC).isoformat(timespec="seconds"),
109
+ "metric": metric,
110
+ "study_names": study_names,
111
+ "eval_tokens": eval_tokens,
112
+ "eval_batch": eval_batch,
113
+ "time_budget": time_budget,
114
+ "n_candidates": len(candidates),
115
+ "candidates": candidates,
116
+ }
117
+
118
+
119
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
120
+ parser = argparse.ArgumentParser(description="Plan canonical-eval retests for historical HPO configs")
121
+ parser.add_argument("--storage", default="sqlite:///optuna_hpo.db")
122
+ parser.add_argument("--study-name", action="append", default=[])
123
+ parser.add_argument("--metric", default="val_bpb")
124
+ parser.add_argument("--top-k", type=int, default=10)
125
+ parser.add_argument("--eval-tokens", type=int, default=16384)
126
+ parser.add_argument("--eval-batch", type=int, default=2)
127
+ parser.add_argument("--time-budget", type=int, default=420)
128
+ parser.add_argument("--out", type=Path, default=Path(".tmp") / "optuna" / "retest_plan.json")
129
+ return parser.parse_args(argv)
130
+
131
+
132
+ def main(argv: list[str] | None = None) -> int:
133
+ args = parse_args(argv)
134
+ study_names = args.study_name or ["hydra_hpo"]
135
+ payload = build_retest_plan(
136
+ storage=args.storage,
137
+ study_names=study_names,
138
+ top_k=args.top_k,
139
+ metric=args.metric,
140
+ eval_tokens=args.eval_tokens,
141
+ eval_batch=args.eval_batch,
142
+ time_budget=args.time_budget,
143
+ )
144
+ args.out.parent.mkdir(parents=True, exist_ok=True)
145
+ args.out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
146
+ print(json.dumps(payload, indent=2, sort_keys=True))
147
+ return 0
148
+
149
+
150
+ if __name__ == "__main__":
151
+ raise SystemExit(main())
overlay/scripts/hydra_generation.py CHANGED
@@ -1,183 +1,180 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import os
5
- from pathlib import Path
6
- from typing import Callable
7
-
8
- import torch
9
-
10
- from scripts.benchmark_checkpoint import hydrate_checkpoint
11
- from scripts.hf_routing import resolve_routing
12
-
13
-
14
- def default_checkpoint_path() -> Path:
15
- return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
16
-
17
-
18
- def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
19
- base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
20
- return [
21
- base / "best_bpb.pt",
22
- base / "pretrain_final.pt",
23
- base / "latest.pt",
24
- ]
25
-
26
-
27
- def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
28
- if explicit_path is not None:
29
- return explicit_path
30
- env_checkpoint = os.environ.get("HYDRA_HF_CHECKPOINT_PATH")
31
- if env_checkpoint:
32
- return Path(env_checkpoint).expanduser()
33
- for candidate in checkpoint_candidates(cache_dir=cache_dir):
34
- if candidate.exists():
35
- return candidate
36
- return default_checkpoint_path()
37
-
38
-
39
- def validate_checkpoint_compatibility(
40
- *,
41
- baseline_arch: str,
42
- missing_keys: list[str],
43
- unexpected_keys: list[str],
44
- total_model_keys: int,
45
- ) -> None:
46
- if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
47
- raise RuntimeError(
48
- "checkpoint incompatible with transformer baseline architecture; "
49
- "use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
50
- )
51
- mismatch_count = len(missing_keys) + len(unexpected_keys)
52
- if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
53
- raise RuntimeError("checkpoint incompatible with requested model architecture")
54
-
55
-
56
- def generate_from_callable(
57
- generator: Callable[[str], str] | Callable[..., str],
58
- prompt: str,
59
- *,
60
- max_new_tokens: int,
61
- temperature: float,
62
- top_p: float,
63
- ) -> str:
64
- text = generator(
65
- prompt,
66
- max_new_tokens=max_new_tokens,
67
- temperature=temperature,
68
- top_p=top_p,
69
- )
70
- return str(text).strip()
71
-
72
-
73
- def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
74
- ckpt_path = resolve_checkpoint_path(checkpoint_path)
75
- if not ckpt_path.exists():
76
- hydrated = hydrate_checkpoint(
77
- cache_dir=ckpt_path.parent,
78
- output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
79
- token=os.environ.get("HF_TOKEN"),
80
- )
81
- if hydrated is not None:
82
- ckpt_path = hydrated
83
- if not ckpt_path.exists():
84
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
85
-
86
- from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
87
- from transformers.modeling_outputs import CausalLMOutputWithPast
88
-
89
- from hydra.config import PostSemClawConfig
90
- from hydra.model import PostSemClawModel
91
- from prepare import Tokenizer
92
-
93
- resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
94
-
95
- class _HydraGenConfig(PretrainedConfig):
96
- model_type = "hydra"
97
-
98
- def __init__(self, vocab_size: int = 65536, **kw):
99
- super().__init__(**kw)
100
- self.vocab_size = vocab_size
101
-
102
- class HydraForCausalLM(PreTrainedModel, GenerationMixin):
103
- config_class = _HydraGenConfig
104
-
105
- def __init__(self, gen_config, inner_model):
106
- super().__init__(gen_config)
107
- self.inner = inner_model
108
- self.config.vocab_size = gen_config.vocab_size
109
-
110
- def forward(self, input_ids, attention_mask=None, **kw):
111
- logits = self.inner(input_ids)
112
- return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
113
-
114
- def prepare_inputs_for_generation(self, input_ids, **kw):
115
- return {"input_ids": input_ids}
116
-
117
- def get_input_embeddings(self):
118
- return self.inner.wte
119
-
120
- def can_generate(self) -> bool:
121
- return True
122
-
123
- @property
124
- def _supports_cache_class(self):
125
- return False
126
-
127
- tokenizer = Tokenizer.from_directory()
128
- vocab_size = tokenizer.get_vocab_size()
129
- bos = tokenizer.get_bos_token_id()
130
- ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
131
- cfg = PostSemClawConfig(**ckpt["config"])
132
- with torch.device("meta"):
133
- inner = PostSemClawModel(cfg)
134
- inner.to_empty(device=resolved_device)
135
- missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
136
- validate_checkpoint_compatibility(
137
- baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
138
- missing_keys=list(missing),
139
- unexpected_keys=list(unexpected),
140
- total_model_keys=len(inner.state_dict()),
141
- )
142
- inner.eval()
143
-
144
- gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
145
- gen_cfg.bos_token_id = bos
146
- gen_cfg.eos_token_id = bos
147
- gen_cfg.pad_token_id = bos
148
- model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
149
- model.eval()
150
- return tokenizer, model, bos, resolved_device, GenerationConfig
151
-
152
-
153
- def build_hydra_generator(
154
- *,
155
- checkpoint_path: Path | None = None,
156
- device: str | None = None,
157
- max_new_tokens: int,
158
- temperature: float,
159
- top_p: float,
160
- ):
161
- tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)
162
-
163
- def _generate(prompt: str) -> str:
164
- ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
165
- gen_config = GenerationConfig(
166
- max_new_tokens=max_new_tokens,
167
- use_cache=False,
168
- do_sample=temperature > 0.0,
169
- temperature=temperature,
170
- top_p=top_p,
171
- bos_token_id=bos,
172
- eos_token_id=bos,
173
- pad_token_id=bos,
174
- )
175
- if str(resolved_device).startswith("cuda"):
176
- with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
177
- out = model.generate(ids, generation_config=gen_config)
178
- else:
179
- with torch.no_grad():
180
- out = model.generate(ids, generation_config=gen_config)
181
- return tokenizer.decode(out[0].tolist())
182
-
183
- return _generate
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Callable
7
+
8
+ import torch
9
+
10
+ from scripts.benchmark_checkpoint import hydrate_checkpoint
11
+ from scripts.hf_routing import resolve_routing
12
+
13
+
14
+ def default_checkpoint_path() -> Path:
15
+ return Path(os.path.expanduser("~/.cache/autoresearch/latest.pt"))
16
+
17
+
18
+ def checkpoint_candidates(*, cache_dir: Path | None = None) -> list[Path]:
19
+ base = cache_dir or Path(os.path.expanduser("~/.cache/autoresearch"))
20
+ return [
21
+ base / "best_bpb.pt",
22
+ base / "pretrain_final.pt",
23
+ base / "latest.pt",
24
+ ]
25
+
26
+
27
+ def resolve_checkpoint_path(explicit_path: Path | None, *, cache_dir: Path | None = None) -> Path:
28
+ if explicit_path is not None:
29
+ return explicit_path
30
+ for candidate in checkpoint_candidates(cache_dir=cache_dir):
31
+ if candidate.exists():
32
+ return candidate
33
+ return default_checkpoint_path()
34
+
35
+
36
+ def validate_checkpoint_compatibility(
37
+ *,
38
+ baseline_arch: str,
39
+ missing_keys: list[str],
40
+ unexpected_keys: list[str],
41
+ total_model_keys: int,
42
+ ) -> None:
43
+ if baseline_arch == "transformer" and (missing_keys or unexpected_keys):
44
+ raise RuntimeError(
45
+ "checkpoint incompatible with transformer baseline architecture; "
46
+ "use a transformer-trained checkpoint or keep HYDRA_BASELINE_ARCH=mamba3"
47
+ )
48
+ mismatch_count = len(missing_keys) + len(unexpected_keys)
49
+ if total_model_keys > 0 and mismatch_count > max(8, total_model_keys // 2):
50
+ raise RuntimeError("checkpoint incompatible with requested model architecture")
51
+
52
+
53
+ def generate_from_callable(
54
+ generator: Callable[[str], str] | Callable[..., str],
55
+ prompt: str,
56
+ *,
57
+ max_new_tokens: int,
58
+ temperature: float,
59
+ top_p: float,
60
+ ) -> str:
61
+ text = generator(
62
+ prompt,
63
+ max_new_tokens=max_new_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ )
67
+ return str(text).strip()
68
+
69
+
70
+ def load_hydra_causal_lm(checkpoint_path: Path | None = None, device: str | None = None):
71
+ ckpt_path = resolve_checkpoint_path(checkpoint_path)
72
+ if not ckpt_path.exists():
73
+ hydrated = hydrate_checkpoint(
74
+ cache_dir=ckpt_path.parent,
75
+ output_repo=resolve_routing(token=os.environ.get("HF_TOKEN")).output_repo,
76
+ token=os.environ.get("HF_TOKEN"),
77
+ )
78
+ if hydrated is not None:
79
+ ckpt_path = hydrated
80
+ if not ckpt_path.exists():
81
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
82
+
83
+ from transformers import GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel
84
+ from transformers.modeling_outputs import CausalLMOutputWithPast
85
+
86
+ from hydra.config import PostSemClawConfig
87
+ from hydra.model import PostSemClawModel
88
+ from prepare import Tokenizer
89
+
90
+ resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
91
+
92
+ class _HydraGenConfig(PretrainedConfig):
93
+ model_type = "hydra"
94
+
95
+ def __init__(self, vocab_size: int = 65536, **kw):
96
+ super().__init__(**kw)
97
+ self.vocab_size = vocab_size
98
+
99
+ class HydraForCausalLM(PreTrainedModel, GenerationMixin):
100
+ config_class = _HydraGenConfig
101
+
102
+ def __init__(self, gen_config, inner_model):
103
+ super().__init__(gen_config)
104
+ self.inner = inner_model
105
+ self.config.vocab_size = gen_config.vocab_size
106
+
107
+ def forward(self, input_ids, attention_mask=None, **kw):
108
+ logits = self.inner(input_ids)
109
+ return CausalLMOutputWithPast(loss=None, logits=logits, past_key_values=None)
110
+
111
+ def prepare_inputs_for_generation(self, input_ids, **kw):
112
+ return {"input_ids": input_ids}
113
+
114
+ def get_input_embeddings(self):
115
+ return self.inner.wte
116
+
117
+ def can_generate(self) -> bool:
118
+ return True
119
+
120
+ @property
121
+ def _supports_cache_class(self):
122
+ return False
123
+
124
+ tokenizer = Tokenizer.from_directory()
125
+ vocab_size = tokenizer.get_vocab_size()
126
+ bos = tokenizer.get_bos_token_id()
127
+ ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
128
+ cfg = PostSemClawConfig(**ckpt["config"])
129
+ with torch.device("meta"):
130
+ inner = PostSemClawModel(cfg)
131
+ inner.to_empty(device=resolved_device)
132
+ missing, unexpected = inner.load_state_dict(ckpt["model_state_dict"], strict=False)
133
+ validate_checkpoint_compatibility(
134
+ baseline_arch=os.environ.get("HYDRA_BASELINE_ARCH", "mamba3").strip().lower(),
135
+ missing_keys=list(missing),
136
+ unexpected_keys=list(unexpected),
137
+ total_model_keys=len(inner.state_dict()),
138
+ )
139
+ inner.eval()
140
+
141
+ gen_cfg = _HydraGenConfig(vocab_size=vocab_size)
142
+ gen_cfg.bos_token_id = bos
143
+ gen_cfg.eos_token_id = bos
144
+ gen_cfg.pad_token_id = bos
145
+ model = HydraForCausalLM(gen_cfg, inner).to(resolved_device)
146
+ model.eval()
147
+ return tokenizer, model, bos, resolved_device, GenerationConfig
148
+
149
+
150
+ def build_hydra_generator(
151
+ *,
152
+ checkpoint_path: Path | None = None,
153
+ device: str | None = None,
154
+ max_new_tokens: int,
155
+ temperature: float,
156
+ top_p: float,
157
+ ):
158
+ tokenizer, model, bos, resolved_device, GenerationConfig = load_hydra_causal_lm(checkpoint_path=checkpoint_path, device=device)
159
+
160
+ def _generate(prompt: str) -> str:
161
+ ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=resolved_device)
162
+ gen_config = GenerationConfig(
163
+ max_new_tokens=max_new_tokens,
164
+ use_cache=False,
165
+ do_sample=temperature > 0.0,
166
+ temperature=temperature,
167
+ top_p=top_p,
168
+ bos_token_id=bos,
169
+ eos_token_id=bos,
170
+ pad_token_id=bos,
171
+ )
172
+ if str(resolved_device).startswith("cuda"):
173
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
174
+ out = model.generate(ids, generation_config=gen_config)
175
+ else:
176
+ with torch.no_grad():
177
+ out = model.generate(ids, generation_config=gen_config)
178
+ return tokenizer.decode(out[0].tolist())
179
+
180
+ return _generate
 
 
 
overlay/scripts/launch_benchmark_hf_job.py CHANGED
@@ -1,222 +1,157 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import os
7
- import sys
8
- from pathlib import Path
9
-
10
- REPO_ROOT = Path(__file__).resolve().parents[1]
11
- if str(REPO_ROOT) not in sys.path:
12
- sys.path.insert(0, str(REPO_ROOT))
13
-
14
- from huggingface_hub import HfApi
15
- from huggingface_hub.utils import get_token
16
-
17
- from scripts.cycle_executor import decode_config_for_benchmark, load_cycle_freeze, variant_env_for_benchmark
18
- from scripts.hf_routing import resolve_routing
19
- from scripts.launch_feather_hf_job import IMAGE_DIR, sync_overlay_from_repo, wait_for_space
20
-
21
-
22
- FREEZE_PATH = REPO_ROOT / "artifacts" / "cycle_1_execution_freeze.json"
23
-
24
-
25
- def resolve_variant_checkpoint_override(*, variant: str) -> str | None:
26
- if variant == "baseline_transformer_matched":
27
- return os.environ.get("HYDRA_TRANSFORMER_CHECKPOINT_PATH") or os.environ.get("HYDRA_HF_CHECKPOINT_PATH")
28
- if variant == "baseline_mamba_matched":
29
- return os.environ.get("HYDRA_MAMBA_CHECKPOINT_PATH") or os.environ.get("HYDRA_HF_CHECKPOINT_PATH")
30
- return os.environ.get("HYDRA_HF_CHECKPOINT_PATH")
31
-
32
-
33
- def validate_variant_ready_for_submission(*, variant: str, freeze: dict[str, object]) -> None:
34
- variants = freeze.get("variants")
35
- variant_cfg: object = variants.get(variant, {}) if isinstance(variants, dict) else {}
36
- status = variant_cfg.get("status") if isinstance(variant_cfg, dict) else None
37
- checkpoint_override = resolve_variant_checkpoint_override(variant=variant)
38
- if variant == "baseline_transformer_matched" and status == "blocked_checkpoint_incompatible" and not checkpoint_override:
39
- raise SystemExit(
40
- "baseline_transformer_matched is blocked by checkpoint incompatibility; set HYDRA_TRANSFORMER_CHECKPOINT_PATH to a transformer-compatible checkpoint before submission"
41
- )
42
-
43
-
44
- def build_benchmark_job_env(
45
- *,
46
- benchmark: str,
47
- variant: str,
48
- seed: int,
49
- output_repo: str,
50
- tokenizer_repo: str,
51
- retina_repo: str,
52
- freeze: dict[str, object],
53
- ) -> dict[str, str]:
54
- env = {
55
- "FEATHER_HF_OUTPUT_REPO": output_repo,
56
- "FEATHER_HF_RETINA_CACHE_REPO": retina_repo,
57
- "HF_REPO_ID": output_repo,
58
- "FEATHER_RUNTIME_MODE": "benchmark",
59
- "HYDRA_TOKENIZER_CACHE_REPO": tokenizer_repo,
60
- "HYDRA_RETINA_CACHE_REPO": retina_repo,
61
- "HYDRA_BENCHMARK_NAME": benchmark,
62
- "HYDRA_BENCHMARK_VARIANT": variant,
63
- "HYDRA_SEED": str(seed),
64
- "PYTHONUNBUFFERED": "1",
65
- }
66
- env.update(variant_env_for_benchmark(freeze, variant))
67
- checkpoint_override = resolve_variant_checkpoint_override(variant=variant)
68
- if checkpoint_override:
69
- env["HYDRA_HF_CHECKPOINT_PATH"] = checkpoint_override
70
- for key, value in os.environ.items():
71
- if key.startswith("HYDRA_") and key not in env:
72
- env[key] = value
73
- return env
74
-
75
-
76
- def build_benchmark_job_command(
77
- *,
78
- benchmark: str,
79
- variant: str,
80
- seed: int,
81
- suite_path: Path | None,
82
- freeze: dict[str, object],
83
- ) -> list[str]:
84
- decode_cfg = decode_config_for_benchmark(freeze, benchmark)
85
- command = [
86
- "python",
87
- "/app/entrypoint.py",
88
- "--max-new-tokens",
89
- str(int(decode_cfg.get("max_tokens", 256))),
90
- "--temperature",
91
- str(float(decode_cfg.get("temperature", 0.2))),
92
- "--top-p",
93
- str(float(decode_cfg.get("top_p", 0.95))),
94
- ]
95
- if suite_path is not None:
96
- command.extend(["--suite", str(suite_path)])
97
- return command
98
-
99
-
100
- def submit_benchmark_job(
101
- *,
102
- api,
103
- image: str,
104
- command: list[str],
105
- env: dict[str, str],
106
- token: str,
107
- namespace: str,
108
- flavor: str,
109
- timeout: str,
110
- ) -> dict[str, str]:
111
- job = api.run_job(
112
- image=image,
113
- command=command,
114
- env=env,
115
- secrets={"HF_TOKEN": token},
116
- flavor=flavor,
117
- timeout=timeout,
118
- namespace=namespace,
119
- token=token,
120
- )
121
- return {
122
- "job_id": job.id,
123
- "job_url": job.url,
124
- "job_stage": str(job.status.stage),
125
- }
126
-
127
-
128
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
129
- routing = resolve_routing(token=os.environ.get("HF_TOKEN"))
130
- parser = argparse.ArgumentParser(description="Prepare or submit a remote HF benchmark job")
131
- parser.add_argument("--benchmark", required=True)
132
- parser.add_argument("--variant", required=True)
133
- parser.add_argument("--seed", type=int, required=True)
134
- parser.add_argument("--output-repo", default=routing.output_repo)
135
- parser.add_argument("--tokenizer-repo", default=routing.output_repo)
136
- parser.add_argument("--retina-repo", default=routing.retina_cache_repo)
137
- parser.add_argument("--freeze", type=Path, default=FREEZE_PATH)
138
- parser.add_argument("--suite", type=Path)
139
- parser.add_argument("--image", default=f"hf.co/spaces/{routing.space_repo}")
140
- parser.add_argument("--namespace", default=routing.job_namespace)
141
- parser.add_argument("--flavor", default="a10g-small")
142
- parser.add_argument("--timeout", default="30m")
143
- parser.add_argument("--summary-out", type=Path)
144
- parser.add_argument("--dry-run", action="store_true")
145
- parser.add_argument("--refresh-image", action="store_true")
146
- parser.add_argument("--sync-overlay", action="store_true")
147
- return parser.parse_args(argv)
148
-
149
-
150
- def main(argv: list[str] | None = None) -> int:
151
- args = parse_args(argv)
152
- freeze = load_cycle_freeze(args.freeze)
153
- validate_variant_ready_for_submission(variant=args.variant, freeze=freeze)
154
- env = build_benchmark_job_env(
155
- benchmark=args.benchmark,
156
- variant=args.variant,
157
- seed=args.seed,
158
- output_repo=args.output_repo,
159
- tokenizer_repo=args.tokenizer_repo,
160
- retina_repo=args.retina_repo,
161
- freeze=freeze,
162
- )
163
- command = build_benchmark_job_command(
164
- benchmark=args.benchmark,
165
- variant=args.variant,
166
- seed=args.seed,
167
- suite_path=args.suite,
168
- freeze=freeze,
169
- )
170
- payload = {
171
- "benchmark": args.benchmark,
172
- "variant": args.variant,
173
- "seed": args.seed,
174
- "output_repo": args.output_repo,
175
- "tokenizer_repo": args.tokenizer_repo,
176
- "retina_repo": args.retina_repo,
177
- "freeze": str(args.freeze),
178
- "suite": str(args.suite) if args.suite is not None else None,
179
- "image": args.image,
180
- "namespace": args.namespace,
181
- "command": command,
182
- "env": env,
183
- "dry_run": args.dry_run,
184
- }
185
- if not args.dry_run:
186
- token = os.environ.get("HF_TOKEN") or get_token()
187
- if not token:
188
- raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
189
- api = HfApi(token=token)
190
- if args.refresh_image:
191
- space_repo = args.image.removeprefix("hf.co/spaces/")
192
- if args.sync_overlay:
193
- sync_overlay_from_repo()
194
- api.upload_folder(
195
- repo_id=space_repo,
196
- repo_type="space",
197
- folder_path=str(IMAGE_DIR),
198
- commit_message="Update benchmark runtime image",
199
- token=token,
200
- )
201
- wait_for_space(api, space_repo, token=token)
202
- payload.update(
203
- submit_benchmark_job(
204
- api=api,
205
- image=args.image,
206
- command=command,
207
- env=env,
208
- token=token,
209
- namespace=args.namespace,
210
- flavor=args.flavor,
211
- timeout=args.timeout,
212
- )
213
- )
214
- if args.summary_out is not None:
215
- args.summary_out.parent.mkdir(parents=True, exist_ok=True)
216
- args.summary_out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
217
- print(json.dumps(payload, indent=2, sort_keys=True))
218
- return 0
219
-
220
-
221
- if __name__ == "__main__":
222
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ REPO_ROOT = Path(__file__).resolve().parents[1]
11
+ if str(REPO_ROOT) not in sys.path:
12
+ sys.path.insert(0, str(REPO_ROOT))
13
+
14
+ from huggingface_hub import HfApi
15
+ from huggingface_hub.utils import get_token
16
+
17
+ from scripts.hf_routing import resolve_routing
18
+ from scripts.launch_feather_hf_job import IMAGE_DIR, sync_overlay_from_repo, wait_for_space
19
+
20
+
21
+ def build_benchmark_job_env(
22
+ *,
23
+ benchmark: str,
24
+ variant: str,
25
+ seed: int,
26
+ output_repo: str,
27
+ tokenizer_repo: str,
28
+ ) -> dict[str, str]:
29
+ env = {
30
+ "FEATHER_HF_OUTPUT_REPO": output_repo,
31
+ "FEATHER_RUNTIME_MODE": "benchmark",
32
+ "HYDRA_TOKENIZER_CACHE_REPO": tokenizer_repo,
33
+ "HYDRA_BENCHMARK_NAME": benchmark,
34
+ "HYDRA_BENCHMARK_VARIANT": variant,
35
+ "HYDRA_SEED": str(seed),
36
+ "PYTHONUNBUFFERED": "1",
37
+ }
38
+ for key, value in os.environ.items():
39
+ if key.startswith("HYDRA_") and key not in env:
40
+ env[key] = value
41
+ return env
42
+
43
+
44
+ def build_benchmark_job_command(*, benchmark: str, variant: str, seed: int) -> list[str]:
45
+ return [
46
+ "python",
47
+ "/app/entrypoint.py",
48
+ ]
49
+
50
+
51
+ def submit_benchmark_job(
52
+ *,
53
+ api,
54
+ image: str,
55
+ command: list[str],
56
+ env: dict[str, str],
57
+ token: str,
58
+ namespace: str,
59
+ flavor: str,
60
+ timeout: str,
61
+ ) -> dict[str, str]:
62
+ job = api.run_job(
63
+ image=image,
64
+ command=command,
65
+ env=env,
66
+ secrets={"HF_TOKEN": token},
67
+ flavor=flavor,
68
+ timeout=timeout,
69
+ namespace=namespace,
70
+ token=token,
71
+ )
72
+ return {
73
+ "job_id": job.id,
74
+ "job_url": job.url,
75
+ "job_stage": str(job.status.stage),
76
+ }
77
+
78
+
79
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
80
+ routing = resolve_routing(token=os.environ.get("HF_TOKEN"))
81
+ parser = argparse.ArgumentParser(description="Prepare or submit a remote HF benchmark job")
82
+ parser.add_argument("--benchmark", required=True)
83
+ parser.add_argument("--variant", required=True)
84
+ parser.add_argument("--seed", type=int, required=True)
85
+ parser.add_argument("--output-repo", default=routing.output_repo)
86
+ parser.add_argument("--tokenizer-repo", default=routing.output_repo)
87
+ parser.add_argument("--image", default=f"hf.co/spaces/{routing.space_repo}")
88
+ parser.add_argument("--namespace", default=routing.job_namespace)
89
+ parser.add_argument("--flavor", default="a10g-small")
90
+ parser.add_argument("--timeout", default="30m")
91
+ parser.add_argument("--summary-out", type=Path)
92
+ parser.add_argument("--dry-run", action="store_true")
93
+ parser.add_argument("--refresh-image", action="store_true")
94
+ parser.add_argument("--sync-overlay", action="store_true")
95
+ return parser.parse_args(argv)
96
+
97
+
98
+ def main(argv: list[str] | None = None) -> int:
99
+ args = parse_args(argv)
100
+ env = build_benchmark_job_env(
101
+ benchmark=args.benchmark,
102
+ variant=args.variant,
103
+ seed=args.seed,
104
+ output_repo=args.output_repo,
105
+ tokenizer_repo=args.tokenizer_repo,
106
+ )
107
+ command = build_benchmark_job_command(benchmark=args.benchmark, variant=args.variant, seed=args.seed)
108
+ payload = {
109
+ "benchmark": args.benchmark,
110
+ "variant": args.variant,
111
+ "seed": args.seed,
112
+ "output_repo": args.output_repo,
113
+ "tokenizer_repo": args.tokenizer_repo,
114
+ "image": args.image,
115
+ "namespace": args.namespace,
116
+ "command": command,
117
+ "env": env,
118
+ "dry_run": args.dry_run,
119
+ }
120
+ if not args.dry_run:
121
+ token = os.environ.get("HF_TOKEN") or get_token()
122
+ if not token:
123
+ raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
124
+ api = HfApi(token=token)
125
+ if args.refresh_image:
126
+ space_repo = args.image.removeprefix("hf.co/spaces/")
127
+ if args.sync_overlay:
128
+ sync_overlay_from_repo()
129
+ api.upload_folder(
130
+ repo_id=space_repo,
131
+ repo_type="space",
132
+ folder_path=str(IMAGE_DIR),
133
+ commit_message="Update benchmark runtime image",
134
+ token=token,
135
+ )
136
+ wait_for_space(api, space_repo, token=token)
137
+ payload.update(
138
+ submit_benchmark_job(
139
+ api=api,
140
+ image=args.image,
141
+ command=command,
142
+ env=env,
143
+ token=token,
144
+ namespace=args.namespace,
145
+ flavor=args.flavor,
146
+ timeout=args.timeout,
147
+ )
148
+ )
149
+ if args.summary_out is not None:
150
+ args.summary_out.parent.mkdir(parents=True, exist_ok=True)
151
+ args.summary_out.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
152
+ print(json.dumps(payload, indent=2, sort_keys=True))
153
+ return 0
154
+
155
+
156
+ if __name__ == "__main__":
157
+ raise SystemExit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/scripts/launch_feather_hf_job.py CHANGED
@@ -1,224 +1,218 @@
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
- import os
5
- import shutil
6
- import sys
7
- import time
8
- import json
9
- from collections.abc import Mapping, Sequence
10
- from typing import Any, cast
11
- from pathlib import Path
12
-
13
- import httpx
14
- from huggingface_hub import HfApi
15
- from huggingface_hub.utils import HfHubHTTPError
16
- from huggingface_hub.utils import get_token
17
-
18
- REPO_ROOT = Path(__file__).resolve().parents[1]
19
- if str(REPO_ROOT) not in sys.path:
20
- sys.path.insert(0, str(REPO_ROOT))
21
-
22
- from scripts.hf_routing import resolve_routing
23
- from configs.harness_config import HarnessConfig
24
-
25
- DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:latest')
26
- IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
27
- TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
28
- TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
29
- TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
30
- DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
31
- CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
32
- JOB_FLAVOR = os.environ.get('FEATHER_HF_FLAVOR', 'a10g-small')
33
- DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
34
- USE_SPACE_IMAGE = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
35
  # When true, assume the Space image has already been built by a previous
36
  # invocation and skip the upload+build wait. Used by sweep drivers that fan
37
  # out many jobs against a single pre-uploaded image.
38
- SKIP_UPLOAD = os.environ.get('FEATHER_HF_SKIP_UPLOAD', '0') == '1'
39
- SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
40
- JOB_SUBMIT_RETRIES = max(1, int(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRIES', '3')))
41
- JOB_SUBMIT_RETRY_BASE_S = float(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRY_BASE_S', '5'))
42
- BUILD_LOG_TAIL_LINES = max(1, int(os.environ.get('FEATHER_HF_BUILD_LOG_TAIL_LINES', '120')))
43
-
44
-
45
- def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
46
- """Use streaming data path for short-budget launch profiles."""
47
- try:
48
- shards = int(target_shards)
49
- budget = int(time_budget)
50
- except ValueError:
51
- return False
52
- return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
53
-
54
-
55
- def apply_a10_env_profile(
56
- env: dict[str, str],
57
- *,
58
- job_flavor: str,
59
- parent_env: Mapping[str, str] = os.environ,
60
- ) -> str | None:
61
- if not job_flavor.startswith('a10'):
62
- return None
63
-
64
- full_arch = parent_env.get('HYDRA_THROUGHPUT_MODE') == '0' or env.get('HYDRA_THROUGHPUT_MODE') == '0'
65
- if full_arch:
66
- defaults = {
67
- 'HYDRA_THROUGHPUT_MODE': '0',
68
- 'HYDRA_MUON_COMPILE': '0',
69
- 'HYDRA_FORCE_HTM_CPU': '0',
70
- 'HYDRA_INERT_MAMBA': '0',
71
- 'HYDRA_ALLOW_SYNTHETIC_RETINA': '0',
72
- 'HYDRA_FASTPATH': '0',
73
- }
74
- profile = 'full-architecture'
75
- else:
76
- defaults = {
77
- 'HYDRA_MUON_COMPILE': '0',
78
- 'HYDRA_FORCE_HTM_CPU': '1',
79
- 'HYDRA_INERT_MAMBA': '1',
80
- 'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
81
- 'HYDRA_FASTPATH': '1',
82
- }
83
- profile = 'compatibility'
84
-
85
- for key, default in defaults.items():
86
- env[key] = parent_env.get(key, env.get(key, default))
87
- return profile
88
-
89
-
90
- def _http_status(exc: BaseException) -> int | None:
91
- response = getattr(exc, 'response', None)
92
- status = getattr(response, 'status_code', None)
93
- return int(status) if isinstance(status, int) else None
94
-
95
-
96
- def submit_job_with_retry(
97
- api: HfApi,
98
- *,
99
- image: str,
100
- command: Sequence[str],
101
- env: dict[str, str],
102
- secrets: dict[str, str],
103
- flavor: Any,
104
- timeout: str,
105
- token: str,
106
- namespace: str,
107
- ):
108
- for attempt in range(1, JOB_SUBMIT_RETRIES + 1):
109
- try:
110
- return api.run_job(
111
- image=image,
112
- command=list(command),
113
- env=env,
114
- secrets=secrets,
115
- flavor=flavor,
116
- timeout=timeout,
117
- namespace=namespace,
118
- token=token,
119
- )
120
- except HfHubHTTPError as exc:
121
- status = _http_status(exc)
122
- if status is None or status < 500 or status >= 600:
123
- raise
124
- if attempt >= JOB_SUBMIT_RETRIES:
125
- raise SystemExit(
126
- f'HF job submit failed after {JOB_SUBMIT_RETRIES} attempts '
127
- f'(http {status}); failing fast'
128
- ) from exc
129
- time.sleep(JOB_SUBMIT_RETRY_BASE_S * attempt)
130
- raise SystemExit('HF job submit failed unexpectedly; failing fast')
131
-
132
-
133
- def fetch_space_build_log_tail(
134
- api: HfApi,
135
- repo_id: str,
136
- token: str,
137
- *,
138
- limit: int = BUILD_LOG_TAIL_LINES,
139
- ) -> str:
140
- try:
141
- lines = list(api.fetch_space_logs(repo_id, build=True, follow=False, token=token))
142
- except Exception as exc:
143
- return f'[space-build-log] failed to fetch build logs: {exc!r}'
144
- tail = lines[-limit:]
145
- text = ''.join(tail)
146
- if text and not text.endswith('\n'):
147
- text += '\n'
148
- return text or '[space-build-log] no buffered build logs available\n'
149
-
150
-
151
- def sync_overlay_from_repo() -> None:
152
- """Refresh Space overlay with required project files."""
153
- overlay = IMAGE_DIR / 'overlay'
154
- overlay.mkdir(parents=True, exist_ok=True)
155
-
156
- for child in overlay.iterdir():
157
- if child.is_dir():
158
- shutil.rmtree(child)
159
- else:
160
- child.unlink()
161
-
162
- include_paths = [
163
- 'hydra',
164
- 'subsystems',
165
- 'scripts',
166
- 'data',
167
- 'htm_rust',
168
- 'harness',
169
- 'configs',
170
- 'artifacts',
171
- 'prepare.py',
172
- 'prepare_nemotron.py',
173
- 'train.py',
174
- 'pyproject.toml',
175
- 'uv.lock',
176
- ]
177
- ignore = shutil.ignore_patterns(
178
- '__pycache__',
179
- '.pytest_cache',
180
- '.ruff_cache',
181
- '.venv',
182
- '.git',
183
- 'target',
184
- '*.pyc',
185
- '_tmp*',
186
- 'cycle1a_runs',
187
- 'cycle1a_probe',
188
- 'remote_benchmark_submission*.json',
189
- )
190
-
191
- copied: list[str] = []
192
- for rel in include_paths:
193
- src = REPO_ROOT / rel
194
- dst = overlay / rel
195
- if not src.exists():
196
- continue
197
- if src.is_dir():
198
- shutil.copytree(src, dst, dirs_exist_ok=True, ignore=ignore)
199
- else:
200
- dst.parent.mkdir(parents=True, exist_ok=True)
201
- shutil.copy2(src, dst)
202
- copied.append(rel)
203
-
204
- scripts_dir = overlay / 'scripts'
205
- if scripts_dir.exists():
206
- for sh_path in scripts_dir.rglob('*.sh'):
207
- data = sh_path.read_bytes()
208
- data = data.replace(b'\r\n', b'\n').replace(b'\r', b'\n')
209
- sh_path.write_bytes(data)
210
-
211
- print(f'[launch] overlay synced from repo ({len(copied)} paths): {copied}', flush=True)
212
-
213
-
214
- def require_token() -> str:
215
- token = os.environ.get('HF_TOKEN') or get_token()
216
- if not token:
217
- raise SystemExit('HF_TOKEN must be set or cached via huggingface-cli login for launch_feather_hf_job.py')
218
- return token
219
-
220
-
221
- def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800) -> None:
222
  """Wait until the Space image has been built.
223
 
224
  We use the Space purely as a container-image builder for HF Jobs. The Space
@@ -233,134 +227,134 @@ def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800)
233
  and APP_STARTING_ERROR after a successful BUILDING→APP_STARTING transition
234
  are acceptable — the image exists in the registry and Jobs can use it.
235
  """
236
- start = time.time()
237
- seen_build_completion = False
238
- seen_building = False
239
- while True:
240
- try:
241
- runtime = api.get_space_runtime(repo_id, token=token)
242
- except httpx.TransportError as exc:
243
- if time.time() - start > timeout_s:
244
- raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with network errors') from exc
245
- print(f'[space] transient runtime endpoint network error: {exc!r}; retrying', flush=True)
246
- time.sleep(20)
247
- continue
248
- except HfHubHTTPError as exc:
249
- status = _http_status(exc)
250
- if status is not None and 500 <= status < 600:
251
- if time.time() - start > timeout_s:
252
- raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with HTTP {status}') from exc
253
- time.sleep(20)
254
- continue
255
- raise
256
- stage = getattr(runtime, 'stage', None)
257
- hardware = getattr(runtime, 'hardware', None)
258
- err = getattr(runtime, 'errorMessage', None) or getattr(runtime, 'error_message', None)
259
- print(f'[space] stage={stage} hardware={hardware}', flush=True)
260
- if stage == 'BUILDING':
261
- seen_building = True
262
- if stage in {'APP_STARTING', 'RUNNING', 'PAUSED', 'SLEEPING'}:
263
- seen_build_completion = True
264
- if stage in {'RUNNING', 'PAUSED', 'SLEEPING'}:
265
- return
266
- # Image is built — Jobs can use it regardless of Space boot outcome.
267
- if (seen_build_completion or seen_building) and stage in {'RUNTIME_ERROR', 'APP_STARTING_ERROR'}:
268
- print(f'[space] Space boot failed with {stage} but built image is '
269
- f'available in the Space registry and is usable by HF Jobs.',
270
- flush=True)
271
- return
272
- # Hard build failures — no image was produced.
273
- if stage in {'BUILD_ERROR', 'CONFIG_ERROR', 'NO_APP_FILE'}:
274
- build_log_tail = fetch_space_build_log_tail(api, repo_id, token)
275
- raise RuntimeError(
276
- f'Space {repo_id} build failed: stage={stage} error={err!r}\n'
277
- f'--- Space build log tail ---\n{build_log_tail}'
278
- )
279
  if time.time() - start > timeout_s:
280
  raise TimeoutError(f'Space {repo_id} did not become ready in {timeout_s}s (last stage={stage})')
281
  time.sleep(20)
282
 
283
 
284
- def main() -> int:
285
- token = require_token()
286
- routing = resolve_routing(token=token)
287
- api = HfApi(token=token)
288
- secondary_gates = HarnessConfig().to_secondary_gates()
289
-
290
- print(f'[launch] image_dir={IMAGE_DIR}', flush=True)
291
- print(f'[launch] owner={routing.owner}', flush=True)
292
- print(f'[launch] space_repo={routing.space_repo}', flush=True)
293
- print(f'[launch] output_repo={routing.output_repo}', flush=True)
294
- print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
295
- print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
296
- print(f'[launch] flavor={JOB_FLAVOR}', flush=True)
297
- print(f'[launch] namespace={routing.job_namespace}', flush=True)
298
- print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
299
- print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
300
- if not USE_SPACE_IMAGE:
301
- print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
302
-
303
- api.create_repo(repo_id=routing.space_repo, repo_type='space', space_sdk='docker', private=True, exist_ok=True, token=token)
304
- api.create_repo(repo_id=routing.output_repo, repo_type='model', private=True, exist_ok=True, token=token)
305
 
306
  if DRY_RUN:
307
  print('[launch] dry-run mode; skipping upload and job submission', flush=True)
308
  return 0
309
 
310
- image_ref = DEFAULT_IMAGE
311
- if USE_SPACE_IMAGE:
312
- if SKIP_UPLOAD:
313
- print('[launch] FEATHER_HF_SKIP_UPLOAD=1; reusing existing Space image', flush=True)
314
- else:
315
- if SYNC_OVERLAY:
316
- sync_overlay_from_repo()
317
- print('[launch] uploading custom Docker Space image context...', flush=True)
318
- api.upload_folder(
319
- repo_id=routing.space_repo,
320
- repo_type='space',
321
- folder_path=str(IMAGE_DIR),
322
- commit_message='Update Feather training runtime image',
323
- token=token,
324
- )
325
-
326
- print('[launch] waiting for Space image build to become ready...', flush=True)
327
- wait_for_space(api, routing.space_repo, token=token)
328
- image_ref = f'hf.co/spaces/{routing.space_repo}'
329
-
330
- env = {
331
- 'HF_REPO_ID': routing.output_repo,
332
- 'FEATHER_HF_OWNER': routing.owner,
333
- 'FEATHER_HF_SPACE_REPO': routing.space_repo,
334
- 'FEATHER_HF_OUTPUT_REPO': routing.output_repo,
335
- 'FEATHER_HF_RETINA_CACHE_REPO': routing.retina_cache_repo,
336
- 'HYDRA_RETINA_CACHE_REPO': routing.retina_cache_repo,
337
- 'HYDRA_TARGET_SHARDS': TARGET_SHARDS,
338
- 'HYDRA_TIME_BUDGET': TIME_BUDGET,
339
- 'HYDRA_DOWNLOAD_WORKERS': DOWNLOAD_WORKERS,
340
  'HYDRA_CKPT_INTERVAL': CKPT_INTERVAL,
341
  'PYTHONUNBUFFERED': '1',
342
- 'FEATHER_RUNTIME_MODE': 'job',
343
- }
344
- if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
345
- env['HYDRA_USE_NEMOTRON'] = '1'
346
- print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
347
- # A10 profile: default compatibility mode avoids known PTX/compile runtime
348
- # pitfalls, but HYDRA_THROUGHPUT_MODE=0 explicitly selects the full
349
- # SDR/HTM/Engram architecture instead of silently inheriting bypass defaults.
350
- _a10_profile = apply_a10_env_profile(env, job_flavor=JOB_FLAVOR)
351
- if _a10_profile is not None:
352
- if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
353
- env['HYDRA_FASTPATH'] = '0'
354
- print(
355
- f'[launch] applied A10 {_a10_profile} env profile '
356
- f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
357
- f"HYDRA_THROUGHPUT_MODE={env.get('HYDRA_THROUGHPUT_MODE', 'unset')}, "
358
- f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
359
- f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
360
- f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
361
- f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
362
- flush=True,
363
- )
364
  # Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
365
  # sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
366
  # HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
@@ -370,18 +364,18 @@ def main() -> int:
370
  env[_k] = _v
371
  secrets = {'HF_TOKEN': token}
372
 
373
- print(f'[launch] submitting HF Job on flavor={JOB_FLAVOR}...', flush=True)
374
- job = submit_job_with_retry(
375
- api,
376
- image=image_ref,
377
- command=['python', '/app/entrypoint.py'],
378
- env=env,
379
- secrets=secrets,
380
- flavor=cast(Any, JOB_FLAVOR),
381
- timeout=TIMEOUT,
382
- namespace=routing.job_namespace,
383
- token=token,
384
- )
385
  print(f'[launch] submitted job_id={job.id} status={job.status.stage} url={job.url}', flush=True)
386
  return 0
387
 
 
1
  #!/usr/bin/env python3
2
  from __future__ import annotations
3
 
4
+ import os
5
+ import shutil
6
+ import sys
7
+ import time
8
+ import json
9
+ from collections.abc import Mapping, Sequence
10
+ from typing import Any, cast
11
+ from pathlib import Path
12
+
13
+ import httpx
14
+ from huggingface_hub import HfApi
15
+ from huggingface_hub.utils import HfHubHTTPError
16
+ from huggingface_hub.utils import get_token
17
+
18
+ REPO_ROOT = Path(__file__).resolve().parents[1]
19
+ if str(REPO_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(REPO_ROOT))
21
+
22
+ from scripts.hf_routing import resolve_routing
23
+ from configs.harness_config import HarnessConfig
24
+
25
+ DEFAULT_IMAGE = os.environ.get('FEATHER_HF_IMAGE', 'ghcr.io/slapglif/feather-hf-runtime:latest')
26
+ IMAGE_DIR = Path(__file__).resolve().parents[1] / 'hf_jobs' / 'feather_h200_image'
27
+ TIMEOUT = os.environ.get('FEATHER_HF_JOB_TIMEOUT', '12h')
28
+ TARGET_SHARDS = os.environ.get('HYDRA_TARGET_SHARDS', '2048')
29
+ TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '43200')
30
+ DOWNLOAD_WORKERS = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '16')
31
+ CKPT_INTERVAL = os.environ.get('HYDRA_CKPT_INTERVAL', '1000')
32
+ JOB_FLAVOR = os.environ.get('FEATHER_HF_FLAVOR', 'a10g-small')
33
+ DRY_RUN = os.environ.get('FEATHER_HF_DRY_RUN', '0') == '1'
34
+ USE_SPACE_IMAGE = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
35
  # When true, assume the Space image has already been built by a previous
36
  # invocation and skip the upload+build wait. Used by sweep drivers that fan
37
  # out many jobs against a single pre-uploaded image.
38
+ SKIP_UPLOAD = os.environ.get('FEATHER_HF_SKIP_UPLOAD', '0') == '1'
39
+ SYNC_OVERLAY = os.environ.get('FEATHER_HF_SYNC_OVERLAY', '1') == '1'
40
+ JOB_SUBMIT_RETRIES = max(1, int(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRIES', '3')))
41
+ JOB_SUBMIT_RETRY_BASE_S = float(os.environ.get('FEATHER_HF_JOB_SUBMIT_RETRY_BASE_S', '5'))
42
+ BUILD_LOG_TAIL_LINES = max(1, int(os.environ.get('FEATHER_HF_BUILD_LOG_TAIL_LINES', '120')))
43
+
44
+
45
+ def should_enable_fast_start_streaming(target_shards: str, time_budget: str) -> bool:
46
+ """Use streaming data path for short-budget launch profiles."""
47
+ try:
48
+ shards = int(target_shards)
49
+ budget = int(time_budget)
50
+ except ValueError:
51
+ return False
52
+ return shards > 0 and shards <= 256 and budget > 0 and budget <= 1800
53
+
54
+
55
+ def apply_a10_env_profile(
56
+ env: dict[str, str],
57
+ *,
58
+ job_flavor: str,
59
+ parent_env: Mapping[str, str] = os.environ,
60
+ ) -> str | None:
61
+ if not job_flavor.startswith('a10'):
62
+ return None
63
+
64
+ full_arch = parent_env.get('HYDRA_THROUGHPUT_MODE') == '0' or env.get('HYDRA_THROUGHPUT_MODE') == '0'
65
+ if full_arch:
66
+ defaults = {
67
+ 'HYDRA_THROUGHPUT_MODE': '0',
68
+ 'HYDRA_MUON_COMPILE': '0',
69
+ 'HYDRA_FORCE_HTM_CPU': '0',
70
+ 'HYDRA_INERT_MAMBA': '0',
71
+ 'HYDRA_ALLOW_SYNTHETIC_RETINA': '0',
72
+ 'HYDRA_FASTPATH': '0',
73
+ }
74
+ profile = 'full-architecture'
75
+ else:
76
+ defaults = {
77
+ 'HYDRA_MUON_COMPILE': '0',
78
+ 'HYDRA_FORCE_HTM_CPU': '1',
79
+ 'HYDRA_INERT_MAMBA': '1',
80
+ 'HYDRA_ALLOW_SYNTHETIC_RETINA': '1',
81
+ 'HYDRA_FASTPATH': '1',
82
+ }
83
+ profile = 'compatibility'
84
+
85
+ for key, default in defaults.items():
86
+ env[key] = parent_env.get(key, env.get(key, default))
87
+ return profile
88
+
89
+
90
+ def _http_status(exc: BaseException) -> int | None:
91
+ response = getattr(exc, 'response', None)
92
+ status = getattr(response, 'status_code', None)
93
+ return int(status) if isinstance(status, int) else None
94
+
95
+
96
+ def submit_job_with_retry(
97
+ api: HfApi,
98
+ *,
99
+ image: str,
100
+ command: Sequence[str],
101
+ env: dict[str, str],
102
+ secrets: dict[str, str],
103
+ flavor: Any,
104
+ timeout: str,
105
+ token: str,
106
+ namespace: str,
107
+ ):
108
+ for attempt in range(1, JOB_SUBMIT_RETRIES + 1):
109
+ try:
110
+ return api.run_job(
111
+ image=image,
112
+ command=list(command),
113
+ env=env,
114
+ secrets=secrets,
115
+ flavor=flavor,
116
+ timeout=timeout,
117
+ namespace=namespace,
118
+ token=token,
119
+ )
120
+ except HfHubHTTPError as exc:
121
+ status = _http_status(exc)
122
+ if status is None or status < 500 or status >= 600:
123
+ raise
124
+ if attempt >= JOB_SUBMIT_RETRIES:
125
+ raise SystemExit(
126
+ f'HF job submit failed after {JOB_SUBMIT_RETRIES} attempts '
127
+ f'(http {status}); failing fast'
128
+ ) from exc
129
+ time.sleep(JOB_SUBMIT_RETRY_BASE_S * attempt)
130
+ raise SystemExit('HF job submit failed unexpectedly; failing fast')
131
+
132
+
133
+ def fetch_space_build_log_tail(
134
+ api: HfApi,
135
+ repo_id: str,
136
+ token: str,
137
+ *,
138
+ limit: int = BUILD_LOG_TAIL_LINES,
139
+ ) -> str:
140
+ try:
141
+ lines = list(api.fetch_space_logs(repo_id, build=True, follow=False, token=token))
142
+ except Exception as exc:
143
+ return f'[space-build-log] failed to fetch build logs: {exc!r}'
144
+ tail = lines[-limit:]
145
+ text = ''.join(tail)
146
+ if text and not text.endswith('\n'):
147
+ text += '\n'
148
+ return text or '[space-build-log] no buffered build logs available\n'
149
+
150
+
151
+ def sync_overlay_from_repo() -> None:
152
+ """Refresh Space overlay with required project files."""
153
+ overlay = IMAGE_DIR / 'overlay'
154
+ overlay.mkdir(parents=True, exist_ok=True)
155
+
156
+ for child in overlay.iterdir():
157
+ if child.is_dir():
158
+ shutil.rmtree(child)
159
+ else:
160
+ child.unlink()
161
+
162
+ include_paths = [
163
+ 'hydra',
164
+ 'subsystems',
165
+ 'scripts',
166
+ 'htm_rust',
167
+ 'harness',
168
+ 'configs',
169
+ 'prepare.py',
170
+ 'prepare_nemotron.py',
171
+ 'train.py',
172
+ 'pyproject.toml',
173
+ 'uv.lock',
174
+ ]
175
+ ignore = shutil.ignore_patterns(
176
+ '__pycache__',
177
+ '.pytest_cache',
178
+ '.ruff_cache',
179
+ '.venv',
180
+ '.git',
181
+ 'target',
182
+ '*.pyc',
183
+ )
184
+
185
+ copied: list[str] = []
186
+ for rel in include_paths:
187
+ src = REPO_ROOT / rel
188
+ dst = overlay / rel
189
+ if not src.exists():
190
+ continue
191
+ if src.is_dir():
192
+ shutil.copytree(src, dst, dirs_exist_ok=True, ignore=ignore)
193
+ else:
194
+ dst.parent.mkdir(parents=True, exist_ok=True)
195
+ shutil.copy2(src, dst)
196
+ copied.append(rel)
197
+
198
+ scripts_dir = overlay / 'scripts'
199
+ if scripts_dir.exists():
200
+ for sh_path in scripts_dir.rglob('*.sh'):
201
+ data = sh_path.read_bytes()
202
+ data = data.replace(b'\r\n', b'\n').replace(b'\r', b'\n')
203
+ sh_path.write_bytes(data)
204
+
205
+ print(f'[launch] overlay synced from repo ({len(copied)} paths): {copied}', flush=True)
206
+
207
+
208
+ def require_token() -> str:
209
+ token = os.environ.get('HF_TOKEN') or get_token()
210
+ if not token:
211
+ raise SystemExit('HF_TOKEN must be set or cached via huggingface-cli login for launch_feather_hf_job.py')
212
+ return token
213
+
214
+
215
+ def wait_for_space(api: HfApi, repo_id: str, token: str, timeout_s: int = 1800) -> None:
 
 
 
 
 
 
216
  """Wait until the Space image has been built.
217
 
218
  We use the Space purely as a container-image builder for HF Jobs. The Space
 
227
  and APP_STARTING_ERROR after a successful BUILDING→APP_STARTING transition
228
  are acceptable — the image exists in the registry and Jobs can use it.
229
  """
230
+ start = time.time()
231
+ seen_build_completion = False
232
+ seen_building = False
233
+ while True:
234
+ try:
235
+ runtime = api.get_space_runtime(repo_id, token=token)
236
+ except httpx.TransportError as exc:
237
+ if time.time() - start > timeout_s:
238
+ raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with network errors') from exc
239
+ print(f'[space] transient runtime endpoint network error: {exc!r}; retrying', flush=True)
240
+ time.sleep(20)
241
+ continue
242
+ except HfHubHTTPError as exc:
243
+ status = _http_status(exc)
244
+ if status is not None and 500 <= status < 600:
245
+ if time.time() - start > timeout_s:
246
+ raise TimeoutError(f'Space {repo_id} runtime endpoint kept failing with HTTP {status}') from exc
247
+ time.sleep(20)
248
+ continue
249
+ raise
250
+ stage = getattr(runtime, 'stage', None)
251
+ hardware = getattr(runtime, 'hardware', None)
252
+ err = getattr(runtime, 'errorMessage', None) or getattr(runtime, 'error_message', None)
253
+ print(f'[space] stage={stage} hardware={hardware}', flush=True)
254
+ if stage == 'BUILDING':
255
+ seen_building = True
256
+ if stage in {'APP_STARTING', 'RUNNING', 'PAUSED', 'SLEEPING'}:
257
+ seen_build_completion = True
258
+ if stage in {'RUNNING', 'PAUSED', 'SLEEPING'}:
259
+ return
260
+ # Image is built — Jobs can use it regardless of Space boot outcome.
261
+ if (seen_build_completion or seen_building) and stage in {'RUNTIME_ERROR', 'APP_STARTING_ERROR'}:
262
+ print(f'[space] Space boot failed with {stage} but built image is '
263
+ f'available in the Space registry and is usable by HF Jobs.',
264
+ flush=True)
265
+ return
266
+ # Hard build failures — no image was produced.
267
+ if stage in {'BUILD_ERROR', 'CONFIG_ERROR', 'NO_APP_FILE'}:
268
+ build_log_tail = fetch_space_build_log_tail(api, repo_id, token)
269
+ raise RuntimeError(
270
+ f'Space {repo_id} build failed: stage={stage} error={err!r}\n'
271
+ f'--- Space build log tail ---\n{build_log_tail}'
272
+ )
273
  if time.time() - start > timeout_s:
274
  raise TimeoutError(f'Space {repo_id} did not become ready in {timeout_s}s (last stage={stage})')
275
  time.sleep(20)
276
 
277
 
278
+ def main() -> int:
279
+ token = require_token()
280
+ routing = resolve_routing(token=token)
281
+ api = HfApi(token=token)
282
+ secondary_gates = HarnessConfig().to_secondary_gates()
283
+
284
+ print(f'[launch] image_dir={IMAGE_DIR}', flush=True)
285
+ print(f'[launch] owner={routing.owner}', flush=True)
286
+ print(f'[launch] space_repo={routing.space_repo}', flush=True)
287
+ print(f'[launch] output_repo={routing.output_repo}', flush=True)
288
+ print(f'[launch] retina_cache_repo={routing.retina_cache_repo}', flush=True)
289
+ print(f'[launch] target_shards={TARGET_SHARDS} time_budget={TIME_BUDGET} timeout={TIMEOUT}', flush=True)
290
+ print(f'[launch] flavor={JOB_FLAVOR}', flush=True)
291
+ print(f'[launch] namespace={routing.job_namespace}', flush=True)
292
+ print(f'[launch] image_mode={"space" if USE_SPACE_IMAGE else "ghcr"}', flush=True)
293
+ print(f'[launch] secondary_gates={json.dumps(secondary_gates, sort_keys=True)}', flush=True)
294
+ if not USE_SPACE_IMAGE:
295
+ print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
296
+
297
+ api.create_repo(repo_id=routing.space_repo, repo_type='space', space_sdk='docker', private=True, exist_ok=True, token=token)
298
+ api.create_repo(repo_id=routing.output_repo, repo_type='model', private=True, exist_ok=True, token=token)
299
 
300
  if DRY_RUN:
301
  print('[launch] dry-run mode; skipping upload and job submission', flush=True)
302
  return 0
303
 
304
+ image_ref = DEFAULT_IMAGE
305
+ if USE_SPACE_IMAGE:
306
+ if SKIP_UPLOAD:
307
+ print('[launch] FEATHER_HF_SKIP_UPLOAD=1; reusing existing Space image', flush=True)
308
+ else:
309
+ if SYNC_OVERLAY:
310
+ sync_overlay_from_repo()
311
+ print('[launch] uploading custom Docker Space image context...', flush=True)
312
+ api.upload_folder(
313
+ repo_id=routing.space_repo,
314
+ repo_type='space',
315
+ folder_path=str(IMAGE_DIR),
316
+ commit_message='Update Feather training runtime image',
317
+ token=token,
318
+ )
319
+
320
+ print('[launch] waiting for Space image build to become ready...', flush=True)
321
+ wait_for_space(api, routing.space_repo, token=token)
322
+ image_ref = f'hf.co/spaces/{routing.space_repo}'
323
+
324
+ env = {
325
+ 'HF_REPO_ID': routing.output_repo,
326
+ 'FEATHER_HF_OWNER': routing.owner,
327
+ 'FEATHER_HF_SPACE_REPO': routing.space_repo,
328
+ 'FEATHER_HF_OUTPUT_REPO': routing.output_repo,
329
+ 'FEATHER_HF_RETINA_CACHE_REPO': routing.retina_cache_repo,
330
+ 'HYDRA_RETINA_CACHE_REPO': routing.retina_cache_repo,
331
+ 'HYDRA_TARGET_SHARDS': TARGET_SHARDS,
332
+ 'HYDRA_TIME_BUDGET': TIME_BUDGET,
333
+ 'HYDRA_DOWNLOAD_WORKERS': DOWNLOAD_WORKERS,
334
  'HYDRA_CKPT_INTERVAL': CKPT_INTERVAL,
335
  'PYTHONUNBUFFERED': '1',
336
+ 'FEATHER_RUNTIME_MODE': 'job',
337
+ }
338
+ if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
339
+ env['HYDRA_USE_NEMOTRON'] = '1'
340
+ print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
341
+ # A10 profile: default compatibility mode avoids known PTX/compile runtime
342
+ # pitfalls, but HYDRA_THROUGHPUT_MODE=0 explicitly selects the full
343
+ # SDR/HTM/Engram architecture instead of silently inheriting bypass defaults.
344
+ _a10_profile = apply_a10_env_profile(env, job_flavor=JOB_FLAVOR)
345
+ if _a10_profile is not None:
346
+ if env.get('HYDRA_INERT_MAMBA') == '0' and 'HYDRA_FASTPATH' not in os.environ:
347
+ env['HYDRA_FASTPATH'] = '0'
348
+ print(
349
+ f'[launch] applied A10 {_a10_profile} env profile '
350
+ f"(HYDRA_MUON_COMPILE={env['HYDRA_MUON_COMPILE']}, "
351
+ f"HYDRA_THROUGHPUT_MODE={env.get('HYDRA_THROUGHPUT_MODE', 'unset')}, "
352
+ f"HYDRA_FORCE_HTM_CPU={env['HYDRA_FORCE_HTM_CPU']}, "
353
+ f"HYDRA_INERT_MAMBA={env['HYDRA_INERT_MAMBA']}, "
354
+ f"HYDRA_ALLOW_SYNTHETIC_RETINA={env['HYDRA_ALLOW_SYNTHETIC_RETINA']}, "
355
+ f"HYDRA_FASTPATH={env['HYDRA_FASTPATH']})",
356
+ flush=True,
357
+ )
358
  # Pass through any HYDRA_* / FEATHER_* overrides from the caller's env so
359
  # sweep drivers can set HYDRA_N_LAYER, HYDRA_SDR_TARGET_ACTIVE,
360
  # HYDRA_LAYER_DIAGNOSTICS, HYDRA_METRICS_OUT, HYDRA_MID_VAL_INTERVAL, etc.
 
364
  env[_k] = _v
365
  secrets = {'HF_TOKEN': token}
366
 
367
+ print(f'[launch] submitting HF Job on flavor={JOB_FLAVOR}...', flush=True)
368
+ job = submit_job_with_retry(
369
+ api,
370
+ image=image_ref,
371
+ command=['python', '/app/entrypoint.py'],
372
+ env=env,
373
+ secrets=secrets,
374
+ flavor=cast(Any, JOB_FLAVOR),
375
+ timeout=TIMEOUT,
376
+ namespace=routing.job_namespace,
377
+ token=token,
378
+ )
379
  print(f'[launch] submitted job_id={job.id} status={job.status.stage} url={job.url}', flush=True)
380
  return 0
381
 
overlay/scripts/optuna_hpo.py CHANGED
@@ -5,131 +5,131 @@ import argparse
5
  import json
6
  import os
7
  import re
8
- import subprocess
9
- import sys
10
- import time
11
- import tempfile
12
- from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
13
- from pathlib import Path
14
- from typing import Any
15
-
16
- import optuna
17
-
18
-
19
- _HF_ENV_KEY_RE = re.compile(r"^[A-Z][A-Z0-9_]*$")
20
-
21
-
22
- REPO_ROOT = Path(__file__).resolve().parents[1]
23
- if str(REPO_ROOT) not in sys.path:
24
- sys.path.insert(0, str(REPO_ROOT))
25
-
26
- from scripts.hf_routing import resolve_routing
27
-
28
- TRAIN_ENTRYPOINT = REPO_ROOT / "train.py"
29
- SEARCH_SPACE_KEYS = {
30
- "d_model",
31
- "n_layer",
32
- "d_state",
33
- "headdim",
34
- "expand",
35
- "seq_len",
36
- "batch_size",
37
- "grad_accum",
38
- "matrix_lr",
39
- "embed_lr",
40
- "unembed_lr",
41
- "hyena_layers",
42
- "engram_n_columns",
43
- "engram_layer_idx",
44
- "sdr_target_active",
45
- "htm_learn_every",
46
- "htm_subsample",
47
- "engram_subsample",
48
- "mamba3_chunk",
49
- "dropout",
50
- }
51
-
52
-
53
- def _filter_prior_params(raw: dict[str, Any]) -> dict[str, Any]:
54
- return {k: v for k, v in raw.items() if k in SEARCH_SPACE_KEYS}
55
-
56
-
57
- def _load_prior_param_sets(path: Path) -> list[dict[str, Any]]:
58
- if not path.exists():
59
- return []
60
-
61
- payload = json.loads(path.read_text(encoding="utf-8"))
62
- if isinstance(payload, dict):
63
- rows = payload.get("trials", [])
64
- elif isinstance(payload, list):
65
- rows = payload
66
- else:
67
- rows = []
68
-
69
- out: list[dict[str, Any]] = []
70
- for item in rows:
71
- if not isinstance(item, dict):
72
- continue
73
- params_obj = item.get("params", item)
74
- if not isinstance(params_obj, dict):
75
- continue
76
- filtered = _filter_prior_params(params_obj)
77
- if filtered:
78
- out.append(filtered)
79
- return out
80
-
81
-
82
- def _enqueue_transfer_priors(study: optuna.Study, priors_file: Path, apply_priors: bool) -> int:
83
- if not apply_priors:
84
- return 0
85
-
86
- priors_raw = _load_prior_param_sets(priors_file)
87
- if not priors_raw:
88
- return 0
89
-
90
- # Deduplicate param sets across merged studies.
91
- priors: list[dict[str, Any]] = []
92
- seen: set[str] = set()
93
- for params in priors_raw:
94
- key = json.dumps(params, sort_keys=True)
95
- if key in seen:
96
- continue
97
- seen.add(key)
98
- priors.append(params)
99
-
100
- enqueued = 0
101
- for params in priors:
102
- before = len(study.get_trials(deepcopy=False))
103
- try:
104
- study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"}, skip_if_exists=True)
105
- except TypeError:
106
- study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"})
107
- after = len(study.get_trials(deepcopy=False))
108
- if after > before:
109
- enqueued += 1
110
- return enqueued
111
-
112
-
113
- def _enqueue_quality_anchors(study: optuna.Study, priors_file: Path, quality_mode_local: bool, top_k: int) -> int:
114
- if not quality_mode_local or top_k <= 0:
115
- return 0
116
-
117
- priors = _load_prior_param_sets(priors_file)[:top_k]
118
- enqueued = 0
119
- for params in priors:
120
- before = len(study.get_trials(deepcopy=False))
121
- try:
122
- study.enqueue_trial(
123
- params,
124
- user_attrs={"seed_source": "quality_anchor"},
125
- skip_if_exists=True,
126
- )
127
- except TypeError:
128
- study.enqueue_trial(params, user_attrs={"seed_source": "quality_anchor"})
129
- after = len(study.get_trials(deepcopy=False))
130
- if after > before:
131
- enqueued += 1
132
- return enqueued
133
 
134
 
135
  def _parse_metrics_from_stdout(stdout: str) -> dict[str, Any] | None:
@@ -164,241 +164,241 @@ def _parse_metrics_from_log_lines(lines: list[str]) -> dict[str, Any] | None:
164
  return None
165
 
166
 
167
- def _parse_last_train_bpb_from_logs(lines: list[str]) -> float | None:
168
- """Best-effort fallback when final eval crashes before metrics JSON write."""
169
- last: float | None = None
170
- for line in lines:
171
- m = re.search(r"\bbpb=([0-9]+(?:\.[0-9]+)?)", line)
172
  if m:
173
- last = float(m.group(1))
174
- return last
175
-
176
-
177
- def _persist_trial_artifacts(
178
- *,
179
- trial_dir: Path,
180
- metrics: dict[str, Any] | None,
181
- log_lines: list[str] | None,
182
- log_name: str,
183
- metadata: dict[str, Any],
184
- ) -> dict[str, str | None]:
185
- trial_dir.mkdir(parents=True, exist_ok=True)
186
- metrics_path = trial_dir / "metrics.json"
187
- log_path = trial_dir / log_name
188
- manifest_path = trial_dir / "trial_artifacts.json"
189
-
190
- if metrics is not None:
191
- metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True), encoding="utf-8")
192
- if log_lines is not None:
193
- log_path.write_text("\n".join(log_lines), encoding="utf-8")
194
-
195
- manifest = {
196
- **metadata,
197
- "metrics_path": str(metrics_path) if metrics is not None else None,
198
- "log_path": str(log_path) if log_lines is not None else None,
199
- }
200
- manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
201
- return {
202
- "metrics_path": str(metrics_path) if metrics is not None else None,
203
- "log_path": str(log_path) if log_lines is not None else None,
204
- "manifest_path": str(manifest_path),
205
- }
206
-
207
-
208
- def _resolve_objective_metric(
209
- trial: optuna.Trial,
210
- *,
211
- metric_key: str,
212
- metrics: dict[str, Any] | None,
213
- allow_log_metric_fallback: bool,
214
- fallback_bpb: float | None,
215
- tps_seen: float | None,
216
- ) -> float:
217
- """Resolve the objective value while labeling where it came from.
218
-
219
- Validation metrics and live training-log fallbacks are intentionally
220
- different sources. Keeping that distinction in trial attrs prevents a
221
- skipped/OOM eval from being mistaken for a real validation result.
222
- """
223
- if metrics is None:
224
- if allow_log_metric_fallback and metric_key == "val_bpb" and fallback_bpb is not None:
225
- trial.set_user_attr("objective_source", "train_log_fallback")
226
- trial.set_user_attr("objective_metric", "train_bpb")
227
- trial.set_user_attr("eval_status", "missing_metrics")
228
- trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
229
- if tps_seen is not None:
230
- trial.set_user_attr("tps", float(tps_seen))
231
- return float(fallback_bpb)
232
- trial.set_user_attr("objective_source", "missing_metrics")
233
- raise optuna.TrialPruned("No metrics payload found")
234
-
235
- eval_status = str(
236
- metrics.get(
237
- "eval_status",
238
- "completed" if metrics.get("val_bpb") is not None else "unknown",
239
- )
240
- )
241
- trial.set_user_attr("eval_status", eval_status)
242
-
243
- if fallback_bpb is not None:
244
- trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
245
-
246
- if metric_key not in metrics or metrics[metric_key] is None:
247
- trial.set_user_attr("objective_source", "missing_metric")
248
- trial.set_user_attr("objective_metric", metric_key)
249
- raise optuna.TrialPruned(f"Metric '{metric_key}' missing in metrics payload")
250
-
251
- value = float(metrics[metric_key])
252
- trial.set_user_attr("objective_metric", metric_key)
253
- if metric_key == "val_bpb":
254
- trial.set_user_attr("objective_source", "final_val")
255
- trial.set_user_attr("final_val_bpb", value)
256
- else:
257
- trial.set_user_attr("objective_source", "metrics_json")
258
- return value
259
-
260
-
261
- def _fetch_job_logs_safe(
262
- api,
263
- *,
264
- job_id: str,
265
- token: str,
266
- namespace: str,
267
- retries: int = 3,
268
- sleep_s: float = 2.0,
269
- timeout_s: float = 20.0,
270
- ) -> list[str]:
271
- last_exc: Exception | None = None
272
- for attempt in range(1, retries + 1):
273
- try:
274
- with ThreadPoolExecutor(max_workers=1) as executor:
275
- future = executor.submit(
276
- lambda: list(api.fetch_job_logs(job_id=job_id, follow=False, token=token, namespace=namespace))
277
- )
278
- return future.result(timeout=timeout_s)
279
- except FuturesTimeoutError as exc:
280
- last_exc = TimeoutError(f"Timed out fetching HF job logs for {job_id} after {timeout_s:.1f}s")
281
- except Exception as exc: # noqa: BLE001
282
- last_exc = exc
283
- if attempt >= retries:
284
- raise
285
- time.sleep(sleep_s)
286
- if last_exc is not None:
287
- raise last_exc
288
- return []
289
-
290
-
291
- def _effective_min_tps(args: argparse.Namespace) -> float | None:
292
- min_tps = args.min_tps
293
- if getattr(args, "quality_mode_local", False) and min_tps == 50000.0:
294
- return 0.0
295
- return min_tps
296
-
297
-
298
- def _trial_env(trial: optuna.Trial, args: argparse.Namespace, metrics_path: Path) -> dict[str, str]:
299
- env = os.environ.copy()
300
- full_arch_hpo = env.get("HYDRA_HPO_FULL_ARCH", "0") == "1"
301
- speed_arch_hpo = full_arch_hpo and env.get("HYDRA_HPO_SPEED_ARCH", "0") == "1"
302
- quality_mode_local = bool(getattr(args, "quality_mode_local", False))
303
 
304
  # Runtime and reporting
305
  env["HYDRA_METRICS_OUT"] = str(metrics_path)
306
  env["HYDRA_TIME_BUDGET"] = str(args.trial_time_budget)
307
  env["PYTHONUNBUFFERED"] = "1"
308
 
309
- # Search space — fully env-driven to match existing training stack.
310
- if speed_arch_hpo:
311
- # Full-arch speed mode targets A10 underutilization observed in HPO:
312
- # low VRAM/MFU, strong BPB from shallow models, and fixed SDR/HTM
313
- # overhead dominating small microbatches. Keep all components enabled
314
- # while amortizing overhead over more tokens.
315
- env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96]))
316
- env["HYDRA_N_LAYER"] = str(trial.suggest_categorical("n_layer", [2]))
317
- env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
318
- env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
319
- env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
320
- elif quality_mode_local and full_arch_hpo:
321
- env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128]))
322
- env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 2, 3))
323
- env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
324
- env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
325
- env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
326
- else:
327
- env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128, 160, 192]))
328
- env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 1, 4))
329
- env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32, 48]))
330
- env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [8, 16, 32]))
331
- env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
332
-
333
- if speed_arch_hpo:
334
- seq_len = trial.suggest_categorical("seq_len", [64, 128])
335
- batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
336
- grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
337
- elif quality_mode_local and full_arch_hpo:
338
- seq_len = trial.suggest_categorical("seq_len", [64])
339
- batch_size = trial.suggest_categorical("batch_size", [4, 8])
340
- grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
341
- else:
342
- seq_len = trial.suggest_categorical("seq_len", [32, 64])
343
- batch_size = trial.suggest_categorical("batch_size", [4, 8] if full_arch_hpo else [4, 8, 16])
344
- grad_accum = trial.suggest_categorical("grad_accum", [1, 4, 8, 16] if full_arch_hpo else [8, 16, 32, 64])
345
  # Keep TOTAL_BATCH_SIZE divisible by DEVICE_BATCH_SIZE * MAX_SEQ_LEN.
346
  total_batch = batch_size * seq_len * grad_accum
347
  env["HYDRA_SEQ_LEN"] = str(seq_len)
348
  env["HYDRA_BATCH_SIZE"] = str(batch_size)
349
  env["HYDRA_TOTAL_BATCH"] = str(total_batch)
350
 
351
- if quality_mode_local and full_arch_hpo:
352
- env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.008, 0.03, log=True))
353
- env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.15, 0.6, log=True))
354
- env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.001, 0.01, log=True))
355
- else:
356
- env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.005, 0.2, log=True))
357
- env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.05, 1.0, log=True))
358
- env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.0005, 0.02, log=True))
359
-
360
- if full_arch_hpo:
361
- env["HYDRA_HYENA_LAYERS"] = ""
362
- env["HYDRA_ENGRAM_N_COLUMNS"] = str(
363
- trial.suggest_categorical(
364
- "engram_n_columns",
365
- [512, 1024] if (speed_arch_hpo or quality_mode_local) else [512, 1024, 2048],
366
- )
367
- )
368
- env["HYDRA_ENGRAM_LAYER_IDX"] = str(trial.suggest_int("engram_layer_idx", 0, max(0, int(env["HYDRA_N_LAYER"]) - 1)))
369
- env["HYDRA_SDR_TARGET_ACTIVE"] = str(
370
- trial.suggest_categorical(
371
- "sdr_target_active",
372
- [327] if quality_mode_local else ([164, 327] if speed_arch_hpo else [164, 327, 512]),
373
- )
374
- )
375
- env["HYDRA_HTM_LEARN_EVERY"] = str(
376
- trial.suggest_categorical("htm_learn_every", [8, 16] if (speed_arch_hpo or quality_mode_local) else [4, 8, 16])
377
- )
378
- env["HYDRA_HTM_SUBSAMPLE"] = str(
379
- trial.suggest_categorical("htm_subsample", [1, 2] if quality_mode_local else ([4, 8, 16] if speed_arch_hpo else [1, 2, 4, 8]))
380
- )
381
- env["HYDRA_ENGRAM_SUBSAMPLE"] = str(
382
- trial.suggest_categorical("engram_subsample", [1, 2] if quality_mode_local else ([1, 2, 4] if speed_arch_hpo else [1]))
383
- )
384
- env["HYDRA_MAMBA3_CHUNK"] = str(trial.suggest_categorical("mamba3_chunk", [32, 64]))
385
- env["HYDRA_DROPOUT"] = str(trial.suggest_categorical("dropout", [0.0, 0.1] if (speed_arch_hpo or quality_mode_local) else [0.0, 0.1, 0.2]))
386
- else:
387
- env["HYDRA_HYENA_LAYERS"] = trial.suggest_categorical("hyena_layers", ["", "0", "1", "0,1"])
388
 
389
  # Keep trials alive long enough to emit metrics.
390
  env["HYDRA_FAIL_LOSS_THRESHOLD"] = "1000000"
391
  env["HYDRA_USE_NEMOTRON"] = os.environ.get("HYDRA_USE_NEMOTRON", "1")
392
  env["HYDRA_LOCAL_SHARDS_ONLY"] = os.environ.get("HYDRA_LOCAL_SHARDS_ONLY", "0")
393
  # Strict optimal-path defaults (no forced fallback profile).
394
- env["HYDRA_MUON_COMPILE"] = os.environ.get("HYDRA_MUON_COMPILE", "1")
395
- env["HYDRA_THROUGHPUT_MODE"] = os.environ.get("HYDRA_THROUGHPUT_MODE", "0" if full_arch_hpo else "1")
396
- env["HYDRA_FORCE_HTM_CPU"] = os.environ.get("HYDRA_FORCE_HTM_CPU", "0")
397
- env["HYDRA_ALLOW_SYNTHETIC_RETINA"] = os.environ.get("HYDRA_ALLOW_SYNTHETIC_RETINA", "0")
398
- env["HYDRA_INERT_MAMBA"] = os.environ.get("HYDRA_INERT_MAMBA", "0")
399
- env["HYDRA_FASTPATH"] = os.environ.get("HYDRA_FASTPATH", "0" if full_arch_hpo else "1")
400
-
401
- return env
402
 
403
 
404
  def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
@@ -410,7 +410,7 @@ def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
410
  return sanitized
411
 
412
 
413
- def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
414
  if args.hf_use_bash:
415
  return [["bash", "-lc", args.hf_command]]
416
 
@@ -432,20 +432,20 @@ def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
432
  uniq.append(c)
433
  return uniq
434
 
435
- return [raw.split()]
436
-
437
-
438
- def _space_repo_from_hf_image(image: str, namespace: str) -> str:
439
- prefix = "hf.co/spaces/"
440
- if image.startswith(prefix):
441
- return image[len(prefix):]
442
- return os.environ.get("FEATHER_HF_SPACE_REPO", f"{namespace}/feather-a10-runtime")
443
 
 
 
 
 
 
444
 
445
- def _objective_local(args: argparse.Namespace):
446
- effective_min_tps = _effective_min_tps(args)
447
-
448
- def objective(trial: optuna.Trial) -> float:
 
449
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
450
  metrics_path = trial_dir / "metrics.json"
451
 
@@ -460,67 +460,67 @@ def _objective_local(args: argparse.Namespace):
460
  timeout=args.trial_timeout,
461
  )
462
 
463
- metrics: dict[str, Any] | None = None
464
  if metrics_path.exists():
465
  try:
466
  metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
467
  except json.JSONDecodeError:
468
  metrics = None
469
- if metrics is None:
470
- metrics = _parse_metrics_from_stdout(proc.stdout)
471
-
472
- artifact_paths = _persist_trial_artifacts(
473
- trial_dir=trial_dir,
474
- metrics=metrics,
475
- log_lines=(proc.stdout or "").splitlines(),
476
- log_name="train_stdout.log",
477
- metadata={"runner": "local", "returncode": proc.returncode},
478
- )
479
- (trial_dir / "train_stderr.log").write_text(proc.stderr or "", encoding="utf-8")
480
-
481
- fallback_bpb = _parse_last_train_bpb_from_logs(proc.stdout.splitlines())
482
- if metrics is None:
483
- _resolve_objective_metric(
484
- trial,
485
- metric_key=args.metric,
486
- metrics=None,
487
- allow_log_metric_fallback=args.allow_log_metric_fallback,
488
- fallback_bpb=fallback_bpb,
489
- tps_seen=None,
490
- )
491
- raise optuna.TrialPruned("No metrics found (HYDRA_METRICS_OUT/[METRICS_JSON])")
492
 
493
  if proc.returncode != 0:
494
  raise optuna.TrialPruned(f"Training failed rc={proc.returncode}")
495
 
496
- metric_key = args.metric
497
 
498
  tps_val = metrics.get("tps")
499
  if tps_val is not None:
500
  tps_f = float(tps_val)
501
  trial.set_user_attr("tps", tps_f)
502
- if effective_min_tps is not None and tps_f < effective_min_tps:
503
- raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
504
-
505
- value = _resolve_objective_metric(
506
- trial,
507
- metric_key=metric_key,
508
- metrics=metrics,
509
- allow_log_metric_fallback=args.allow_log_metric_fallback,
510
- fallback_bpb=fallback_bpb,
511
- tps_seen=None,
512
- )
513
-
514
- # Keep useful context on trial
515
- trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
516
- trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
517
-
518
- return value
519
 
520
  return objective
521
 
522
 
523
- def _objective_hf_job(args: argparse.Namespace):
524
  from huggingface_hub import HfApi
525
  from huggingface_hub.utils import get_token
526
 
@@ -530,9 +530,9 @@ def _objective_hf_job(args: argparse.Namespace):
530
  f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
531
  )
532
 
533
- api = HfApi(token=token)
534
- terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
535
- effective_min_tps = _effective_min_tps(args)
536
 
537
  def objective(trial: optuna.Trial) -> float:
538
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
@@ -568,14 +568,14 @@ def _objective_hf_job(args: argparse.Namespace):
568
  info = api.inspect_job(job_id=job.id, token=token, namespace=args.hf_namespace)
569
  bootstrap_stage = str(info.status.stage)
570
  bootstrap_msg = str(getattr(info.status, "message", "") or "")
571
- bootstrap_logs = _fetch_job_logs_safe(
572
- api,
573
- job_id=job.id,
574
- token=token,
575
- namespace=args.hf_namespace,
576
- retries=2,
577
- sleep_s=1.0,
578
- )
579
  if bootstrap_stage in {"RUNNING", "COMPLETED"} or bootstrap_logs:
580
  break
581
  if bootstrap_stage in {"ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}:
@@ -611,12 +611,12 @@ def _objective_hf_job(args: argparse.Namespace):
611
  info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
612
  stage = str(info.status.stage)
613
  terminal_detail = str(getattr(info.status, "message", "")) or terminal_detail
614
- log_lines = _fetch_job_logs_safe(
615
- api,
616
- job_id=job_id,
617
- token=token,
618
- namespace=args.hf_namespace,
619
- )
620
 
621
  m = _parse_metrics_from_log_lines(log_lines)
622
  if m is not None:
@@ -643,66 +643,66 @@ def _objective_hf_job(args: argparse.Namespace):
643
  except Exception:
644
  pass
645
 
646
- artifact_paths = _persist_trial_artifacts(
647
- trial_dir=trial_dir,
648
- metrics=metrics,
649
- log_lines=log_lines,
650
- log_name="hf_job.log",
651
- metadata={"runner": "hf-job", "hf_job_id": job_id, "hf_stage": stage},
652
- )
653
- trial.set_user_attr("hf_stage", stage)
654
- trial.set_user_attr("hf_log_lines", len(log_lines))
655
  if terminal_detail:
656
  trial.set_user_attr("hf_status_message", terminal_detail)
657
 
658
- fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
659
- if metrics is None:
660
- try:
661
- value = _resolve_objective_metric(
662
- trial,
663
- metric_key=args.metric,
664
- metrics=None,
665
- allow_log_metric_fallback=args.allow_log_metric_fallback,
666
- fallback_bpb=fallback_bpb,
667
- tps_seen=tps_seen,
668
- )
669
- if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
670
- raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
671
- return value
672
- except optuna.TrialPruned:
673
- pass
674
- if tps_seen is not None:
675
- trial.set_user_attr("tps", tps_seen)
676
- detail = f"stage={stage}, logs={len(log_lines)}"
677
- if terminal_detail:
678
- detail = f"{detail}, message={terminal_detail}"
679
  raise optuna.TrialPruned(f"No metrics found from HF job ({detail})")
680
 
681
- metric_key = args.metric
682
 
683
  tps_val = metrics.get("tps")
684
  if tps_val is not None:
685
  tps_f = float(tps_val)
686
  trial.set_user_attr("tps", tps_f)
687
- if effective_min_tps is not None and tps_f < effective_min_tps:
688
- raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
689
-
690
- value = _resolve_objective_metric(
691
- trial,
692
- metric_key=metric_key,
693
- metrics=metrics,
694
- allow_log_metric_fallback=args.allow_log_metric_fallback,
695
- fallback_bpb=fallback_bpb,
696
- tps_seen=tps_seen,
697
- )
698
- trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
699
- trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
700
- return value
701
 
702
  return objective
703
 
704
 
705
- def _objective_hf_launcher(args: argparse.Namespace):
706
  from huggingface_hub import HfApi
707
  from huggingface_hub.utils import get_token
708
 
@@ -712,9 +712,9 @@ def _objective_hf_launcher(args: argparse.Namespace):
712
  f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
713
  )
714
 
715
- api = HfApi(token=token)
716
- terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
717
- effective_min_tps = _effective_min_tps(args)
718
 
719
  def objective(trial: optuna.Trial) -> float:
720
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
@@ -725,11 +725,11 @@ def _objective_hf_launcher(args: argparse.Namespace):
725
  local_env = os.environ.copy()
726
  local_env.update(env)
727
  local_env[args.hf_token_env] = token
728
- local_env["FEATHER_HF_NAMESPACE"] = args.hf_namespace
729
- local_env["FEATHER_HF_FLAVOR"] = args.hf_flavor
730
- local_env["FEATHER_HF_JOB_TIMEOUT"] = args.hf_timeout
731
- local_env["FEATHER_HF_IMAGE"] = args.hf_image
732
- local_env["FEATHER_HF_SPACE_REPO"] = _space_repo_from_hf_image(args.hf_image, args.hf_namespace)
733
  if args.hf_output_repo:
734
  local_env["FEATHER_HF_OUTPUT_REPO"] = args.hf_output_repo
735
  else:
@@ -766,12 +766,12 @@ def _objective_hf_launcher(args: argparse.Namespace):
766
  info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
767
  stage = str(info.status.stage)
768
  terminal_detail = str(getattr(info.status, "message", "") or "") or terminal_detail
769
- log_lines = _fetch_job_logs_safe(
770
- api,
771
- job_id=job_id,
772
- token=token,
773
- namespace=args.hf_namespace,
774
- )
775
 
776
  mtr = _parse_metrics_from_log_lines(log_lines)
777
  if mtr is not None:
@@ -796,85 +796,85 @@ def _objective_hf_launcher(args: argparse.Namespace):
796
  except Exception:
797
  pass
798
 
799
- artifact_paths = _persist_trial_artifacts(
800
- trial_dir=trial_dir,
801
- metrics=metrics,
802
- log_lines=log_lines,
803
- log_name="hf_job.log",
804
- metadata={"runner": "hf-launcher", "hf_job_id": job_id, "hf_stage": stage},
805
- )
806
- trial.set_user_attr("hf_stage", stage)
807
- trial.set_user_attr("hf_log_lines", len(log_lines))
808
  if terminal_detail:
809
  trial.set_user_attr("hf_status_message", terminal_detail)
810
 
811
- fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
812
- if metrics is None:
813
- try:
814
- value = _resolve_objective_metric(
815
- trial,
816
- metric_key=args.metric,
817
- metrics=None,
818
- allow_log_metric_fallback=args.allow_log_metric_fallback,
819
- fallback_bpb=fallback_bpb,
820
- tps_seen=tps_seen,
821
- )
822
- if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
823
- raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
824
- return value
825
- except optuna.TrialPruned:
826
- pass
827
- if tps_seen is not None:
828
- trial.set_user_attr("tps", tps_seen)
829
- detail = f"stage={stage}, logs={len(log_lines)}"
830
- if terminal_detail:
831
- detail = f"{detail}, message={terminal_detail}"
832
  raise optuna.TrialPruned(f"No metrics found from HF launcher job ({detail})")
833
 
834
- metric_key = args.metric
835
 
836
  tps_val = metrics.get("tps")
837
  if tps_val is not None:
838
  tps_f = float(tps_val)
839
  trial.set_user_attr("tps", tps_f)
840
- if effective_min_tps is not None and tps_f < effective_min_tps:
841
- raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
842
-
843
- value = _resolve_objective_metric(
844
- trial,
845
- metric_key=metric_key,
846
- metrics=metrics,
847
- allow_log_metric_fallback=args.allow_log_metric_fallback,
848
- fallback_bpb=fallback_bpb,
849
- tps_seen=tps_seen,
850
- )
851
- trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
852
- trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
853
- return value
854
 
855
  return objective
856
 
857
 
858
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
859
- routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
860
- parser = argparse.ArgumentParser(description="Optuna HPO runner for HYDRA train.py")
861
  parser.add_argument("--study-name", default="hydra_hpo", help="Optuna study name")
862
  parser.add_argument("--storage", default="sqlite:///optuna_hpo.db", help="Optuna storage URL")
863
  parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize")
864
  parser.add_argument("--metric", default="val_bpb", help="Metric key to optimize from HYDRA metrics")
865
- parser.add_argument(
866
- "--min-tps",
867
- type=float,
868
- default=50000.0,
869
- help="TPS floor; prune trials under this value (set 0 to disable)",
870
- )
871
  parser.add_argument("--trials", type=int, default=20, help="Number of Optuna trials")
872
  parser.add_argument("--study-timeout", type=int, default=None, help="Study timeout in seconds")
873
  parser.add_argument("--trial-time-budget", type=int, default=300, help="HYDRA_TIME_BUDGET passed to each trial")
874
  parser.add_argument("--trial-timeout", type=int, default=900, help="Subprocess timeout per trial in seconds")
875
  parser.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local", help="Trial execution backend")
876
- parser.add_argument("--hf-namespace", default=routing_defaults.job_namespace, help="HF namespace for jobs")
877
- parser.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}", help="HF jobs image")
878
  parser.add_argument("--hf-flavor", default="a10g-large", help="HF jobs hardware flavor")
879
  parser.add_argument("--hf-timeout", default="25m", help="HF job timeout string")
880
  parser.add_argument("--hf-command", default="/app/entrypoint.py", help="Command executed inside HF job")
@@ -886,23 +886,23 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
886
  parser.add_argument("--hf-token-env", default="HF_TOKEN", help="Token env key passed as HF job secret")
887
  parser.add_argument("--hf-stop-after-metric", action="store_true", default=True, help="Cancel running job after metrics captured")
888
  parser.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
889
- parser.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py", help="Local launcher script for hf-launcher runner")
890
- parser.add_argument("--hf-output-repo", default=routing_defaults.output_repo, help="Optional FEATHER_HF_OUTPUT_REPO override for launcher runner")
891
- parser.add_argument("--allow-log-metric-fallback", action="store_true", default=False, help="When metrics JSON is missing, allow val_bpb fallback from latest logged train bpb")
892
- parser.add_argument("--no-allow-log-metric-fallback", action="store_false", dest="allow_log_metric_fallback")
893
- parser.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json", help="Path to transfer-learning prior trials JSON")
894
- parser.add_argument("--apply-priors", action="store_true", default=True, help="Enqueue transfer-learning prior trials before optimize")
895
- parser.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
896
- parser.add_argument("--quality-mode-local", action="store_true", default=False, help="Narrow local full-architecture search around the proven quality-winning region")
897
- parser.add_argument("--quality-anchor-top-k", type=int, default=3, help="Number of top clean priors to enqueue as deterministic local quality anchors")
898
- parser.add_argument("--seed", type=int, default=42, help="Seed for sampler")
899
  parser.add_argument("--n-startup-trials", type=int, default=5, help="Pruner startup trials before pruning")
900
  parser.add_argument("--n-warmup-steps", type=int, default=0, help="Pruner warmup steps")
901
  parser.add_argument("--patience-trials", type=int, default=None, help="Stop study after this many completed trials without meaningful improvement")
902
  parser.add_argument("--min-improvement", type=float, default=0.0, help="Minimum best-value improvement to reset patience")
903
  parser.add_argument("--work-dir", type=Path, default=REPO_ROOT / ".tmp" / "optuna", help="Directory for trial artifacts")
904
  parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "best_summary.json")
905
- return parser.parse_args(argv)
906
 
907
 
908
  def main() -> int:
@@ -916,22 +916,22 @@ def main() -> int:
916
  n_warmup_steps=args.n_warmup_steps,
917
  )
918
 
919
- study = optuna.create_study(
920
- study_name=args.study_name,
921
- storage=args.storage,
922
- load_if_exists=True,
923
- direction=args.direction,
924
- sampler=sampler,
925
- pruner=pruner,
926
- )
927
-
928
- enqueued_quality_anchors = _enqueue_quality_anchors(study, args.priors_file, args.quality_mode_local, args.quality_anchor_top_k)
929
- if enqueued_quality_anchors:
930
- print(f"[hpo] enqueued {enqueued_quality_anchors} local quality anchors from {args.priors_file}")
931
-
932
- enqueued_priors = _enqueue_transfer_priors(study, args.priors_file, args.apply_priors)
933
- if enqueued_priors:
934
- print(f"[hpo] enqueued {enqueued_priors} transfer priors from {args.priors_file}")
935
 
936
  state: dict[str, Any] = {
937
  "best": None,
@@ -990,29 +990,29 @@ def main() -> int:
990
  "best_trial_number": study.best_trial.number,
991
  "best_trial_user_attrs": study.best_trial.user_attrs,
992
  "n_trials": len(study.trials),
993
- "n_completed": len(completed),
994
- "patience_trials": args.patience_trials,
995
- "min_improvement": args.min_improvement,
996
- "quality_mode_local": args.quality_mode_local,
997
- "enqueued_quality_anchors": enqueued_quality_anchors,
998
- "enqueued_priors": enqueued_priors,
999
- }
1000
- else:
1001
- best = {
1002
  "study_name": study.study_name,
1003
  "direction": args.direction,
1004
  "metric": args.metric,
1005
  "best_value": None,
1006
  "best_params": {},
1007
- "best_trial_number": None,
1008
- "best_trial_user_attrs": {},
1009
- "n_trials": len(study.trials),
1010
- "n_completed": 0,
1011
- "quality_mode_local": args.quality_mode_local,
1012
- "enqueued_quality_anchors": enqueued_quality_anchors,
1013
- "enqueued_priors": enqueued_priors,
1014
- "note": "No completed trials with metrics found.",
1015
- }
1016
  args.summary_out.write_text(json.dumps(best, indent=2), encoding="utf-8")
1017
  print(json.dumps(best, indent=2))
1018
  return 0
 
5
  import json
6
  import os
7
  import re
8
+ import subprocess
9
+ import sys
10
+ import time
11
+ import tempfile
12
+ from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import optuna
17
+
18
+
19
+ _HF_ENV_KEY_RE = re.compile(r"^[A-Z][A-Z0-9_]*$")
20
+
21
+
22
+ REPO_ROOT = Path(__file__).resolve().parents[1]
23
+ if str(REPO_ROOT) not in sys.path:
24
+ sys.path.insert(0, str(REPO_ROOT))
25
+
26
+ from scripts.hf_routing import resolve_routing
27
+
28
+ TRAIN_ENTRYPOINT = REPO_ROOT / "train.py"
29
+ SEARCH_SPACE_KEYS = {
30
+ "d_model",
31
+ "n_layer",
32
+ "d_state",
33
+ "headdim",
34
+ "expand",
35
+ "seq_len",
36
+ "batch_size",
37
+ "grad_accum",
38
+ "matrix_lr",
39
+ "embed_lr",
40
+ "unembed_lr",
41
+ "hyena_layers",
42
+ "engram_n_columns",
43
+ "engram_layer_idx",
44
+ "sdr_target_active",
45
+ "htm_learn_every",
46
+ "htm_subsample",
47
+ "engram_subsample",
48
+ "mamba3_chunk",
49
+ "dropout",
50
+ }
51
+
52
+
53
+ def _filter_prior_params(raw: dict[str, Any]) -> dict[str, Any]:
54
+ return {k: v for k, v in raw.items() if k in SEARCH_SPACE_KEYS}
55
+
56
+
57
+ def _load_prior_param_sets(path: Path) -> list[dict[str, Any]]:
58
+ if not path.exists():
59
+ return []
60
+
61
+ payload = json.loads(path.read_text(encoding="utf-8"))
62
+ if isinstance(payload, dict):
63
+ rows = payload.get("trials", [])
64
+ elif isinstance(payload, list):
65
+ rows = payload
66
+ else:
67
+ rows = []
68
+
69
+ out: list[dict[str, Any]] = []
70
+ for item in rows:
71
+ if not isinstance(item, dict):
72
+ continue
73
+ params_obj = item.get("params", item)
74
+ if not isinstance(params_obj, dict):
75
+ continue
76
+ filtered = _filter_prior_params(params_obj)
77
+ if filtered:
78
+ out.append(filtered)
79
+ return out
80
+
81
+
82
+ def _enqueue_transfer_priors(study: optuna.Study, priors_file: Path, apply_priors: bool) -> int:
83
+ if not apply_priors:
84
+ return 0
85
+
86
+ priors_raw = _load_prior_param_sets(priors_file)
87
+ if not priors_raw:
88
+ return 0
89
+
90
+ # Deduplicate param sets across merged studies.
91
+ priors: list[dict[str, Any]] = []
92
+ seen: set[str] = set()
93
+ for params in priors_raw:
94
+ key = json.dumps(params, sort_keys=True)
95
+ if key in seen:
96
+ continue
97
+ seen.add(key)
98
+ priors.append(params)
99
+
100
+ enqueued = 0
101
+ for params in priors:
102
+ before = len(study.get_trials(deepcopy=False))
103
+ try:
104
+ study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"}, skip_if_exists=True)
105
+ except TypeError:
106
+ study.enqueue_trial(params, user_attrs={"seed_source": "transfer_priors"})
107
+ after = len(study.get_trials(deepcopy=False))
108
+ if after > before:
109
+ enqueued += 1
110
+ return enqueued
111
+
112
+
113
+ def _enqueue_quality_anchors(study: optuna.Study, priors_file: Path, quality_mode_local: bool, top_k: int) -> int:
114
+ if not quality_mode_local or top_k <= 0:
115
+ return 0
116
+
117
+ priors = _load_prior_param_sets(priors_file)[:top_k]
118
+ enqueued = 0
119
+ for params in priors:
120
+ before = len(study.get_trials(deepcopy=False))
121
+ try:
122
+ study.enqueue_trial(
123
+ params,
124
+ user_attrs={"seed_source": "quality_anchor"},
125
+ skip_if_exists=True,
126
+ )
127
+ except TypeError:
128
+ study.enqueue_trial(params, user_attrs={"seed_source": "quality_anchor"})
129
+ after = len(study.get_trials(deepcopy=False))
130
+ if after > before:
131
+ enqueued += 1
132
+ return enqueued
133
 
134
 
135
  def _parse_metrics_from_stdout(stdout: str) -> dict[str, Any] | None:
 
164
  return None
165
 
166
 
167
+ def _parse_last_train_bpb_from_logs(lines: list[str]) -> float | None:
168
+ """Best-effort fallback when final eval crashes before metrics JSON write."""
169
+ last: float | None = None
170
+ for line in lines:
171
+ m = re.search(r"\bbpb=([0-9]+(?:\.[0-9]+)?)", line)
172
  if m:
173
+ last = float(m.group(1))
174
+ return last
175
+
176
+
177
+ def _persist_trial_artifacts(
178
+ *,
179
+ trial_dir: Path,
180
+ metrics: dict[str, Any] | None,
181
+ log_lines: list[str] | None,
182
+ log_name: str,
183
+ metadata: dict[str, Any],
184
+ ) -> dict[str, str | None]:
185
+ trial_dir.mkdir(parents=True, exist_ok=True)
186
+ metrics_path = trial_dir / "metrics.json"
187
+ log_path = trial_dir / log_name
188
+ manifest_path = trial_dir / "trial_artifacts.json"
189
+
190
+ if metrics is not None:
191
+ metrics_path.write_text(json.dumps(metrics, indent=2, sort_keys=True), encoding="utf-8")
192
+ if log_lines is not None:
193
+ log_path.write_text("\n".join(log_lines), encoding="utf-8")
194
+
195
+ manifest = {
196
+ **metadata,
197
+ "metrics_path": str(metrics_path) if metrics is not None else None,
198
+ "log_path": str(log_path) if log_lines is not None else None,
199
+ }
200
+ manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
201
+ return {
202
+ "metrics_path": str(metrics_path) if metrics is not None else None,
203
+ "log_path": str(log_path) if log_lines is not None else None,
204
+ "manifest_path": str(manifest_path),
205
+ }
206
+
207
+
208
+ def _resolve_objective_metric(
209
+ trial: optuna.Trial,
210
+ *,
211
+ metric_key: str,
212
+ metrics: dict[str, Any] | None,
213
+ allow_log_metric_fallback: bool,
214
+ fallback_bpb: float | None,
215
+ tps_seen: float | None,
216
+ ) -> float:
217
+ """Resolve the objective value while labeling where it came from.
218
+
219
+ Validation metrics and live training-log fallbacks are intentionally
220
+ different sources. Keeping that distinction in trial attrs prevents a
221
+ skipped/OOM eval from being mistaken for a real validation result.
222
+ """
223
+ if metrics is None:
224
+ if allow_log_metric_fallback and metric_key == "val_bpb" and fallback_bpb is not None:
225
+ trial.set_user_attr("objective_source", "train_log_fallback")
226
+ trial.set_user_attr("objective_metric", "train_bpb")
227
+ trial.set_user_attr("eval_status", "missing_metrics")
228
+ trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
229
+ if tps_seen is not None:
230
+ trial.set_user_attr("tps", float(tps_seen))
231
+ return float(fallback_bpb)
232
+ trial.set_user_attr("objective_source", "missing_metrics")
233
+ raise optuna.TrialPruned("No metrics payload found")
234
+
235
+ eval_status = str(
236
+ metrics.get(
237
+ "eval_status",
238
+ "completed" if metrics.get("val_bpb") is not None else "unknown",
239
+ )
240
+ )
241
+ trial.set_user_attr("eval_status", eval_status)
242
+
243
+ if fallback_bpb is not None:
244
+ trial.set_user_attr("train_bpb_fallback", float(fallback_bpb))
245
+
246
+ if metric_key not in metrics or metrics[metric_key] is None:
247
+ trial.set_user_attr("objective_source", "missing_metric")
248
+ trial.set_user_attr("objective_metric", metric_key)
249
+ raise optuna.TrialPruned(f"Metric '{metric_key}' missing in metrics payload")
250
+
251
+ value = float(metrics[metric_key])
252
+ trial.set_user_attr("objective_metric", metric_key)
253
+ if metric_key == "val_bpb":
254
+ trial.set_user_attr("objective_source", "final_val")
255
+ trial.set_user_attr("final_val_bpb", value)
256
+ else:
257
+ trial.set_user_attr("objective_source", "metrics_json")
258
+ return value
259
+
260
+
261
+ def _fetch_job_logs_safe(
262
+ api,
263
+ *,
264
+ job_id: str,
265
+ token: str,
266
+ namespace: str,
267
+ retries: int = 3,
268
+ sleep_s: float = 2.0,
269
+ timeout_s: float = 20.0,
270
+ ) -> list[str]:
271
+ last_exc: Exception | None = None
272
+ for attempt in range(1, retries + 1):
273
+ try:
274
+ with ThreadPoolExecutor(max_workers=1) as executor:
275
+ future = executor.submit(
276
+ lambda: list(api.fetch_job_logs(job_id=job_id, follow=False, token=token, namespace=namespace))
277
+ )
278
+ return future.result(timeout=timeout_s)
279
+ except FuturesTimeoutError as exc:
280
+ last_exc = TimeoutError(f"Timed out fetching HF job logs for {job_id} after {timeout_s:.1f}s")
281
+ except Exception as exc: # noqa: BLE001
282
+ last_exc = exc
283
+ if attempt >= retries:
284
+ raise
285
+ time.sleep(sleep_s)
286
+ if last_exc is not None:
287
+ raise last_exc
288
+ return []
289
+
290
+
291
+ def _effective_min_tps(args: argparse.Namespace) -> float | None:
292
+ min_tps = args.min_tps
293
+ if getattr(args, "quality_mode_local", False) and min_tps == 50000.0:
294
+ return 0.0
295
+ return min_tps
296
+
297
+
298
+ def _trial_env(trial: optuna.Trial, args: argparse.Namespace, metrics_path: Path) -> dict[str, str]:
299
+ env = os.environ.copy()
300
+ full_arch_hpo = env.get("HYDRA_HPO_FULL_ARCH", "0") == "1"
301
+ speed_arch_hpo = full_arch_hpo and env.get("HYDRA_HPO_SPEED_ARCH", "0") == "1"
302
+ quality_mode_local = bool(getattr(args, "quality_mode_local", False))
303
 
304
  # Runtime and reporting
305
  env["HYDRA_METRICS_OUT"] = str(metrics_path)
306
  env["HYDRA_TIME_BUDGET"] = str(args.trial_time_budget)
307
  env["PYTHONUNBUFFERED"] = "1"
308
 
309
+ # Search space — fully env-driven to match existing training stack.
310
+ if speed_arch_hpo:
311
+ # Full-arch speed mode targets A10 underutilization observed in HPO:
312
+ # low VRAM/MFU, strong BPB from shallow models, and fixed SDR/HTM
313
+ # overhead dominating small microbatches. Keep all components enabled
314
+ # while amortizing overhead over more tokens.
315
+ env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96]))
316
+ env["HYDRA_N_LAYER"] = str(trial.suggest_categorical("n_layer", [2]))
317
+ env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
318
+ env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
319
+ env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
320
+ elif quality_mode_local and full_arch_hpo:
321
+ env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128]))
322
+ env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 2, 3))
323
+ env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32]))
324
+ env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [16, 32]))
325
+ env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
326
+ else:
327
+ env["HYDRA_D_MODEL"] = str(trial.suggest_categorical("d_model", [64, 96, 128, 160, 192]))
328
+ env["HYDRA_N_LAYER"] = str(trial.suggest_int("n_layer", 1, 4))
329
+ env["HYDRA_D_STATE"] = str(trial.suggest_categorical("d_state", [16, 32, 48]))
330
+ env["HYDRA_HEADDIM"] = str(trial.suggest_categorical("headdim", [8, 16, 32]))
331
+ env["HYDRA_EXPAND"] = str(trial.suggest_categorical("expand", [1, 2]))
332
+
333
+ if speed_arch_hpo:
334
+ seq_len = trial.suggest_categorical("seq_len", [64, 128])
335
+ batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
336
+ grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
337
+ elif quality_mode_local and full_arch_hpo:
338
+ seq_len = trial.suggest_categorical("seq_len", [64])
339
+ batch_size = trial.suggest_categorical("batch_size", [4, 8])
340
+ grad_accum = trial.suggest_categorical("grad_accum", [4, 8, 16])
341
+ else:
342
+ seq_len = trial.suggest_categorical("seq_len", [32, 64])
343
+ batch_size = trial.suggest_categorical("batch_size", [4, 8] if full_arch_hpo else [4, 8, 16])
344
+ grad_accum = trial.suggest_categorical("grad_accum", [1, 4, 8, 16] if full_arch_hpo else [8, 16, 32, 64])
345
  # Keep TOTAL_BATCH_SIZE divisible by DEVICE_BATCH_SIZE * MAX_SEQ_LEN.
346
  total_batch = batch_size * seq_len * grad_accum
347
  env["HYDRA_SEQ_LEN"] = str(seq_len)
348
  env["HYDRA_BATCH_SIZE"] = str(batch_size)
349
  env["HYDRA_TOTAL_BATCH"] = str(total_batch)
350
 
351
+ if quality_mode_local and full_arch_hpo:
352
+ env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.008, 0.03, log=True))
353
+ env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.15, 0.6, log=True))
354
+ env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.001, 0.01, log=True))
355
+ else:
356
+ env["HYDRA_MATRIX_LR"] = str(trial.suggest_float("matrix_lr", 0.005, 0.2, log=True))
357
+ env["HYDRA_EMBED_LR"] = str(trial.suggest_float("embed_lr", 0.05, 1.0, log=True))
358
+ env["HYDRA_UNEMBED_LR"] = str(trial.suggest_float("unembed_lr", 0.0005, 0.02, log=True))
359
+
360
+ if full_arch_hpo:
361
+ env["HYDRA_HYENA_LAYERS"] = ""
362
+ env["HYDRA_ENGRAM_N_COLUMNS"] = str(
363
+ trial.suggest_categorical(
364
+ "engram_n_columns",
365
+ [512, 1024] if (speed_arch_hpo or quality_mode_local) else [512, 1024, 2048],
366
+ )
367
+ )
368
+ env["HYDRA_ENGRAM_LAYER_IDX"] = str(trial.suggest_int("engram_layer_idx", 0, max(0, int(env["HYDRA_N_LAYER"]) - 1)))
369
+ env["HYDRA_SDR_TARGET_ACTIVE"] = str(
370
+ trial.suggest_categorical(
371
+ "sdr_target_active",
372
+ [327] if quality_mode_local else ([164, 327] if speed_arch_hpo else [164, 327, 512]),
373
+ )
374
+ )
375
+ env["HYDRA_HTM_LEARN_EVERY"] = str(
376
+ trial.suggest_categorical("htm_learn_every", [8, 16] if (speed_arch_hpo or quality_mode_local) else [4, 8, 16])
377
+ )
378
+ env["HYDRA_HTM_SUBSAMPLE"] = str(
379
+ trial.suggest_categorical("htm_subsample", [1, 2] if quality_mode_local else ([4, 8, 16] if speed_arch_hpo else [1, 2, 4, 8]))
380
+ )
381
+ env["HYDRA_ENGRAM_SUBSAMPLE"] = str(
382
+ trial.suggest_categorical("engram_subsample", [1, 2] if quality_mode_local else ([1, 2, 4] if speed_arch_hpo else [1]))
383
+ )
384
+ env["HYDRA_MAMBA3_CHUNK"] = str(trial.suggest_categorical("mamba3_chunk", [32, 64]))
385
+ env["HYDRA_DROPOUT"] = str(trial.suggest_categorical("dropout", [0.0, 0.1] if (speed_arch_hpo or quality_mode_local) else [0.0, 0.1, 0.2]))
386
+ else:
387
+ env["HYDRA_HYENA_LAYERS"] = trial.suggest_categorical("hyena_layers", ["", "0", "1", "0,1"])
388
 
389
  # Keep trials alive long enough to emit metrics.
390
  env["HYDRA_FAIL_LOSS_THRESHOLD"] = "1000000"
391
  env["HYDRA_USE_NEMOTRON"] = os.environ.get("HYDRA_USE_NEMOTRON", "1")
392
  env["HYDRA_LOCAL_SHARDS_ONLY"] = os.environ.get("HYDRA_LOCAL_SHARDS_ONLY", "0")
393
  # Strict optimal-path defaults (no forced fallback profile).
394
+ env["HYDRA_MUON_COMPILE"] = os.environ.get("HYDRA_MUON_COMPILE", "1")
395
+ env["HYDRA_THROUGHPUT_MODE"] = os.environ.get("HYDRA_THROUGHPUT_MODE", "0" if full_arch_hpo else "1")
396
+ env["HYDRA_FORCE_HTM_CPU"] = os.environ.get("HYDRA_FORCE_HTM_CPU", "0")
397
+ env["HYDRA_ALLOW_SYNTHETIC_RETINA"] = os.environ.get("HYDRA_ALLOW_SYNTHETIC_RETINA", "0")
398
+ env["HYDRA_INERT_MAMBA"] = os.environ.get("HYDRA_INERT_MAMBA", "0")
399
+ env["HYDRA_FASTPATH"] = os.environ.get("HYDRA_FASTPATH", "0" if full_arch_hpo else "1")
400
+
401
+ return env
402
 
403
 
404
  def _sanitize_hf_env(env: dict[str, str]) -> dict[str, str]:
 
410
  return sanitized
411
 
412
 
413
+ def _hf_command_candidates(args: argparse.Namespace) -> list[list[str]]:
414
  if args.hf_use_bash:
415
  return [["bash", "-lc", args.hf_command]]
416
 
 
432
  uniq.append(c)
433
  return uniq
434
 
435
+ return [raw.split()]
436
+
 
 
 
 
 
 
437
 
438
+ def _space_repo_from_hf_image(image: str, namespace: str) -> str:
439
+ prefix = "hf.co/spaces/"
440
+ if image.startswith(prefix):
441
+ return image[len(prefix):]
442
+ return os.environ.get("FEATHER_HF_SPACE_REPO", f"{namespace}/feather-a10-runtime")
443
 
444
+
445
+ def _objective_local(args: argparse.Namespace):
446
+ effective_min_tps = _effective_min_tps(args)
447
+
448
+ def objective(trial: optuna.Trial) -> float:
449
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
450
  metrics_path = trial_dir / "metrics.json"
451
 
 
460
  timeout=args.trial_timeout,
461
  )
462
 
463
+ metrics: dict[str, Any] | None = None
464
  if metrics_path.exists():
465
  try:
466
  metrics = json.loads(metrics_path.read_text(encoding="utf-8"))
467
  except json.JSONDecodeError:
468
  metrics = None
469
+ if metrics is None:
470
+ metrics = _parse_metrics_from_stdout(proc.stdout)
471
+
472
+ artifact_paths = _persist_trial_artifacts(
473
+ trial_dir=trial_dir,
474
+ metrics=metrics,
475
+ log_lines=(proc.stdout or "").splitlines(),
476
+ log_name="train_stdout.log",
477
+ metadata={"runner": "local", "returncode": proc.returncode},
478
+ )
479
+ (trial_dir / "train_stderr.log").write_text(proc.stderr or "", encoding="utf-8")
480
+
481
+ fallback_bpb = _parse_last_train_bpb_from_logs(proc.stdout.splitlines())
482
+ if metrics is None:
483
+ _resolve_objective_metric(
484
+ trial,
485
+ metric_key=args.metric,
486
+ metrics=None,
487
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
488
+ fallback_bpb=fallback_bpb,
489
+ tps_seen=None,
490
+ )
491
+ raise optuna.TrialPruned("No metrics found (HYDRA_METRICS_OUT/[METRICS_JSON])")
492
 
493
  if proc.returncode != 0:
494
  raise optuna.TrialPruned(f"Training failed rc={proc.returncode}")
495
 
496
+ metric_key = args.metric
497
 
498
  tps_val = metrics.get("tps")
499
  if tps_val is not None:
500
  tps_f = float(tps_val)
501
  trial.set_user_attr("tps", tps_f)
502
+ if effective_min_tps is not None and tps_f < effective_min_tps:
503
+ raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
504
+
505
+ value = _resolve_objective_metric(
506
+ trial,
507
+ metric_key=metric_key,
508
+ metrics=metrics,
509
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
510
+ fallback_bpb=fallback_bpb,
511
+ tps_seen=None,
512
+ )
513
+
514
+ # Keep useful context on trial
515
+ trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
516
+ trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
517
+
518
+ return value
519
 
520
  return objective
521
 
522
 
523
+ def _objective_hf_job(args: argparse.Namespace):
524
  from huggingface_hub import HfApi
525
  from huggingface_hub.utils import get_token
526
 
 
530
  f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
531
  )
532
 
533
+ api = HfApi(token=token)
534
+ terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
535
+ effective_min_tps = _effective_min_tps(args)
536
 
537
  def objective(trial: optuna.Trial) -> float:
538
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
 
568
  info = api.inspect_job(job_id=job.id, token=token, namespace=args.hf_namespace)
569
  bootstrap_stage = str(info.status.stage)
570
  bootstrap_msg = str(getattr(info.status, "message", "") or "")
571
+ bootstrap_logs = _fetch_job_logs_safe(
572
+ api,
573
+ job_id=job.id,
574
+ token=token,
575
+ namespace=args.hf_namespace,
576
+ retries=2,
577
+ sleep_s=1.0,
578
+ )
579
  if bootstrap_stage in {"RUNNING", "COMPLETED"} or bootstrap_logs:
580
  break
581
  if bootstrap_stage in {"ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}:
 
611
  info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
612
  stage = str(info.status.stage)
613
  terminal_detail = str(getattr(info.status, "message", "")) or terminal_detail
614
+ log_lines = _fetch_job_logs_safe(
615
+ api,
616
+ job_id=job_id,
617
+ token=token,
618
+ namespace=args.hf_namespace,
619
+ )
620
 
621
  m = _parse_metrics_from_log_lines(log_lines)
622
  if m is not None:
 
643
  except Exception:
644
  pass
645
 
646
+ artifact_paths = _persist_trial_artifacts(
647
+ trial_dir=trial_dir,
648
+ metrics=metrics,
649
+ log_lines=log_lines,
650
+ log_name="hf_job.log",
651
+ metadata={"runner": "hf-job", "hf_job_id": job_id, "hf_stage": stage},
652
+ )
653
+ trial.set_user_attr("hf_stage", stage)
654
+ trial.set_user_attr("hf_log_lines", len(log_lines))
655
  if terminal_detail:
656
  trial.set_user_attr("hf_status_message", terminal_detail)
657
 
658
+ fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
659
+ if metrics is None:
660
+ try:
661
+ value = _resolve_objective_metric(
662
+ trial,
663
+ metric_key=args.metric,
664
+ metrics=None,
665
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
666
+ fallback_bpb=fallback_bpb,
667
+ tps_seen=tps_seen,
668
+ )
669
+ if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
670
+ raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
671
+ return value
672
+ except optuna.TrialPruned:
673
+ pass
674
+ if tps_seen is not None:
675
+ trial.set_user_attr("tps", tps_seen)
676
+ detail = f"stage={stage}, logs={len(log_lines)}"
677
+ if terminal_detail:
678
+ detail = f"{detail}, message={terminal_detail}"
679
  raise optuna.TrialPruned(f"No metrics found from HF job ({detail})")
680
 
681
+ metric_key = args.metric
682
 
683
  tps_val = metrics.get("tps")
684
  if tps_val is not None:
685
  tps_f = float(tps_val)
686
  trial.set_user_attr("tps", tps_f)
687
+ if effective_min_tps is not None and tps_f < effective_min_tps:
688
+ raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
689
+
690
+ value = _resolve_objective_metric(
691
+ trial,
692
+ metric_key=metric_key,
693
+ metrics=metrics,
694
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
695
+ fallback_bpb=fallback_bpb,
696
+ tps_seen=tps_seen,
697
+ )
698
+ trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
699
+ trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
700
+ return value
701
 
702
  return objective
703
 
704
 
705
+ def _objective_hf_launcher(args: argparse.Namespace):
706
  from huggingface_hub import HfApi
707
  from huggingface_hub.utils import get_token
708
 
 
712
  f"No Hugging Face token found. Set {args.hf_token_env} or run huggingface-cli login."
713
  )
714
 
715
+ api = HfApi(token=token)
716
+ terminal_states = {"ERROR", "COMPLETED", "CANCELLED", "TIMEOUT", "FAILED", "CANCELED"}
717
+ effective_min_tps = _effective_min_tps(args)
718
 
719
  def objective(trial: optuna.Trial) -> float:
720
  trial_dir = Path(tempfile.mkdtemp(prefix=f"optuna_trial_{trial.number}_", dir=str(args.work_dir)))
 
725
  local_env = os.environ.copy()
726
  local_env.update(env)
727
  local_env[args.hf_token_env] = token
728
+ local_env["FEATHER_HF_NAMESPACE"] = args.hf_namespace
729
+ local_env["FEATHER_HF_FLAVOR"] = args.hf_flavor
730
+ local_env["FEATHER_HF_JOB_TIMEOUT"] = args.hf_timeout
731
+ local_env["FEATHER_HF_IMAGE"] = args.hf_image
732
+ local_env["FEATHER_HF_SPACE_REPO"] = _space_repo_from_hf_image(args.hf_image, args.hf_namespace)
733
  if args.hf_output_repo:
734
  local_env["FEATHER_HF_OUTPUT_REPO"] = args.hf_output_repo
735
  else:
 
766
  info = api.inspect_job(job_id=job_id, token=token, namespace=args.hf_namespace)
767
  stage = str(info.status.stage)
768
  terminal_detail = str(getattr(info.status, "message", "") or "") or terminal_detail
769
+ log_lines = _fetch_job_logs_safe(
770
+ api,
771
+ job_id=job_id,
772
+ token=token,
773
+ namespace=args.hf_namespace,
774
+ )
775
 
776
  mtr = _parse_metrics_from_log_lines(log_lines)
777
  if mtr is not None:
 
796
  except Exception:
797
  pass
798
 
799
+ artifact_paths = _persist_trial_artifacts(
800
+ trial_dir=trial_dir,
801
+ metrics=metrics,
802
+ log_lines=log_lines,
803
+ log_name="hf_job.log",
804
+ metadata={"runner": "hf-launcher", "hf_job_id": job_id, "hf_stage": stage},
805
+ )
806
+ trial.set_user_attr("hf_stage", stage)
807
+ trial.set_user_attr("hf_log_lines", len(log_lines))
808
  if terminal_detail:
809
  trial.set_user_attr("hf_status_message", terminal_detail)
810
 
811
+ fallback_bpb = _parse_last_train_bpb_from_logs(log_lines)
812
+ if metrics is None:
813
+ try:
814
+ value = _resolve_objective_metric(
815
+ trial,
816
+ metric_key=args.metric,
817
+ metrics=None,
818
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
819
+ fallback_bpb=fallback_bpb,
820
+ tps_seen=tps_seen,
821
+ )
822
+ if tps_seen is not None and effective_min_tps is not None and tps_seen < effective_min_tps:
823
+ raise optuna.TrialPruned(f"TPS below floor: {tps_seen} < {effective_min_tps}")
824
+ return value
825
+ except optuna.TrialPruned:
826
+ pass
827
+ if tps_seen is not None:
828
+ trial.set_user_attr("tps", tps_seen)
829
+ detail = f"stage={stage}, logs={len(log_lines)}"
830
+ if terminal_detail:
831
+ detail = f"{detail}, message={terminal_detail}"
832
  raise optuna.TrialPruned(f"No metrics found from HF launcher job ({detail})")
833
 
834
+ metric_key = args.metric
835
 
836
  tps_val = metrics.get("tps")
837
  if tps_val is not None:
838
  tps_f = float(tps_val)
839
  trial.set_user_attr("tps", tps_f)
840
+ if effective_min_tps is not None and tps_f < effective_min_tps:
841
+ raise optuna.TrialPruned(f"TPS below floor: {tps_f} < {effective_min_tps}")
842
+
843
+ value = _resolve_objective_metric(
844
+ trial,
845
+ metric_key=metric_key,
846
+ metrics=metrics,
847
+ allow_log_metric_fallback=args.allow_log_metric_fallback,
848
+ fallback_bpb=fallback_bpb,
849
+ tps_seen=tps_seen,
850
+ )
851
+ trial.set_user_attr("summary_path", metrics.get("summary_path") or artifact_paths["manifest_path"])
852
+ trial.set_user_attr("run_log_path", metrics.get("run_log_path") or artifact_paths["log_path"])
853
+ return value
854
 
855
  return objective
856
 
857
 
858
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
859
+ routing_defaults = resolve_routing(token=os.environ.get("HF_TOKEN"))
860
+ parser = argparse.ArgumentParser(description="Optuna HPO runner for HYDRA train.py")
861
  parser.add_argument("--study-name", default="hydra_hpo", help="Optuna study name")
862
  parser.add_argument("--storage", default="sqlite:///optuna_hpo.db", help="Optuna storage URL")
863
  parser.add_argument("--direction", choices=["minimize", "maximize"], default="minimize")
864
  parser.add_argument("--metric", default="val_bpb", help="Metric key to optimize from HYDRA metrics")
865
+ parser.add_argument(
866
+ "--min-tps",
867
+ type=float,
868
+ default=50000.0,
869
+ help="TPS floor; prune trials under this value (set 0 to disable)",
870
+ )
871
  parser.add_argument("--trials", type=int, default=20, help="Number of Optuna trials")
872
  parser.add_argument("--study-timeout", type=int, default=None, help="Study timeout in seconds")
873
  parser.add_argument("--trial-time-budget", type=int, default=300, help="HYDRA_TIME_BUDGET passed to each trial")
874
  parser.add_argument("--trial-timeout", type=int, default=900, help="Subprocess timeout per trial in seconds")
875
  parser.add_argument("--runner", choices=["local", "hf-job", "hf-launcher"], default="local", help="Trial execution backend")
876
+ parser.add_argument("--hf-namespace", default=routing_defaults.job_namespace, help="HF namespace for jobs")
877
+ parser.add_argument("--hf-image", default=f"hf.co/spaces/{routing_defaults.space_repo}", help="HF jobs image")
878
  parser.add_argument("--hf-flavor", default="a10g-large", help="HF jobs hardware flavor")
879
  parser.add_argument("--hf-timeout", default="25m", help="HF job timeout string")
880
  parser.add_argument("--hf-command", default="/app/entrypoint.py", help="Command executed inside HF job")
 
886
  parser.add_argument("--hf-token-env", default="HF_TOKEN", help="Token env key passed as HF job secret")
887
  parser.add_argument("--hf-stop-after-metric", action="store_true", default=True, help="Cancel running job after metrics captured")
888
  parser.add_argument("--no-hf-stop-after-metric", action="store_false", dest="hf_stop_after_metric")
889
+ parser.add_argument("--hf-launcher-script", type=Path, default=REPO_ROOT / "scripts" / "launch_feather_hf_job.py", help="Local launcher script for hf-launcher runner")
890
+ parser.add_argument("--hf-output-repo", default=routing_defaults.output_repo, help="Optional FEATHER_HF_OUTPUT_REPO override for launcher runner")
891
+ parser.add_argument("--allow-log-metric-fallback", action="store_true", default=False, help="When metrics JSON is missing, allow val_bpb fallback from latest logged train bpb")
892
+ parser.add_argument("--no-allow-log-metric-fallback", action="store_false", dest="allow_log_metric_fallback")
893
+ parser.add_argument("--priors-file", type=Path, default=REPO_ROOT / "docs" / "hpo_transfer_priors.json", help="Path to transfer-learning prior trials JSON")
894
+ parser.add_argument("--apply-priors", action="store_true", default=True, help="Enqueue transfer-learning prior trials before optimize")
895
+ parser.add_argument("--no-apply-priors", action="store_false", dest="apply_priors")
896
+ parser.add_argument("--quality-mode-local", action="store_true", default=False, help="Narrow local full-architecture search around the proven quality-winning region")
897
+ parser.add_argument("--quality-anchor-top-k", type=int, default=3, help="Number of top clean priors to enqueue as deterministic local quality anchors")
898
+ parser.add_argument("--seed", type=int, default=42, help="Seed for sampler")
899
  parser.add_argument("--n-startup-trials", type=int, default=5, help="Pruner startup trials before pruning")
900
  parser.add_argument("--n-warmup-steps", type=int, default=0, help="Pruner warmup steps")
901
  parser.add_argument("--patience-trials", type=int, default=None, help="Stop study after this many completed trials without meaningful improvement")
902
  parser.add_argument("--min-improvement", type=float, default=0.0, help="Minimum best-value improvement to reset patience")
903
  parser.add_argument("--work-dir", type=Path, default=REPO_ROOT / ".tmp" / "optuna", help="Directory for trial artifacts")
904
  parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / ".tmp" / "optuna" / "best_summary.json")
905
+ return parser.parse_args(argv)
906
 
907
 
908
  def main() -> int:
 
916
  n_warmup_steps=args.n_warmup_steps,
917
  )
918
 
919
+ study = optuna.create_study(
920
+ study_name=args.study_name,
921
+ storage=args.storage,
922
+ load_if_exists=True,
923
+ direction=args.direction,
924
+ sampler=sampler,
925
+ pruner=pruner,
926
+ )
927
+
928
+ enqueued_quality_anchors = _enqueue_quality_anchors(study, args.priors_file, args.quality_mode_local, args.quality_anchor_top_k)
929
+ if enqueued_quality_anchors:
930
+ print(f"[hpo] enqueued {enqueued_quality_anchors} local quality anchors from {args.priors_file}")
931
+
932
+ enqueued_priors = _enqueue_transfer_priors(study, args.priors_file, args.apply_priors)
933
+ if enqueued_priors:
934
+ print(f"[hpo] enqueued {enqueued_priors} transfer priors from {args.priors_file}")
935
 
936
  state: dict[str, Any] = {
937
  "best": None,
 
990
  "best_trial_number": study.best_trial.number,
991
  "best_trial_user_attrs": study.best_trial.user_attrs,
992
  "n_trials": len(study.trials),
993
+ "n_completed": len(completed),
994
+ "patience_trials": args.patience_trials,
995
+ "min_improvement": args.min_improvement,
996
+ "quality_mode_local": args.quality_mode_local,
997
+ "enqueued_quality_anchors": enqueued_quality_anchors,
998
+ "enqueued_priors": enqueued_priors,
999
+ }
1000
+ else:
1001
+ best = {
1002
  "study_name": study.study_name,
1003
  "direction": args.direction,
1004
  "metric": args.metric,
1005
  "best_value": None,
1006
  "best_params": {},
1007
+ "best_trial_number": None,
1008
+ "best_trial_user_attrs": {},
1009
+ "n_trials": len(study.trials),
1010
+ "n_completed": 0,
1011
+ "quality_mode_local": args.quality_mode_local,
1012
+ "enqueued_quality_anchors": enqueued_quality_anchors,
1013
+ "enqueued_priors": enqueued_priors,
1014
+ "note": "No completed trials with metrics found.",
1015
+ }
1016
  args.summary_out.write_text(json.dumps(best, indent=2), encoding="utf-8")
1017
  print(json.dumps(best, indent=2))
1018
  return 0
overlay/scripts/run_cycle1a.py CHANGED
@@ -1,46 +1,45 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import sys
6
- from pathlib import Path
7
-
8
- REPO_ROOT = Path(__file__).resolve().parents[1]
9
- if str(REPO_ROOT) not in sys.path:
10
- sys.path.insert(0, str(REPO_ROOT))
11
-
12
- from scripts import cycle_executor
13
-
14
-
15
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
16
- parser = argparse.ArgumentParser(description="Run the full local Cycle 1a benchmark suite")
17
- parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_runs")
18
- parser.add_argument("--preflight-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_preflight.json")
19
- parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_summary.json")
20
- parser.add_argument("--hydrate-assets", action="store_true")
21
- parser.add_argument("--require-ready", action="store_true")
22
- parser.add_argument("--output-repo")
23
- parser.add_argument("--tokenizer-repo")
24
- return parser.parse_args(argv)
25
-
26
-
27
- def main(argv: list[str] | None = None) -> int:
28
- args = parse_args(argv)
29
- return cycle_executor.main([
30
- "--benchmark", "GSM8K",
31
- "--variant", "hydra_full",
32
- "--seed", "42",
33
- "--out-dir", str(args.out_dir),
34
- "--preflight-out", str(args.preflight_out),
35
- "--summary-out", str(args.summary_out),
36
- "--all-runnable",
37
- "--all-benchmarks",
38
- *( ["--hydrate-assets"] if args.hydrate_assets else [] ),
39
- *( ["--require-ready"] if args.require_ready else [] ),
40
- *( ["--output-repo", args.output_repo] if args.output_repo else [] ),
41
- *( ["--tokenizer-repo", args.tokenizer_repo] if args.tokenizer_repo else [] ),
42
- ])
43
-
44
-
45
- if __name__ == "__main__":
46
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from scripts import cycle_executor
9
+
10
+
11
+ REPO_ROOT = Path(__file__).resolve().parents[1]
12
+
13
+
14
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
15
+ parser = argparse.ArgumentParser(description="Run the full local Cycle 1a benchmark suite")
16
+ parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_runs")
17
+ parser.add_argument("--preflight-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_preflight.json")
18
+ parser.add_argument("--summary-out", type=Path, default=REPO_ROOT / "artifacts" / "cycle1a_summary.json")
19
+ parser.add_argument("--hydrate-assets", action="store_true")
20
+ parser.add_argument("--require-ready", action="store_true")
21
+ parser.add_argument("--output-repo")
22
+ parser.add_argument("--tokenizer-repo")
23
+ return parser.parse_args(argv)
24
+
25
+
26
+ def main(argv: list[str] | None = None) -> int:
27
+ args = parse_args(argv)
28
+ return cycle_executor.main([
29
+ "--benchmark", "GSM8K",
30
+ "--variant", "hydra_full",
31
+ "--seed", "42",
32
+ "--out-dir", str(args.out_dir),
33
+ "--preflight-out", str(args.preflight_out),
34
+ "--summary-out", str(args.summary_out),
35
+ "--all-runnable",
36
+ "--all-benchmarks",
37
+ *( ["--hydrate-assets"] if args.hydrate_assets else [] ),
38
+ *( ["--require-ready"] if args.require_ready else [] ),
39
+ *( ["--output-repo", args.output_repo] if args.output_repo else [] ),
40
+ *( ["--tokenizer-repo", args.tokenizer_repo] if args.tokenizer_repo else [] ),
41
+ ])
42
+
43
+
44
+ if __name__ == "__main__":
45
+ raise SystemExit(main())
 
overlay/scripts/setup.sh CHANGED
@@ -25,4 +25,3 @@ echo "=== Setup complete ==="
25
  echo "Run experiments with: uv run train.py"
26
  echo "Run orchestrator with: uv run -m harness.orchestrator"
27
  echo "Run Phase 1 subsystems with: bash scripts/run_phase1.sh"
28
- echo "For WSL/CUDA throughput gate: see docs/WSL_TPS_RUNBOOK.md"
 
25
  echo "Run experiments with: uv run train.py"
26
  echo "Run orchestrator with: uv run -m harness.orchestrator"
27
  echo "Run Phase 1 subsystems with: bash scripts/run_phase1.sh"
 
overlay/scripts/sweep_depth_aggregate.py CHANGED
@@ -11,77 +11,77 @@ Usage:
11
  """
12
  from __future__ import annotations
13
 
14
- import json
15
- import os
16
- import statistics
17
- import re
18
- import sys
19
- from pathlib import Path
20
-
21
- from configs.harness_config import HarnessConfig
22
-
23
- type MetricValue = float | int | str | bool | None
24
- type MetricsDict = dict[str, MetricValue]
25
-
26
- MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
27
- STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
28
- MIN_TPS = float(os.environ.get('SWEEP_MIN_TPS', '0'))
29
- TARGET_TOKENS_M = float(os.environ.get('SWEEP_TARGET_TOKENS_M', '0'))
30
- TARGET_SECONDS = float(os.environ.get('SWEEP_TARGET_SECONDS', '0'))
31
-
32
-
33
- def _zero_shot_score(result: MetricsDict) -> float:
34
- """Composite quality score for tie-breaking among BPB-near runs."""
35
- factual = float(result.get('factual_english_score', 0.0) or 0.0)
36
- instruction = float(result.get('instruction_following_score', 0.0) or 0.0)
37
- distinct_2 = float(result.get('distinct_2', 0.0) or 0.0)
38
- repetition = float(result.get('repetition_rate', 0.0) or 0.0)
39
- return factual + instruction + distinct_2 - repetition
40
-
41
-
42
- def _metric_float(result: MetricsDict, key: str, default: float = 0.0) -> float:
43
- value = result.get(key, default)
44
- return float(value) if isinstance(value, (int, float)) else default
45
-
46
-
47
- def _metric_int(result: MetricsDict, key: str, default: int = 0) -> int:
48
- value = result.get(key, default)
49
- return int(value) if isinstance(value, int) else default
50
-
51
-
52
- def _fixed_budget_ranking(results: dict[int, MetricsDict], *, metric_key: str, target: float) -> list[tuple[int, MetricsDict, float]]:
53
- ranked: list[tuple[int, MetricsDict, float]] = []
54
- for n_layer, row in results.items():
55
- budget_val = row.get(metric_key)
56
- if not isinstance(budget_val, (int, float)):
57
- continue
58
- gap = abs(float(budget_val) - target)
59
- ranked.append((n_layer, row, gap))
60
- ranked.sort(
61
- key=lambda item: (
62
- item[2],
63
- _metric_float(item[1], 'val_bpb', float('inf')),
64
- -_zero_shot_score(item[1]),
65
- -_metric_float(item[1], 'tps_median', 0.0),
66
- )
67
- )
68
- return ranked
69
-
70
-
71
- def _percentile_linear(sorted_values: list[float], pct: float) -> float:
72
- if not sorted_values:
73
- return 0.0
74
- if len(sorted_values) == 1:
75
- return sorted_values[0]
76
- rank = (len(sorted_values) - 1) * (pct / 100.0)
77
- lo = int(rank)
78
- hi = min(lo + 1, len(sorted_values) - 1)
79
- frac = rank - lo
80
- return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
81
-
82
-
83
- def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
84
- """Fetch HF Job stdout and parse the [METRICS_JSON] line."""
85
  try:
86
  from huggingface_hub import HfApi # type: ignore
87
  except Exception as e:
@@ -94,73 +94,73 @@ def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
94
  print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
95
  return None
96
 
97
- last_json = None
98
- tps_samples: list[tuple[int, int]] = []
99
- warmup_steps = 25
100
- for line in logs_stream:
101
- # HfApi returns strings or JobLogEntry-like objects depending on version.
102
- text = getattr(line, 'data', None) or str(line)
103
-
104
- wm = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", text)
105
- if wm:
106
- warmup_steps = int(wm.group(1))
107
-
108
- sm = STEP_TPS_PATTERN.search(text)
109
- if sm:
110
- tps_samples.append((int(sm.group(1)), int(sm.group(2))))
111
-
112
- if '[METRICS_JSON]' in text:
113
- payload = text.split('[METRICS_JSON]', 1)[1].strip()
114
- try:
115
- last_json = json.loads(payload)
116
- except Exception:
117
- # Might be truncated on a line boundary — keep looking.
118
- pass
119
- if last_json is None:
120
- return None
121
-
122
- steady_tps = [float(tps) for step, tps in tps_samples if step >= warmup_steps]
123
- if not steady_tps:
124
- steady_tps = [float(tps) for _, tps in tps_samples]
125
- if steady_tps:
126
- sorted_tps = sorted(steady_tps)
127
- last_json['tps_samples'] = len(steady_tps)
128
- last_json['tps_median'] = float(statistics.median(steady_tps))
129
- last_json['tps_p10'] = float(_percentile_linear(sorted_tps, 10.0))
130
- last_json['tps_min'] = float(sorted_tps[0])
131
- last_json['tps_max'] = float(sorted_tps[-1])
132
- last_json['tps_warmup_steps'] = int(warmup_steps)
133
-
134
- return last_json
135
-
136
-
137
- def compare(results: dict[int, MetricsDict]) -> None:
138
- """Pretty-print comparison across n_layer values."""
139
- if not results:
140
- print('[agg] no results')
141
- return
142
- sorted_n = sorted(results.keys())
143
- secondary_gates = HarnessConfig().to_secondary_gates()
144
-
145
- print('\n=== Active secondary gates ===')
146
- for metric, thresholds in sorted(secondary_gates.items()):
147
- print(f' {metric}: {json.dumps(thresholds, sort_keys=True)}')
148
-
149
- # Top-level scalars
150
- print('\n=== Top-level scalars ===')
151
  hdr = ['metric'] + [f'L={n}' for n in sorted_n]
152
  print(' '.join(f'{h:>14}' for h in hdr))
153
- for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
154
- 'training_seconds', 'peak_vram_mb', 'sdr_target_active',
155
- 'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits',
156
- 'tps_median', 'tps_p10', 'tps_min', 'tps_max', 'tps_samples'):
157
- row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
158
- print(' '.join(f'{c:>14}' for c in row))
159
 
160
  # Per-layer panel — one table per metric.
161
  print('\n=== Per-layer: delta_ratio (residual contribution) ===')
162
  print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
163
- max_depth = max(_metric_int(results[n], 'n_layer', 0) for n in sorted_n)
164
  for li in range(max_depth):
165
  row = [f'L{li:02d}']
166
  for n in sorted_n:
@@ -197,62 +197,62 @@ def compare(results: dict[int, MetricsDict]) -> None:
197
 
198
  # Dead-layer detection
199
  print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
200
- for n in sorted_n:
201
- r = results[n]
202
- n_layer = _metric_int(r, 'n_layer', 0)
203
  dead = []
204
  for li in range(n_layer):
205
  v = r.get(f'layer_{li}_delta_ratio')
206
  if isinstance(v, (int, float)) and v < 0.02:
207
  dead.append(li)
208
- status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
209
- print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
210
-
211
- print('\n=== Throughput-constrained ranking ===')
212
- ranked = sorted(
213
- ((n, r) for n, r in results.items() if isinstance(r.get('val_bpb'), (int, float))),
214
- key=lambda x: (
215
- (MIN_TPS > 0) and (_metric_float(x[1], 'tps_median', 0.0) < MIN_TPS),
216
- _metric_float(x[1], 'val_bpb', float('inf')),
217
- -_zero_shot_score(x[1]),
218
- ),
219
- )
220
- feasible_count = 0
221
- for n, r in ranked:
222
- tps_median = _metric_float(r, 'tps_median', 0.0)
223
- feasible = (MIN_TPS <= 0) or (tps_median >= MIN_TPS)
224
- zero_shot_score = _zero_shot_score(r)
225
- if feasible:
226
- feasible_count += 1
227
- print(
228
- f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
229
- f"tps_median={tps_median:.0f} zero_shot_score={zero_shot_score:.4f} feasible={feasible}",
230
- flush=True,
231
- )
232
- if MIN_TPS > 0:
233
- print(f"[agg] throughput gate: tps_median >= {MIN_TPS:.0f}; feasible={feasible_count}/{len(ranked)}")
234
-
235
- if TARGET_TOKENS_M > 0:
236
- print('\n=== Fixed-token champion comparison ===')
237
- print(f' target_tokens_M={TARGET_TOKENS_M:.4f}')
238
- for n, r, gap in _fixed_budget_ranking(results, metric_key='total_tokens_M', target=TARGET_TOKENS_M):
239
- print(
240
- f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
241
- f"total_tokens_M={_metric_float(r, 'total_tokens_M', float('nan')):.4f} "
242
- f"token_gap_M={gap:.4f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
243
- flush=True,
244
- )
245
-
246
- if TARGET_SECONDS > 0:
247
- print('\n=== Fixed-time champion comparison ===')
248
- print(f' target_seconds={TARGET_SECONDS:.1f}')
249
- for n, r, gap in _fixed_budget_ranking(results, metric_key='training_seconds', target=TARGET_SECONDS):
250
- print(
251
- f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
252
- f"training_seconds={_metric_float(r, 'training_seconds', float('nan')):.1f} "
253
- f"time_gap_s={gap:.1f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
254
- flush=True,
255
- )
256
 
257
 
258
  def main() -> int:
@@ -273,7 +273,7 @@ def main() -> int:
273
  jobs[n_layer] = job_id
274
 
275
  print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
276
- results: dict[int, MetricsDict] = {}
277
  for n, jid in jobs.items():
278
  print(f'[agg] fetching job={jid} (n_layer={n}) ...')
279
  m = fetch_metrics_from_job(jid)
 
11
  """
12
  from __future__ import annotations
13
 
14
+ import json
15
+ import os
16
+ import statistics
17
+ import re
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ from configs.harness_config import HarnessConfig
22
+
23
+ type MetricValue = float | int | str | bool | None
24
+ type MetricsDict = dict[str, MetricValue]
25
+
26
+ MANIFEST = Path(sys.argv[1] if len(sys.argv) > 1 else '/tmp/sweep_depth_manifest.txt')
27
+ STEP_TPS_PATTERN = re.compile(r"step=(\d+).*?\btps=(\d+)\b")
28
+ MIN_TPS = float(os.environ.get('SWEEP_MIN_TPS', '0'))
29
+ TARGET_TOKENS_M = float(os.environ.get('SWEEP_TARGET_TOKENS_M', '0'))
30
+ TARGET_SECONDS = float(os.environ.get('SWEEP_TARGET_SECONDS', '0'))
31
+
32
+
33
+ def _zero_shot_score(result: MetricsDict) -> float:
34
+ """Composite quality score for tie-breaking among BPB-near runs."""
35
+ factual = float(result.get('factual_english_score', 0.0) or 0.0)
36
+ instruction = float(result.get('instruction_following_score', 0.0) or 0.0)
37
+ distinct_2 = float(result.get('distinct_2', 0.0) or 0.0)
38
+ repetition = float(result.get('repetition_rate', 0.0) or 0.0)
39
+ return factual + instruction + distinct_2 - repetition
40
+
41
+
42
+ def _metric_float(result: MetricsDict, key: str, default: float = 0.0) -> float:
43
+ value = result.get(key, default)
44
+ return float(value) if isinstance(value, (int, float)) else default
45
+
46
+
47
+ def _metric_int(result: MetricsDict, key: str, default: int = 0) -> int:
48
+ value = result.get(key, default)
49
+ return int(value) if isinstance(value, int) else default
50
+
51
+
52
+ def _fixed_budget_ranking(results: dict[int, MetricsDict], *, metric_key: str, target: float) -> list[tuple[int, MetricsDict, float]]:
53
+ ranked: list[tuple[int, MetricsDict, float]] = []
54
+ for n_layer, row in results.items():
55
+ budget_val = row.get(metric_key)
56
+ if not isinstance(budget_val, (int, float)):
57
+ continue
58
+ gap = abs(float(budget_val) - target)
59
+ ranked.append((n_layer, row, gap))
60
+ ranked.sort(
61
+ key=lambda item: (
62
+ item[2],
63
+ _metric_float(item[1], 'val_bpb', float('inf')),
64
+ -_zero_shot_score(item[1]),
65
+ -_metric_float(item[1], 'tps_median', 0.0),
66
+ )
67
+ )
68
+ return ranked
69
+
70
+
71
+ def _percentile_linear(sorted_values: list[float], pct: float) -> float:
72
+ if not sorted_values:
73
+ return 0.0
74
+ if len(sorted_values) == 1:
75
+ return sorted_values[0]
76
+ rank = (len(sorted_values) - 1) * (pct / 100.0)
77
+ lo = int(rank)
78
+ hi = min(lo + 1, len(sorted_values) - 1)
79
+ frac = rank - lo
80
+ return sorted_values[lo] * (1.0 - frac) + sorted_values[hi] * frac
81
+
82
+
83
+ def fetch_metrics_from_job(job_id: str) -> MetricsDict | None:
84
+ """Fetch HF Job stdout and parse the [METRICS_JSON] line."""
85
  try:
86
  from huggingface_hub import HfApi # type: ignore
87
  except Exception as e:
 
94
  print(f'[agg] could not fetch logs for job={job_id}: {e}', file=sys.stderr)
95
  return None
96
 
97
+ last_json = None
98
+ tps_samples: list[tuple[int, int]] = []
99
+ warmup_steps = 25
100
+ for line in logs_stream:
101
+ # HfApi returns strings or JobLogEntry-like objects depending on version.
102
+ text = getattr(line, 'data', None) or str(line)
103
+
104
+ wm = re.search(r"\[TPS_GUARD\] enabled .*?warmup_steps=(\d+)", text)
105
+ if wm:
106
+ warmup_steps = int(wm.group(1))
107
+
108
+ sm = STEP_TPS_PATTERN.search(text)
109
+ if sm:
110
+ tps_samples.append((int(sm.group(1)), int(sm.group(2))))
111
+
112
+ if '[METRICS_JSON]' in text:
113
+ payload = text.split('[METRICS_JSON]', 1)[1].strip()
114
+ try:
115
+ last_json = json.loads(payload)
116
+ except Exception:
117
+ # Might be truncated on a line boundary — keep looking.
118
+ pass
119
+ if last_json is None:
120
+ return None
121
+
122
+ steady_tps = [float(tps) for step, tps in tps_samples if step >= warmup_steps]
123
+ if not steady_tps:
124
+ steady_tps = [float(tps) for _, tps in tps_samples]
125
+ if steady_tps:
126
+ sorted_tps = sorted(steady_tps)
127
+ last_json['tps_samples'] = len(steady_tps)
128
+ last_json['tps_median'] = float(statistics.median(steady_tps))
129
+ last_json['tps_p10'] = float(_percentile_linear(sorted_tps, 10.0))
130
+ last_json['tps_min'] = float(sorted_tps[0])
131
+ last_json['tps_max'] = float(sorted_tps[-1])
132
+ last_json['tps_warmup_steps'] = int(warmup_steps)
133
+
134
+ return last_json
135
+
136
+
137
+ def compare(results: dict[int, MetricsDict]) -> None:
138
+ """Pretty-print comparison across n_layer values."""
139
+ if not results:
140
+ print('[agg] no results')
141
+ return
142
+ sorted_n = sorted(results.keys())
143
+ secondary_gates = HarnessConfig().to_secondary_gates()
144
+
145
+ print('\n=== Active secondary gates ===')
146
+ for metric, thresholds in sorted(secondary_gates.items()):
147
+ print(f' {metric}: {json.dumps(thresholds, sort_keys=True)}')
148
+
149
+ # Top-level scalars
150
+ print('\n=== Top-level scalars ===')
151
  hdr = ['metric'] + [f'L={n}' for n in sorted_n]
152
  print(' '.join(f'{h:>14}' for h in hdr))
153
+ for key in ('val_bpb', 'val_ppl', 'num_params_M', 'total_tokens_M',
154
+ 'training_seconds', 'peak_vram_mb', 'sdr_target_active',
155
+ 'htm_anomaly', 'engram_hit_rate', 'sdr_active_bits',
156
+ 'tps_median', 'tps_p10', 'tps_min', 'tps_max', 'tps_samples'):
157
+ row = [key] + [f'{results[n].get(key, float("nan")):.4f}' if isinstance(results[n].get(key), (int, float)) else 'n/a' for n in sorted_n]
158
+ print(' '.join(f'{c:>14}' for c in row))
159
 
160
  # Per-layer panel — one table per metric.
161
  print('\n=== Per-layer: delta_ratio (residual contribution) ===')
162
  print(' '.join(['layer'] + [f'L={n:>2}' for n in sorted_n]))
163
+ max_depth = max(_metric_int(results[n], 'n_layer', 0) for n in sorted_n)
164
  for li in range(max_depth):
165
  row = [f'L{li:02d}']
166
  for n in sorted_n:
 
197
 
198
  # Dead-layer detection
199
  print('\n=== Dead-layer detection (delta_ratio < 0.02) ===')
200
+ for n in sorted_n:
201
+ r = results[n]
202
+ n_layer = _metric_int(r, 'n_layer', 0)
203
  dead = []
204
  for li in range(n_layer):
205
  v = r.get(f'layer_{li}_delta_ratio')
206
  if isinstance(v, (int, float)) and v < 0.02:
207
  dead.append(li)
208
+ status = 'ALL LIVE' if not dead else f'DEAD LAYERS: {dead}'
209
+ print(f' n_layer={n:2d} val_bpb={r.get("val_bpb", float("nan")):.4f} {status}')
210
+
211
+ print('\n=== Throughput-constrained ranking ===')
212
+ ranked = sorted(
213
+ ((n, r) for n, r in results.items() if isinstance(r.get('val_bpb'), (int, float))),
214
+ key=lambda x: (
215
+ (MIN_TPS > 0) and (_metric_float(x[1], 'tps_median', 0.0) < MIN_TPS),
216
+ _metric_float(x[1], 'val_bpb', float('inf')),
217
+ -_zero_shot_score(x[1]),
218
+ ),
219
+ )
220
+ feasible_count = 0
221
+ for n, r in ranked:
222
+ tps_median = _metric_float(r, 'tps_median', 0.0)
223
+ feasible = (MIN_TPS <= 0) or (tps_median >= MIN_TPS)
224
+ zero_shot_score = _zero_shot_score(r)
225
+ if feasible:
226
+ feasible_count += 1
227
+ print(
228
+ f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
229
+ f"tps_median={tps_median:.0f} zero_shot_score={zero_shot_score:.4f} feasible={feasible}",
230
+ flush=True,
231
+ )
232
+ if MIN_TPS > 0:
233
+ print(f"[agg] throughput gate: tps_median >= {MIN_TPS:.0f}; feasible={feasible_count}/{len(ranked)}")
234
+
235
+ if TARGET_TOKENS_M > 0:
236
+ print('\n=== Fixed-token champion comparison ===')
237
+ print(f' target_tokens_M={TARGET_TOKENS_M:.4f}')
238
+ for n, r, gap in _fixed_budget_ranking(results, metric_key='total_tokens_M', target=TARGET_TOKENS_M):
239
+ print(
240
+ f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
241
+ f"total_tokens_M={_metric_float(r, 'total_tokens_M', float('nan')):.4f} "
242
+ f"token_gap_M={gap:.4f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
243
+ flush=True,
244
+ )
245
+
246
+ if TARGET_SECONDS > 0:
247
+ print('\n=== Fixed-time champion comparison ===')
248
+ print(f' target_seconds={TARGET_SECONDS:.1f}')
249
+ for n, r, gap in _fixed_budget_ranking(results, metric_key='training_seconds', target=TARGET_SECONDS):
250
+ print(
251
+ f" n_layer={n:2d} val_bpb={_metric_float(r, 'val_bpb', float('nan')):.4f} "
252
+ f"training_seconds={_metric_float(r, 'training_seconds', float('nan')):.1f} "
253
+ f"time_gap_s={gap:.1f} tps_median={_metric_float(r, 'tps_median', 0.0):.0f}",
254
+ flush=True,
255
+ )
256
 
257
 
258
  def main() -> int:
 
273
  jobs[n_layer] = job_id
274
 
275
  print(f'[agg] reading {len(jobs)} jobs from {MANIFEST}')
276
+ results: dict[int, MetricsDict] = {}
277
  for n, jid in jobs.items():
278
  print(f'[agg] fetching job={jid} (n_layer={n}) ...')
279
  m = fetch_metrics_from_job(jid)
overlay/scripts/watch_benchmark_hf_job.py CHANGED
@@ -1,81 +1,33 @@
1
- #!/usr/bin/env python3
2
- from __future__ import annotations
3
-
4
- import argparse
5
- import json
6
- import time
7
- from pathlib import Path
8
-
9
- from huggingface_hub import HfApi
10
- from huggingface_hub.utils import get_token
11
-
12
-
13
- def parse_benchmark_result_from_logs(lines: list[str]):
14
- for line in reversed(lines):
15
- text = line.strip()
16
- if not text.startswith("{"):
17
- continue
18
- try:
19
- payload = json.loads(text)
20
- except json.JSONDecodeError:
21
- continue
22
- if isinstance(payload, dict) and "benchmark" in payload:
23
- return payload
24
- return None
25
-
26
-
27
- def collect_job_snapshot(api, *, job_id: str, token: str, namespace: str) -> dict[str, object]:
28
- info = api.inspect_job(job_id=job_id, token=token, namespace=namespace)
29
- stage = str(info.status.stage)
30
- message = str(getattr(info.status, "message", "") or "")
31
- logs = list(api.fetch_job_logs(job_id=job_id, follow=False, token=token, namespace=namespace))
32
- texts = [(getattr(line, "data", None) or str(line)) for line in logs]
33
- return {
34
- "job_id": job_id,
35
- "stage": stage,
36
- "message": message,
37
- "log_lines": len(texts),
38
- "result": parse_benchmark_result_from_logs(texts),
39
- }
40
-
41
-
42
- def wait_for_terminal_snapshot(api, *, job_id: str, token: str, namespace: str, poll_interval: float = 10.0, timeout_s: float = 1800.0) -> dict[str, object]:
43
- deadline = time.time() + timeout_s
44
- terminal = {"COMPLETED", "ERROR", "FAILED", "CANCELLED", "CANCELED", "TIMEOUT"}
45
- while True:
46
- payload = collect_job_snapshot(api, job_id=job_id, token=token, namespace=namespace)
47
- if payload["stage"] in terminal:
48
- return payload
49
- if time.time() >= deadline:
50
- return payload
51
- time.sleep(poll_interval)
52
-
53
-
54
- def write_watch_summary(path: Path, payload: dict[str, object]) -> None:
55
- path.parent.mkdir(parents=True, exist_ok=True)
56
- path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
57
-
58
-
59
- def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
60
- parser = argparse.ArgumentParser(description="Watch or snapshot a remote benchmark job")
61
- parser.add_argument("--job-id", required=True)
62
- parser.add_argument("--namespace", default="jackoatmon")
63
- parser.add_argument("--summary-out", type=Path)
64
- return parser.parse_args(argv)
65
-
66
-
67
- def main(argv: list[str] | None = None) -> int:
68
- args = parse_args(argv)
69
- token = get_token()
70
- if not token:
71
- raise SystemExit("HF_TOKEN must be set or cached via huggingface-cli login")
72
- api = HfApi(token=token)
73
- payload = collect_job_snapshot(api, job_id=args.job_id, token=token, namespace=args.namespace)
74
- if args.summary_out is not None:
75
- write_watch_summary(args.summary_out, payload)
76
- print(json.dumps(payload, indent=2, sort_keys=True))
77
- return 0
78
-
79
-
80
- if __name__ == "__main__":
81
- raise SystemExit(main())
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ from pathlib import Path
7
+
8
+
9
+ def parse_benchmark_result_from_logs(lines: list[str]):
10
+ for line in reversed(lines):
11
+ text = line.strip()
12
+ if not text.startswith("{"):
13
+ continue
14
+ try:
15
+ payload = json.loads(text)
16
+ except json.JSONDecodeError:
17
+ continue
18
+ if isinstance(payload, dict) and "benchmark" in payload:
19
+ return payload
20
+ return None
21
+
22
+
23
+ def write_watch_summary(path: Path, payload: dict[str, object]) -> None:
24
+ path.parent.mkdir(parents=True, exist_ok=True)
25
+ path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
26
+
27
+
28
+ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser(description="Watch or snapshot a remote benchmark job")
30
+ parser.add_argument("--job-id", required=True)
31
+ parser.add_argument("--namespace", default="jackoatmon")
32
+ parser.add_argument("--summary-out", type=Path)
33
+ return parser.parse_args(argv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
overlay/subsystems/htm.py CHANGED
@@ -29,46 +29,46 @@ copy is small compared to the SP/TM compute.
29
  from __future__ import annotations
30
 
31
  import time
32
- from concurrent.futures import ThreadPoolExecutor
33
- from typing import Any
34
 
35
  import numpy as np
36
  import torch
37
  import torch.nn as nn
38
 
39
- import htm_rust
40
-
41
- _HTM_REGION: Any = getattr(htm_rust, "HTMRegion", None)
42
- _HTM_REGION_GPU: Any = getattr(htm_rust, "HTMRegionGpu", None)
43
- _HTM_STEP_BATCH_FUSED_CUDA: Any = getattr(htm_rust, "step_batch_fused_cuda", None)
44
-
45
- # step_many releases the GIL for the whole pass, so multiple threads can
46
- # truly run regions in parallel — wall-clock scales with B up to CPU cores.
47
- _HTM_HAS_STEP_MANY = hasattr(_HTM_REGION, "step_many")
48
  # GPU backend: built with `maturin develop --features gpu`. One CUDA region
49
  # per batch slot, persistent device state for SP synapses. Transparent
50
  # fallback to CPU when not available.
51
- _HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
52
  # Zero-copy CUDA path: consumes torch CUDA tensors directly via the
53
  # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
54
  # and the D2H of outputs. Huge win when the input SDR already lives on GPU
55
  # (which is the train.py hot path — retina is a device buffer).
56
- _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_cuda")
57
  # Fused megakernel path: collapses all T timesteps + SP + TM into a single
58
  # CUDA launch per forward. Replaces global top-K with per-column threshold
59
  # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
60
  # Opt-in via env var (default on when available).
61
  import os as _os_fused
62
- _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_fused_cuda")
63
- _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
64
-
65
-
66
- def _is_fused_unavailable_error(exc: RuntimeError) -> bool:
67
- message = str(exc)
68
- return (
69
- "Fused HTM kernel is unavailable" in message
70
- or "fused HTM kernel disabled for this CUDA arch" in message
71
- )
72
 
73
 
74
  class HTMLayer(nn.Module):
@@ -93,11 +93,11 @@ class HTMLayer(nn.Module):
93
  learn: bool = True,
94
  reset_each_forward: bool = True,
95
  use_gpu: bool | None = None,
96
- ) -> None:
97
- super().__init__()
98
- self.input_bits = input_bits
99
- self.n_columns = n_columns
100
- self.cells_per_column = cells_per_column
101
  self.learn = learn
102
  self.reset_each_forward = reset_each_forward
103
  self._seed_base = seed
@@ -107,23 +107,23 @@ class HTMLayer(nn.Module):
107
  # converges since the EMA accumulates over many calls. Env:
108
  # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
109
  import os as _os
110
- self._learn_every = int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))
111
- self._forward_counter = 0
112
- force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
113
- # GPU backend gate. Default: auto-detect — use GPU when the pyo3
114
- # module was built with --features gpu AND CUDA is actually usable.
115
- if use_gpu is None:
116
- use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
117
- elif use_gpu and not _HTM_HAS_GPU:
118
- raise RuntimeError(
119
- "HTMLayer(use_gpu=True) but htm_rust was not built with "
120
- "--features gpu. Re-run `maturin develop --features gpu`."
121
- )
122
- elif use_gpu and force_cpu:
123
- use_gpu = False
124
- self._use_gpu = bool(use_gpu)
125
- cls = _HTM_REGION_GPU if self._use_gpu else _HTM_REGION
126
- self._region_cls = cls
127
  self._regions = [
128
  cls(input_bits, n_columns, cells_per_column, seed + i)
129
  for i in range(batch_size)
@@ -144,19 +144,19 @@ class HTMLayer(nn.Module):
144
  )
145
  )
146
 
147
- def reset(self) -> None:
148
- """Clear TM predictive state on every region (keeps SP synapses)."""
149
- for r in self._regions:
150
- r.reset()
151
-
152
- def _next_learn_flag(self) -> bool:
153
- self._forward_counter += 1
154
- return bool(
155
- self.learn
156
- and self.training
157
- and self._learn_every > 0
158
- and (self._forward_counter % self._learn_every == 0)
159
- )
160
 
161
  @torch.no_grad()
162
  def forward(self, sdr: torch.Tensor) -> torch.Tensor:
@@ -167,9 +167,9 @@ class HTMLayer(nn.Module):
167
  if self.reset_each_forward:
168
  self.reset()
169
 
170
- # Learn-gate: run learn kernels only every N forwards (skips 56% of
171
- # HTM CUDA time on skip-forwards; Hebbian EMA still converges).
172
- learn = self._next_learn_flag()
173
 
174
  # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
175
  # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
@@ -178,30 +178,30 @@ class HTMLayer(nn.Module):
178
  if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
179
  sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
180
  cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
181
- anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
182
- # Pick fused (1 launch) or legacy (12*T launches) path.
183
- if _HTM_USE_FUSED:
184
- try:
185
- for b in range(B):
186
- self._regions[b].step_many_fused_cuda(
187
- sdr_u8[b].__cuda_array_interface__,
188
- cols_out[b].__cuda_array_interface__,
189
- anom_out[b].__cuda_array_interface__,
190
- learn,
191
- )
192
- except RuntimeError as exc:
193
- if not _is_fused_unavailable_error(exc):
194
- raise
195
- for b in range(B):
196
- self._regions[b].step_many_cuda(
197
- sdr_u8[b].__cuda_array_interface__,
198
- cols_out[b].__cuda_array_interface__,
199
- anom_out[b].__cuda_array_interface__,
200
- learn,
201
- )
202
- else:
203
- for b in range(B):
204
- self._regions[b].step_many_cuda(
205
  sdr_u8[b].__cuda_array_interface__,
206
  cols_out[b].__cuda_array_interface__,
207
  anom_out[b].__cuda_array_interface__,
@@ -275,7 +275,7 @@ class HTMLayer(nn.Module):
275
  self._ensure_regions(B)
276
  if self.reset_each_forward:
277
  self.reset()
278
- learn = self._next_learn_flag()
279
 
280
  if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
281
  sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
@@ -287,61 +287,61 @@ class HTMLayer(nn.Module):
287
  # grid.y = B processes all regions concurrently — ~B× speedup.
288
  # Falls back to sequential dispatch if the batched entry isn't
289
  # available (older htm_rust wheel).
290
- if _HTM_USE_FUSED and _HTM_STEP_BATCH_FUSED_CUDA is not None:
291
  # Slice self._regions to match B: _ensure_regions may have
292
  # allocated more regions than the current batch size needs
293
  # (e.g. factual eval uses smaller batches than training).
294
  try:
295
- _HTM_STEP_BATCH_FUSED_CUDA(
296
  self._regions[:B],
297
  [sdr_u8[b].__cuda_array_interface__ for b in range(B)],
298
  [cols_out[b].__cuda_array_interface__ for b in range(B)],
299
  [anom_out[b].__cuda_array_interface__ for b in range(B)],
300
  learn,
301
  )
302
- except RuntimeError as _e:
303
- if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
304
- # Batch too large for cooperative grid. Fall back to
305
- # sequential per-region fused launches (each B=1).
306
- for b in range(B):
307
  self._regions[b].step_many_fused_cuda(
308
  sdr_u8[b].__cuda_array_interface__,
309
  cols_out[b].__cuda_array_interface__,
310
- anom_out[b].__cuda_array_interface__,
311
- learn,
312
- )
313
- elif _is_fused_unavailable_error(_e):
314
- for b in range(B):
315
- self._regions[b].step_many_cuda(
316
- sdr_u8[b].__cuda_array_interface__,
317
- cols_out[b].__cuda_array_interface__,
318
- anom_out[b].__cuda_array_interface__,
319
- learn,
320
- )
321
- else:
322
- raise
323
- elif _HTM_USE_FUSED:
324
- try:
325
- for b in range(B):
326
- self._regions[b].step_many_fused_cuda(
327
- sdr_u8[b].__cuda_array_interface__,
328
- cols_out[b].__cuda_array_interface__,
329
- anom_out[b].__cuda_array_interface__,
330
- learn,
331
- )
332
- except RuntimeError as exc:
333
- if not _is_fused_unavailable_error(exc):
334
- raise
335
- for b in range(B):
336
- self._regions[b].step_many_cuda(
337
- sdr_u8[b].__cuda_array_interface__,
338
- cols_out[b].__cuda_array_interface__,
339
- anom_out[b].__cuda_array_interface__,
340
- learn,
341
- )
342
- else:
343
- for b in range(B):
344
- self._regions[b].step_many_cuda(
345
  sdr_u8[b].__cuda_array_interface__,
346
  cols_out[b].__cuda_array_interface__,
347
  anom_out[b].__cuda_array_interface__,
 
29
  from __future__ import annotations
30
 
31
  import time
32
+ from concurrent.futures import ThreadPoolExecutor
33
+ from typing import Any
34
 
35
  import numpy as np
36
  import torch
37
  import torch.nn as nn
38
 
39
+ import htm_rust
40
+
41
+ _HTM_REGION: Any = getattr(htm_rust, "HTMRegion", None)
42
+ _HTM_REGION_GPU: Any = getattr(htm_rust, "HTMRegionGpu", None)
43
+ _HTM_STEP_BATCH_FUSED_CUDA: Any = getattr(htm_rust, "step_batch_fused_cuda", None)
44
+
45
+ # step_many releases the GIL for the whole pass, so multiple threads can
46
+ # truly run regions in parallel — wall-clock scales with B up to CPU cores.
47
+ _HTM_HAS_STEP_MANY = hasattr(_HTM_REGION, "step_many")
48
  # GPU backend: built with `maturin develop --features gpu`. One CUDA region
49
  # per batch slot, persistent device state for SP synapses. Transparent
50
  # fallback to CPU when not available.
51
+ _HTM_HAS_GPU = hasattr(htm_rust, "HTMRegionGpu")
52
  # Zero-copy CUDA path: consumes torch CUDA tensors directly via the
53
  # __cuda_array_interface__ protocol, skipping the sdr.cpu()/numpy round-trip
54
  # and the D2H of outputs. Huge win when the input SDR already lives on GPU
55
  # (which is the train.py hot path — retina is a device buffer).
56
+ _HTM_HAS_CAI = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_cuda")
57
  # Fused megakernel path: collapses all T timesteps + SP + TM into a single
58
  # CUDA launch per forward. Replaces global top-K with per-column threshold
59
  # inhibition (see htm_rust/docs/GPU_HTM.md §Fused Kernel).
60
  # Opt-in via env var (default on when available).
61
  import os as _os_fused
62
+ _HTM_HAS_FUSED = _HTM_HAS_GPU and hasattr(_HTM_REGION_GPU, "step_many_fused_cuda")
63
+ _HTM_USE_FUSED = _HTM_HAS_FUSED and bool(int(_os_fused.environ.get("HYDRA_HTM_FUSED", "1")))
64
+
65
+
66
+ def _is_fused_unavailable_error(exc: RuntimeError) -> bool:
67
+ message = str(exc)
68
+ return (
69
+ "Fused HTM kernel is unavailable" in message
70
+ or "fused HTM kernel disabled for this CUDA arch" in message
71
+ )
72
 
73
 
74
  class HTMLayer(nn.Module):
 
93
  learn: bool = True,
94
  reset_each_forward: bool = True,
95
  use_gpu: bool | None = None,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.input_bits = input_bits
99
+ self.n_columns = n_columns
100
+ self.cells_per_column = cells_per_column
101
  self.learn = learn
102
  self.reset_each_forward = reset_each_forward
103
  self._seed_base = seed
 
107
  # converges since the EMA accumulates over many calls. Env:
108
  # HYDRA_HTM_LEARN_EVERY=N (default 1 = every forward, 0 = disabled).
109
  import os as _os
110
+ self._learn_every = int(_os.environ.get("HYDRA_HTM_LEARN_EVERY", "1"))
111
+ self._forward_counter = 0
112
+ force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
113
+ # GPU backend gate. Default: auto-detect — use GPU when the pyo3
114
+ # module was built with --features gpu AND CUDA is actually usable.
115
+ if use_gpu is None:
116
+ use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
117
+ elif use_gpu and not _HTM_HAS_GPU:
118
+ raise RuntimeError(
119
+ "HTMLayer(use_gpu=True) but htm_rust was not built with "
120
+ "--features gpu. Re-run `maturin develop --features gpu`."
121
+ )
122
+ elif use_gpu and force_cpu:
123
+ use_gpu = False
124
+ self._use_gpu = bool(use_gpu)
125
+ cls = _HTM_REGION_GPU if self._use_gpu else _HTM_REGION
126
+ self._region_cls = cls
127
  self._regions = [
128
  cls(input_bits, n_columns, cells_per_column, seed + i)
129
  for i in range(batch_size)
 
144
  )
145
  )
146
 
147
+ def reset(self) -> None:
148
+ """Clear TM predictive state on every region (keeps SP synapses)."""
149
+ for r in self._regions:
150
+ r.reset()
151
+
152
+ def _next_learn_flag(self) -> bool:
153
+ self._forward_counter += 1
154
+ return bool(
155
+ self.learn
156
+ and self.training
157
+ and self._learn_every > 0
158
+ and (self._forward_counter % self._learn_every == 0)
159
+ )
160
 
161
  @torch.no_grad()
162
  def forward(self, sdr: torch.Tensor) -> torch.Tensor:
 
167
  if self.reset_each_forward:
168
  self.reset()
169
 
170
+ # Learn-gate: run learn kernels only every N forwards (skips 56% of
171
+ # HTM CUDA time on skip-forwards; Hebbian EMA still converges).
172
+ learn = self._next_learn_flag()
173
 
174
  # Zero-copy CUDA hot path. SDR already lives on GPU (retina buffer),
175
  # so we skip sdr.cpu()/numpy round-trip AND the output D2H. The Rust
 
178
  if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
179
  sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
180
  cols_out = torch.empty((B, T, self.n_columns), dtype=torch.uint8, device=sdr.device)
181
+ anom_out = torch.empty((B, T), dtype=torch.float32, device=sdr.device)
182
+ # Pick fused (1 launch) or legacy (12*T launches) path.
183
+ if _HTM_USE_FUSED:
184
+ try:
185
+ for b in range(B):
186
+ self._regions[b].step_many_fused_cuda(
187
+ sdr_u8[b].__cuda_array_interface__,
188
+ cols_out[b].__cuda_array_interface__,
189
+ anom_out[b].__cuda_array_interface__,
190
+ learn,
191
+ )
192
+ except RuntimeError as exc:
193
+ if not _is_fused_unavailable_error(exc):
194
+ raise
195
+ for b in range(B):
196
+ self._regions[b].step_many_cuda(
197
+ sdr_u8[b].__cuda_array_interface__,
198
+ cols_out[b].__cuda_array_interface__,
199
+ anom_out[b].__cuda_array_interface__,
200
+ learn,
201
+ )
202
+ else:
203
+ for b in range(B):
204
+ self._regions[b].step_many_cuda(
205
  sdr_u8[b].__cuda_array_interface__,
206
  cols_out[b].__cuda_array_interface__,
207
  anom_out[b].__cuda_array_interface__,
 
275
  self._ensure_regions(B)
276
  if self.reset_each_forward:
277
  self.reset()
278
+ learn = self._next_learn_flag()
279
 
280
  if _HTM_HAS_CAI and self._use_gpu and sdr.is_cuda:
281
  sdr_u8 = sdr.contiguous().to(torch.uint8) if sdr.dtype != torch.uint8 else sdr.contiguous()
 
287
  # grid.y = B processes all regions concurrently — ~B× speedup.
288
  # Falls back to sequential dispatch if the batched entry isn't
289
  # available (older htm_rust wheel).
290
+ if _HTM_USE_FUSED and _HTM_STEP_BATCH_FUSED_CUDA is not None:
291
  # Slice self._regions to match B: _ensure_regions may have
292
  # allocated more regions than the current batch size needs
293
  # (e.g. factual eval uses smaller batches than training).
294
  try:
295
+ _HTM_STEP_BATCH_FUSED_CUDA(
296
  self._regions[:B],
297
  [sdr_u8[b].__cuda_array_interface__ for b in range(B)],
298
  [cols_out[b].__cuda_array_interface__ for b in range(B)],
299
  [anom_out[b].__cuda_array_interface__ for b in range(B)],
300
  learn,
301
  )
302
+ except RuntimeError as _e:
303
+ if "COOPERATIVE_LAUNCH_TOO_LARGE" in str(_e):
304
+ # Batch too large for cooperative grid. Fall back to
305
+ # sequential per-region fused launches (each B=1).
306
+ for b in range(B):
307
  self._regions[b].step_many_fused_cuda(
308
  sdr_u8[b].__cuda_array_interface__,
309
  cols_out[b].__cuda_array_interface__,
310
+ anom_out[b].__cuda_array_interface__,
311
+ learn,
312
+ )
313
+ elif _is_fused_unavailable_error(_e):
314
+ for b in range(B):
315
+ self._regions[b].step_many_cuda(
316
+ sdr_u8[b].__cuda_array_interface__,
317
+ cols_out[b].__cuda_array_interface__,
318
+ anom_out[b].__cuda_array_interface__,
319
+ learn,
320
+ )
321
+ else:
322
+ raise
323
+ elif _HTM_USE_FUSED:
324
+ try:
325
+ for b in range(B):
326
+ self._regions[b].step_many_fused_cuda(
327
+ sdr_u8[b].__cuda_array_interface__,
328
+ cols_out[b].__cuda_array_interface__,
329
+ anom_out[b].__cuda_array_interface__,
330
+ learn,
331
+ )
332
+ except RuntimeError as exc:
333
+ if not _is_fused_unavailable_error(exc):
334
+ raise
335
+ for b in range(B):
336
+ self._regions[b].step_many_cuda(
337
+ sdr_u8[b].__cuda_array_interface__,
338
+ cols_out[b].__cuda_array_interface__,
339
+ anom_out[b].__cuda_array_interface__,
340
+ learn,
341
+ )
342
+ else:
343
+ for b in range(B):
344
+ self._regions[b].step_many_cuda(
345
  sdr_u8[b].__cuda_array_interface__,
346
  cols_out[b].__cuda_array_interface__,
347
  anom_out[b].__cuda_array_interface__,