Training Guide¶
This guide covers training Free Transformer models from basic setups to advanced distributed training.
Basic Training¶
Single GPU Training¶
# Train with default config
python examples/train_free.py --config configs/free_transformer.yaml
# Train with custom parameters
python examples/train_free.py \
--config configs/free_transformer.yaml \
--batch-size 16 \
--learning-rate 1e-4 \
--num-epochs 5
Configuration Files¶
Training configurations are defined in YAML files:
# configs/free_transformer.yaml
model:
vocab_size: 50000
hidden_dim: 512
num_layers: 12
num_heads: 8
latent_dim: 32
max_seq_len: 1024
training:
batch_size: 32
learning_rate: 1e-4
num_epochs: 10
warmup_steps: 1000
weight_decay: 0.01
# Free Transformer specific
kl_weight: 0.1
free_bits: 0.5
optimizer:
type: "adamw"
betas: [0.9, 0.95]
eps: 1e-8
scheduler:
type: "cosine"
min_lr: 1e-6
data:
dataset_name: "synthetic"
max_seq_len: 512
num_workers: 4
Loss Components¶
The Free Transformer uses a composite loss function:
Reconstruction Loss¶
Standard cross-entropy loss for token prediction:
KL Divergence Loss¶
Regularizes the latent space:
Free Bits Regularization¶
Prevents posterior collapse:
Total Loss¶
Training Strategies¶
1. Curriculum Learning¶
Start with simpler tasks and gradually increase complexity:
# Phase 1: Small sequences, high KL weight
config.max_seq_len = 128
config.kl_weight = 1.0
# Phase 2: Medium sequences, medium KL weight
config.max_seq_len = 256
config.kl_weight = 0.5
# Phase 3: Full sequences, low KL weight
config.max_seq_len = 512
config.kl_weight = 0.1
2. KL Annealing¶
Gradually reduce KL weight during training:
def get_kl_weight(step, total_steps, initial_weight=1.0, final_weight=0.1):
progress = step / total_steps
return initial_weight * (1 - progress) + final_weight * progress
3. Free Bits Scheduling¶
Adjust free bits threshold over time:
def get_free_bits(step, total_steps, initial_bits=2.0, final_bits=0.5):
progress = step / total_steps
return initial_bits * (1 - progress) + final_bits * progress
Advanced Training¶
Mixed Precision Training¶
Enable automatic mixed precision for faster training:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
for batch in dataloader:
with autocast():
logits, z_logits = model(batch['input_ids'], mode='training')
loss = compute_loss(logits, z_logits, batch['input_ids'])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Gradient Checkpointing¶
Reduce memory usage at the cost of computation:
Learning Rate Scheduling¶
Use cosine annealing with warmup:
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import get_cosine_schedule_with_warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=1000,
num_training_steps=total_steps
)
Monitoring and Logging¶
TensorBoard Logging¶
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/free_transformer')
# Log losses
writer.add_scalar('Loss/Reconstruction', recon_loss, step)
writer.add_scalar('Loss/KL', kl_loss, step)
writer.add_scalar('Loss/Total', total_loss, step)
# Log learning rate
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], step)
Weights & Biases Integration¶
import wandb
wandb.init(project="free-transformer")
wandb.config.update(config)
# Log metrics
wandb.log({
'loss/reconstruction': recon_loss,
'loss/kl': kl_loss,
'loss/total': total_loss,
'learning_rate': lr
})
Evaluation During Training¶
Perplexity Calculation¶
def calculate_perplexity(model, dataloader):
model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in dataloader:
logits, _ = model(batch['input_ids'], mode='training')
loss = F.cross_entropy(
logits.view(-1, vocab_size),
batch['input_ids'].view(-1),
reduction='sum'
)
total_loss += loss.item()
total_tokens += batch['input_ids'].numel()
return torch.exp(torch.tensor(total_loss / total_tokens))
Generation Quality¶
def evaluate_generation(model, prompts, max_length=100):
model.eval()
generations = []
for prompt in prompts:
with torch.no_grad():
generated = model.generate(
prompt,
max_new_tokens=max_length,
temperature=0.8,
top_k=40
)
generations.append(generated)
return generations
Troubleshooting¶
Common Issues¶
Posterior Collapse - Symptoms: KL loss drops to zero, model ignores latent variable - Solutions: Increase free bits, reduce KL weight, use KL annealing
Training Instability - Symptoms: Loss spikes, gradient explosions - Solutions: Gradient clipping, lower learning rate, warmup
Poor Generation Quality - Symptoms: Repetitive or incoherent text - Solutions: Adjust temperature, top-k sampling, increase model size
Debugging Tips¶
- Monitor KL loss: Should be positive and stable
- Check latent utilization: Verify Z is being used
- Validate gradients: Ensure gradients flow through all components
- Compare with baseline: Train standard Transformer for comparison
Best Practices¶
- Start small: Begin with small models and datasets
- Use curriculum learning: Gradually increase complexity
- Monitor closely: Watch for posterior collapse
- Regular evaluation: Check generation quality frequently
- Save checkpoints: Regular saves for recovery
- Ablation studies: Test different hyperparameters
Next Steps¶
- Multi-GPU Training: Scale to multiple GPUs
- Configuration: Detailed config options
- Synthetic Data: Generate training data