Skip to content

Trainers API

SFTTrainer

Supervised Fine-Tuning trainer.

from legionheto import SFTTrainer

trainer = SFTTrainer(
    model=model,
    dataset=dataset,
    output_dir="./output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-4,
)

trainer.train()

DPOTrainer

Direct Preference Optimization trainer.

from legionheto import DPOTrainer

trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    dataset=dataset,
    output_dir="./output",
    beta=0.1,
)

trainer.train()

ORPOTrainer

Odds Ratio Preference Optimization trainer.

from legionheto import ORPOTrainer

trainer = ORPOTrainer(
    model=model,
    dataset=dataset,
    output_dir="./output",
    beta=0.1,
    lambda_weight=1.0,
)

trainer.train()

Common Parameters

All trainers support:

  • num_train_epochs: Number of epochs
  • per_device_train_batch_size: Batch size
  • gradient_accumulation_steps: Accumulation steps
  • learning_rate: Learning rate
  • warmup_steps: Warmup steps
  • max_grad_norm: Gradient clipping
  • weight_decay: Weight decay
  • lr_scheduler_type: Scheduler type