TRL documentation
SDFT
SDFT
Self-Distilled Fine-Tuning (SDFT) is described in the paper Self-Distillation Enables Continual Learning by Idan Shenfeld, Mehul Damani, Jonas Hübotter, and Pulkit Agrawal.
Continual learning, enabling models to acquire new skills and knowledge without degrading existing capabilities, remains a fundamental challenge for foundation models. While on-policy reinforcement learning can reduce forgetting, it requires explicit reward functions that are often unavailable. Learning from expert demonstrations, the primary alternative, is dominated by supervised fine-tuning (SFT), which is inherently off-policy. We introduce Self-Distillation Fine-Tuning (SDFT), a simple method that enables on-policy learning directly from demonstrations. SDFT leverages in-context learning by using a demonstration-conditioned model as its own teacher, generating on-policy training signals that preserve prior capabilities while acquiring new skills. Across skill learning and knowledge acquisition tasks, SDFT consistently outperforms SFT, achieving higher new-task accuracy while substantially reducing catastrophic forgetting. In sequential learning experiments, SDFT enables a single model to accumulate multiple skills over time without performance regression, establishing on-policy distillation as a practical path to continual learning from demonstrations.
How it works
Plain supervised fine-tuning trains on the demonstration text off-policy, which tends to overwrite prior capabilities. SDFT learns on-policy instead: the student generates from the plain prompt, a teacher — the same model shown the prompt plus the example’s privileged_context — re-scores those tokens, and its demonstration-conditioned distribution is distilled back into the student. Teacher and student are one network differing only in what they see, creating a self-distillation loop.
Choosing the teacher
teacher_model_kind selects which copy of the model acts as teacher. "base" (the default) freezes the initial weights as a fixed reference, matching the paper; "live" reuses the current student for a zero-lag self-teacher; "ema" maintains an exponential moving average, resynced every teacher_sync_steps steps at rate teacher_update_rate. Under PEFT, "base" is obtained by disabling the adapter during the teacher forward to recover the base weights, and "ema" with pure-LoRA training holds the moving average in a dedicated "teacher" adapter instead of a second model copy. "ema" with a non-pure-LoRA PEFT model (e.g. modules_to_save or bias) is not supported, since a separate EMA copy cannot be parameter-matched to the student.
By default the student generates from the plain prompt; set generate_from_teacher=True to sample from the demonstration-conditioned prompt instead, trading on-policy fidelity for higher-quality rollouts. The distillation objective is set by distillation_mode ("topk_logits" by default, with "full_logits" and "sampled_token" alternatives), distillation_alpha, and distillation_topk; num_loss_tokens_to_skip drops leading completion tokens from the loss. Setting use_liger_kernel=True swaps in a memory-efficient fused JSD loss (Liger) that avoids materializing the full-vocabulary logits; it requires distillation_mode="full_logits" and is incompatible with distillation_is_clip. Training is text-only; generation runs through transformers by default, or vLLM (colocate or server mode) when use_vllm=True.
Usage
from datasets import Dataset
from trl.experimental.sdft import SDFTConfig, SDFTTrainer
dataset = Dataset.from_dict(
{
"prompt": [[{"role": "user", "content": "Solve 2+2."}]],
"privileged_context": ["Example answer: 4."],
}
)
training_args = SDFTConfig(
output_dir="sdft-model",
distillation_alpha=0.5,
distillation_mode="topk_logits",
distillation_topk=5,
max_completion_length=64,
)
trainer = SDFTTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
args=training_args,
train_dataset=dataset,
)
trainer.train()To generate from the teacher-conditioned prompt instead of the student prompt, set generate_from_teacher=True.
To customize how the teacher prompt is built, set teacher_prompt_template on SDFTConfig.
Expected dataset columns
Each example must provide:
prompt: the student-facing promptprivileged_context: only the extra teacher-only information, such as a demonstration, hint, or privileged feedback
Both standard text prompts and conversational prompts are supported by the trainer prompt handling.
Callbacks
The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows.
Shared self-distillation hooks:
on_self_distillation_batch_prepared: fired when a self-distillation batch is ready. The payload includesprompt_ids,completion_ids, andold_per_token_logpswhen importance-sampling clipping inputs are available.on_generation_batch_built: fired when a new buffered generation batch is created. The payload includesgenerate_everyandsteps_per_generation.
SDFT-specific hook:
on_generation_prompts_selected: fired when SDFT chooses the prompt source for on-policy generation. The payload includes the selectedgeneration_promptsand the correspondinggeneration_prompt_text.
Example script
Use trl/experimental/sdft/sdft.py to launch SDFT training from the command line. The script supports any causal LM from the Hub, custom local datasets via --dataset_path, and PEFT/LoRA via the standard ModelConfig flags.
python trl/experimental/sdft/sdft.py \
--model_name_or_path Qwen/Qwen3.5-0.8B \
--dataset_name your-org/your-dataset \
--output_dir outputs/sdft-qwen3.5-0.8b \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 2e-5 \
--max_prompt_length 1024 \
--max_completion_length 512 \
--generate_from_teacher \
--teacher_model_kind ema \
--teacher_sync_steps 1 \
--teacher_update_rate 0.05 \
--eval_strategy steps \
--eval_steps 50 \
--report_to wandbThe original implementation is available at idanshen/Self-Distillation.
SDFTConfig
class trl.experimental.sdft.SDFTConfig
< source >( output_dir: str | None = None per_device_train_batch_size: int = 8 num_train_epochs: float = 3.0 max_steps: int = -1 learning_rate: float = 5e-05 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_steps: float = 0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 optim_target_modules: None | str | list[str] = None gradient_accumulation_steps: int = 1 average_tokens_across_devices: bool = True max_grad_norm: float = 1.0 label_smoothing_factor: float = 0.0 bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None use_cache: bool = False neftune_noise_alpha: float | None = None torch_empty_cache_steps: int | None = None auto_find_batch_size: bool = False logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_steps: float = 10 logging_first_step: bool = False log_on_each_node: bool = True logging_nan_inf_filter: bool = True include_num_input_tokens_seen: str | bool = 'no' log_level: str = 'passive' log_level_replica: str = 'warning' disable_tqdm: bool | None = None report_to: None | str | list[str] = 'none' run_name: str | None = None project: str = 'huggingface' trackio_space_id: str | None = None trackio_bucket_id: str | None = None trackio_static_space_id: typing.Union[str, NoneType, typing.Literal[False]] = None eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' eval_steps: float | None = None eval_delay: float = 0 per_device_eval_batch_size: int = 8 prediction_loss_only: bool = False eval_on_start: bool = False eval_do_concat_batches: bool = True eval_use_gather_object: bool = False eval_accumulation_steps: int | None = None include_for_metrics: list = <factory> batch_eval_metrics: bool = False save_only_model: bool = False save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_on_each_node: bool = False save_total_limit: int | None = None enable_jit_checkpoint: bool = False push_to_hub: bool = False hub_token: str | None = None hub_private_repo: bool | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_always_push: bool = False hub_revision: str | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False restore_callback_states_from_checkpoint: bool = False full_determinism: bool = False seed: int = 42 data_seed: int | None = None use_cpu: bool = False accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None dataloader_drop_last: bool = False dataloader_num_workers: int = 0 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False dataloader_prefetch_factor: int | None = None remove_unused_columns: bool = False label_names: list[str] | None = None train_sampling_strategy: str = 'random' length_column_name: str = 'length' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None ddp_static_graph: bool | None = None ddp_backend: str | None = None ddp_timeout: int = 1800 fsdp: str | None = None fsdp_config: dict[str, typing.Any] | str | None = None deepspeed: dict | str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' skip_memory_metrics: bool = True do_train: bool = False do_eval: bool = False do_predict: bool = False resume_from_checkpoint: str | None = None warmup_ratio: float | None = None logging_dir: str | None = None local_rank: int = -1 model_init_kwargs: dict[str, typing.Any] | None = None disable_dropout: bool = False max_prompt_length: int | None = 512 num_generations: int = 8 num_generations_eval: int | None = None max_completion_length: int | None = 256 ds3_gather_for_generation: bool = True shuffle_dataset: bool = True generation_batch_size: int | None = None steps_per_generation: int | None = None temperature: float = 1.0 top_p: float = 1.0 top_k: int = 0 min_p: float | None = None generation_kwargs: dict[str, typing.Any] | None = None chat_template_kwargs: dict[str, typing.Any] | None = None repetition_penalty: float = 1.0 cache_implementation: str | None = None use_vllm: bool = False vllm_mode: str = 'colocate' vllm_model_impl: str = 'vllm' vllm_enable_sleep_mode: bool = False vllm_server_base_url: str | None = None vllm_server_host: str = '0.0.0.0' vllm_server_port: int = 8000 vllm_group_port: int = 51216 vllm_server_timeout: float = 240.0 vllm_tensor_parallel_size: int = 1 vllm_gpu_memory_utilization: float = 0.3 vllm_max_model_length: int | None = None num_iterations: int = 1 teacher_model_kind: str = 'base' teacher_update_rate: float = 0.05 teacher_sync_steps: int = 1 distillation_alpha: float = 0.5 distillation_mode: typing.Literal['sampled_token', 'full_logits', 'topk_logits'] = 'topk_logits' distillation_topk: int | None = 100 distillation_is_clip: float | None = 2.0 distillation_add_tail: bool = False diagnostics_warning_interval: int = 10 diagnostics_flat_tolerance: float = 1e-08 generate_from_teacher: bool = False teacher_prompt_template: str = '{prompt}\n\n{privileged_context}' num_loss_tokens_to_skip: int = 0 )
Parameters
- disable_dropout (
bool, optional, defaults toFalse) — Whether to disable dropout in the student and teacher models. - teacher_model_kind (
str, optional, defaults to"base") — Semantic teacher choice for SDFT.baseuses the initial student,liveuses the current student, andemauses an exponentially averaged teacher. - distillation_alpha (
float, optional, defaults to0.5) — Divergence interpolation coefficient for SDFT top-k logit distillation. - distillation_mode (
Literal["sampled_token", "full_logits", "topk_logits"], optional, defaults to"topk_logits") — Distillation objective mode. SDFT defaults to the previous effective top-k logit objective. - distillation_topk (
intorNone, optional, defaults to100) — Number of top tokens used by the default SDFT top-k logit objective. - generate_from_teacher (
bool, optional, defaults toFalse) — Whether on-policy generation should use the teacher-conditioned prompt instead of the student prompt. - teacher_prompt_template (
str, optional, defaults to"{prompt}\n\n{privileged_context}") — Template used to combine the student prompt and privileged context into the teacher prompt. - num_loss_tokens_to_skip (
int, optional, defaults to0) — Number of initial completion tokens to exclude from the distillation loss.
Configuration class for SDFTTrainer..
SDFTTrainer
class trl.experimental.sdft.SDFTTrainer
< source >( model: str | PreTrainedModel | nn.Module args: SDFTConfig | None = None train_dataset: Dataset | IterableDataset | None = None eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None callbacks: list[TrainerCallback] | None = None optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None) peft_config: PeftConfig | None = None )
Trainer for SDFT-style on-policy self-distillation with explicit teacher prompts.
train
< source >( resume_from_checkpoint: str | bool | None = None trial: optuna.Trial | dict[str, Any] | None = None ignore_keys_for_eval: list[str] | None = None ) → ~trainer_utils.TrainOutput
Parameters
- resume_from_checkpoint (
strorbool, optional) — If astr, local path to a saved checkpoint as saved by a previous instance ofTrainer. If abooland equalsTrue, load the last checkpoint in args.output_dir as saved by a previous instance ofTrainer. If present, training will resume from the model/optimizer/scheduler states loaded here. - trial (
optuna.Trialordict[str, Any], optional) — The trial run or the hyperparameter dictionary for hyperparameter search. - ignore_keys_for_eval (
list[str], optional) — A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.
Returns
~trainer_utils.TrainOutput
Object containing the global step count, training loss, and metrics.
Main training entry point.
Will save the model, so you can reload it using from_pretrained().
Will only save from the main process.
push_to_hub
< source >( commit_message: str | None = 'End of training' blocking: bool = True token: str | None = None revision: str | None = None **kwargs )
Parameters
- commit_message (
str, optional, defaults to"End of training") — Message to commit while pushing. - blocking (
bool, optional, defaults toTrue) — Whether the function should return only when thegit pushhas finished. - token (
str, optional, defaults toNone) — Token with write permission to overwrite Trainer’s original args. - revision (
str, optional) — The git revision to commit from. Defaults to the head of the “main” branch. - kwargs (
dict[str, Any], optional) — Additional keyword arguments passed along to~Trainer.create_model_card.
Upload self.model and self.processing_class to the 🤗 model hub on the repo self.args.hub_model_id.