""" Shared fitness function for threshold circuit LLM integration. Randomized tests, no answer supervision - fitness IS the training signal. """ import torch import random from typing import Callable, Dict, Tuple, List OPERATIONS = ['add', 'sub', 'mul', 'gt', 'lt', 'eq'] def ground_truth(a: int, b: int, op: str) -> int: """Compute expected result (8-bit arithmetic).""" if op == 'add': return (a + b) & 0xFF elif op == 'sub': return (a - b) & 0xFF elif op == 'mul': return (a * b) & 0xFF elif op == 'gt': return 1 if a > b else 0 elif op == 'lt': return 1 if a < b else 0 elif op == 'eq': return 1 if a == b else 0 else: raise ValueError(f"Unknown op: {op}") def int_to_bits(val: int, n_bits: int = 8) -> torch.Tensor: """Convert integer to bit tensor (MSB first).""" bits = torch.zeros(n_bits) for i in range(n_bits): bits[n_bits - 1 - i] = (val >> i) & 1 return bits def bits_to_int(bits: torch.Tensor) -> int: """Convert bit tensor to integer (MSB first).""" val = 0 n_bits = bits.shape[-1] for i in range(n_bits): val += int(bits[..., i].item()) << (n_bits - 1 - i) return val def op_to_idx(op: str) -> int: """Convert operation string to index.""" return OPERATIONS.index(op) def idx_to_op(idx: int) -> str: """Convert index to operation string.""" return OPERATIONS[idx] def generate_batch(batch_size: int, device: str = 'cuda') -> Dict[str, torch.Tensor]: """ Generate a batch of random arithmetic problems. Returns: Dict with: 'a': [batch_size] int tensor of first operands 'b': [batch_size] int tensor of second operands 'op': [batch_size] int tensor of operation indices 'a_bits': [batch_size, 8] bit tensor 'b_bits': [batch_size, 8] bit tensor 'op_onehot': [batch_size, 6] one-hot operation tensor 'expected': [batch_size] int tensor of expected results 'expected_bits': [batch_size, 8] bit tensor of expected results """ a_vals = torch.randint(0, 256, (batch_size,), device=device) b_vals = torch.randint(0, 256, (batch_size,), device=device) op_indices = torch.randint(0, len(OPERATIONS), (batch_size,), device=device) a_bits = torch.zeros(batch_size, 8, device=device) b_bits = torch.zeros(batch_size, 8, device=device) for i in range(8): a_bits[:, 7-i] = (a_vals >> i) & 1 b_bits[:, 7-i] = (b_vals >> i) & 1 op_onehot = torch.zeros(batch_size, len(OPERATIONS), device=device) op_onehot.scatter_(1, op_indices.unsqueeze(1), 1.0) expected = torch.zeros(batch_size, dtype=torch.long, device=device) for i in range(batch_size): a, b, op_idx = a_vals[i].item(), b_vals[i].item(), op_indices[i].item() expected[i] = ground_truth(a, b, idx_to_op(op_idx)) expected_bits = torch.zeros(batch_size, 8, device=device) for i in range(8): expected_bits[:, 7-i] = (expected >> i) & 1 return { 'a': a_vals, 'b': b_vals, 'op': op_indices, 'a_bits': a_bits.float(), 'b_bits': b_bits.float(), 'op_onehot': op_onehot.float(), 'expected': expected, 'expected_bits': expected_bits.float(), } def compute_fitness( model_fn: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], n_samples: int = 10000, batch_size: int = 256, device: str = 'cuda', return_details: bool = False ) -> float | Tuple[float, Dict]: """ Compute fitness score for a model. Args: model_fn: Function that takes (a_bits, b_bits, op_onehot) and returns result_bits n_samples: Number of test cases batch_size: Batch size for evaluation device: Device to run on return_details: If True, return per-operation breakdown Returns: Fitness score in [0, 1], optionally with details dict """ correct = 0 total = 0 op_correct = {op: 0 for op in OPERATIONS} op_total = {op: 0 for op in OPERATIONS} for _ in range(0, n_samples, batch_size): actual_batch = min(batch_size, n_samples - total) batch = generate_batch(actual_batch, device) with torch.no_grad(): pred_bits = model_fn(batch['a_bits'], batch['b_bits'], batch['op_onehot']) pred_bits_binary = (pred_bits > 0.5).float() for i in range(actual_batch): pred_val = 0 for j in range(8): pred_val += int(pred_bits_binary[i, j].item()) << (7 - j) expected_val = batch['expected'][i].item() op_name = idx_to_op(batch['op'][i].item()) op_total[op_name] += 1 total += 1 if pred_val == expected_val: correct += 1 op_correct[op_name] += 1 fitness = correct / total if total > 0 else 0.0 if return_details: details = { 'correct': correct, 'total': total, 'by_op': { op: { 'correct': op_correct[op], 'total': op_total[op], 'accuracy': op_correct[op] / op_total[op] if op_total[op] > 0 else 0.0 } for op in OPERATIONS } } return fitness, details return fitness def compute_bit_accuracy(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> float: """Compute per-bit accuracy (for gradient signal analysis).""" pred_binary = (pred_bits > 0.5).float() return (pred_binary == expected_bits).float().mean().item() def compute_loss(pred_bits: torch.Tensor, expected_bits: torch.Tensor) -> torch.Tensor: """Binary cross-entropy loss on output bits.""" pred_clamped = pred_bits.clamp(1e-7, 1 - 1e-7) return -((expected_bits * torch.log(pred_clamped) + (1 - expected_bits) * torch.log(1 - pred_clamped))).mean() if __name__ == "__main__": print("Testing fitness module...") batch = generate_batch(8, 'cpu') print(f"\nSample batch:") for i in range(4): a, b = batch['a'][i].item(), batch['b'][i].item() op = idx_to_op(batch['op'][i].item()) expected = batch['expected'][i].item() print(f" {a} {op} {b} = {expected}") def random_model(a_bits, b_bits, op_onehot): return torch.rand(a_bits.shape[0], 8, device=a_bits.device) fitness = compute_fitness(random_model, n_samples=1000, batch_size=100, device='cpu') print(f"\nRandom model fitness: {fitness:.4f} (expected ~0.004 for 8-bit)") def perfect_model(a_bits, b_bits, op_onehot): batch_size = a_bits.shape[0] results = torch.zeros(batch_size, 8, device=a_bits.device) for i in range(batch_size): a = sum(int(a_bits[i, j].item()) << (7-j) for j in range(8)) b = sum(int(b_bits[i, j].item()) << (7-j) for j in range(8)) op_idx = op_onehot[i].argmax().item() result = ground_truth(a, b, idx_to_op(op_idx)) for j in range(8): results[i, 7-j] = (result >> j) & 1 return results fitness = compute_fitness(perfect_model, n_samples=1000, batch_size=100, device='cpu') print(f"Perfect model fitness: {fitness:.4f} (expected 1.0)")