Tabular Classification
K2
Safetensors
PyTorch
exoplanet
transit-detection
astronomy
kepler
tess
1d-cnn
custom-model

Exoplanet Transit Detector πŸ”­πŸͺ

A multi-branch 1D CNN for detecting exoplanet transits in stellar light curves, based on the AstroNet/ExoMiner++ architecture (NASA Ames).

Model Architecture

Multi-Branch 1D CNN with 5 input branches:

  • Global flux branch: Phase-folded full light curve (201 bins)
  • Local flux branch: Zoomed transit view (81 bins)
  • Odd transit branch: Odd-numbered transits (201 bins)
  • Even transit branch: Even-numbered transits (201 bins)
  • Scalar features: Period, duration, depth, stellar parameters (9 features)

Each flux branch uses 2 convolutional blocks with 3 conv layers each (8β†’16 filters), batch normalization, and max pooling. Branches are fused and fed through a 4-layer fully-connected classifier head.

Total parameters: 244,181

Performance

Metric Test Set
Accuracy 89.10%
F1 (weighted) 89.03%
Precision (weighted) 89.03%
Recall (weighted) 89.10%
Loss 0.2804

Training Details

  • Dataset: bingbangboom/exoplanet-transit-detection
    • Multi-mission: Kepler + TESS + K2
    • 18,853 train / 2,357 val / 2,357 test samples
    • 3-class: PLANET, FALSE_POSITIVE, NO_SIGNAL
  • Optimizer: AdamW (lr=5e-4, cosine schedule, 5% warmup)
  • Loss: Weighted cross-entropy (inverse frequency balancing)
  • Epochs: 30 (best model at epoch 15)
  • Batch size: 128
  • Architecture reference: ExoMiner++ (arxiv:2502.09790) & AstroNet-Triage-v2 (arxiv:2301.01371)

Usage

import torch
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# 1. Download and load model
model_path = hf_hub_download("sarojpatil16/exoplanet-transit-detector", "model.safetensors")

# 2. Recreate architecture (copy from this model card or train.py)
# ... (see model architecture code below) ...
# model = AstroNetCNN(n_scalars=9, num_classes=3)
# model.load_state_dict(load_file(model_path))
# model.eval()

# 3. Prepare inputs from light curve data
# flux_global: (1, 201) - phase-folded full light curve, median-subtracted & MAD-normalized
# flux_local:  (1, 81)  - zoomed transit view
# flux_odd:    (1, 201) - odd-numbered transits
# flux_even:   (1, 201) - even-numbered transits
# scalars:     (1, 9)   - [period_days, duration_hrs, depth_ppm, teff, logg, radius, mass, metallicity, kepmag]
#                          (period, duration, depth are log1p-transformed)

# 4. Predict
# with torch.no_grad():
#     output = model(flux_global, flux_local, flux_odd, flux_even, scalars)
#     probabilities = torch.softmax(output.logits, dim=-1)
#     pred_class = torch.argmax(output.logits, dim=-1)
#     # 0=PLANET, 1=FALSE_POSITIVE, 2=NO_SIGNAL

Label Mapping

Class ID Label Description
0 PLANET Confirmed or candidate exoplanet transit
1 FALSE_POSITIVE Signal is not a planet (eclipsing binary, stellar variability, etc.)
2 NO_SIGNAL No significant transit signal detected

References

Downloads last month
6
Safetensors
Model size
245k params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train sarojpatil16/exoplanet-transit-detector

Papers for sarojpatil16/exoplanet-transit-detector