Custom Training Examples¶
This page provides advanced examples for customizing Free Transformer training.
Custom Loss Functions¶
Adding Custom Loss Components¶
import torch
import torch.nn.functional as F
from free_transformer.losses import free_transformer_loss
def custom_free_transformer_loss(logits, z_logits, targets, config, **kwargs):
"""Custom loss function with additional components."""
# Standard Free Transformer loss
base_loss = free_transformer_loss(
logits=logits,
z_logits=z_logits,
targets=targets,
latent_dim=config.latent_dim,
kl_weight=config.kl_weight,
free_bits=config.free_bits
)
# Custom loss components
custom_losses = {}
# 1. Diversity loss - encourage different plans for different sequences
if 'diversity_weight' in kwargs:
diversity_loss = compute_diversity_loss(z_logits)
custom_losses['diversity_loss'] = diversity_loss
base_loss['total_loss'] += kwargs['diversity_weight'] * diversity_loss
# 2. Consistency loss - similar inputs should have similar plans
if 'consistency_weight' in kwargs and 'similarity_matrix' in kwargs:
consistency_loss = compute_consistency_loss(z_logits, kwargs['similarity_matrix'])
custom_losses['consistency_loss'] = consistency_loss
base_loss['total_loss'] += kwargs['consistency_weight'] * consistency_loss
# 3. Sparsity loss - encourage sparse plan usage
if 'sparsity_weight' in kwargs:
sparsity_loss = compute_sparsity_loss(z_logits)
custom_losses['sparsity_loss'] = sparsity_loss
base_loss['total_loss'] += kwargs['sparsity_weight'] * sparsity_loss
# Combine all losses
base_loss.update(custom_losses)
return base_loss
def compute_diversity_loss(z_logits):
"""Encourage diversity in latent plans."""
# Convert to probabilities
probs = torch.sigmoid(z_logits)
# Compute pairwise distances
batch_size = probs.size(0)
distances = torch.cdist(probs, probs, p=2)
# Encourage larger distances (more diversity)
diversity_loss = -distances.mean()
return diversity_loss
def compute_consistency_loss(z_logits, similarity_matrix):
"""Encourage consistent plans for similar inputs."""
probs = torch.sigmoid(z_logits)
# Compute plan distances
plan_distances = torch.cdist(probs, probs, p=2)
# Consistency loss: similar inputs should have similar plans
consistency_loss = (similarity_matrix * plan_distances).mean()
return consistency_loss
def compute_sparsity_loss(z_logits):
"""Encourage sparse plan usage."""
probs = torch.sigmoid(z_logits)
# L1 penalty on probabilities (encourages 0 or 1)
sparsity_loss = torch.abs(probs - 0.5).mean()
return -sparsity_loss # Negative because we want to maximize sparsity
Using Custom Loss in Training¶
def train_with_custom_loss(model, dataloader, config):
"""Training loop with custom loss function."""
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
for epoch in range(config.num_epochs):
for batch_idx, batch in enumerate(dataloader):
tokens = batch['input_ids']
# Forward pass
logits, z_logits = model(tokens, mode='training')
# Compute similarity matrix (example: based on first few tokens)
similarity_matrix = compute_input_similarity(tokens)
# Custom loss with additional components
loss_dict = custom_free_transformer_loss(
logits=logits,
z_logits=z_logits,
targets=tokens,
config=config,
diversity_weight=0.1,
consistency_weight=0.05,
sparsity_weight=0.02,
similarity_matrix=similarity_matrix
)
total_loss = loss_dict['total_loss']
# Backward pass
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Logging
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}")
print(f" Total Loss: {total_loss.item():.4f}")
print(f" Recon Loss: {loss_dict['recon_loss'].item():.4f}")
print(f" KL Loss: {loss_dict['kl_loss'].item():.4f}")
if 'diversity_loss' in loss_dict:
print(f" Diversity Loss: {loss_dict['diversity_loss'].item():.4f}")
def compute_input_similarity(tokens):
"""Compute similarity matrix based on input tokens."""
batch_size = tokens.size(0)
similarity_matrix = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
for j in range(batch_size):
# Simple similarity: overlap in first 10 tokens
overlap = (tokens[i, :10] == tokens[j, :10]).float().mean()
similarity_matrix[i, j] = overlap
return similarity_matrix.to(tokens.device)
Custom Training Schedules¶
KL Annealing Schedule¶
class KLAnnealingScheduler:
"""Custom KL weight annealing scheduler."""
def __init__(self, initial_weight=0.0, final_weight=1.0, annealing_steps=10000, schedule_type='linear'):
self.initial_weight = initial_weight
self.final_weight = final_weight
self.annealing_steps = annealing_steps
self.schedule_type = schedule_type
self.step_count = 0
def get_kl_weight(self):
"""Get current KL weight."""
if self.step_count >= self.annealing_steps:
return self.final_weight
progress = self.step_count / self.annealing_steps
if self.schedule_type == 'linear':
weight = self.initial_weight + progress * (self.final_weight - self.initial_weight)
elif self.schedule_type == 'cosine':
weight = self.initial_weight + 0.5 * (self.final_weight - self.initial_weight) * (1 - torch.cos(torch.tensor(progress * 3.14159)))
elif self.schedule_type == 'exponential':
weight = self.initial_weight * (self.final_weight / self.initial_weight) ** progress
else:
raise ValueError(f"Unknown schedule type: {self.schedule_type}")
return float(weight)
def step(self):
"""Update step count."""
self.step_count += 1
# Usage in training loop
kl_scheduler = KLAnnealingScheduler(
initial_weight=0.0,
final_weight=0.1,
annealing_steps=5000,
schedule_type='cosine'
)
for batch in dataloader:
# Get current KL weight
current_kl_weight = kl_scheduler.get_kl_weight()
# Use in loss computation
loss_dict = free_transformer_loss(
logits=logits,
z_logits=z_logits,
targets=targets,
latent_dim=config.latent_dim,
kl_weight=current_kl_weight,
free_bits=config.free_bits
)
# Update scheduler
kl_scheduler.step()
Free Bits Scheduling¶
class FreeBitsScheduler:
"""Dynamic free bits scheduling."""
def __init__(self, initial_bits=2.0, final_bits=0.5, schedule_steps=8000):
self.initial_bits = initial_bits
self.final_bits = final_bits
self.schedule_steps = schedule_steps
self.step_count = 0
self.kl_history = []
def get_free_bits(self, current_kl=None):
"""Get current free bits threshold."""
if current_kl is not None:
self.kl_history.append(current_kl)
# Adaptive scheduling based on KL history
if len(self.kl_history) > 100:
recent_kl = torch.tensor(self.kl_history[-100:]).mean()
# If KL is too low, increase free bits
if recent_kl < 0.1:
return min(self.initial_bits, self.get_scheduled_bits() + 0.5)
# If KL is too high, decrease free bits
elif recent_kl > 2.0:
return max(self.final_bits, self.get_scheduled_bits() - 0.5)
return self.get_scheduled_bits()
def get_scheduled_bits(self):
"""Get scheduled free bits value."""
if self.step_count >= self.schedule_steps:
return self.final_bits
progress = self.step_count / self.schedule_steps
return self.initial_bits + progress * (self.final_bits - self.initial_bits)
def step(self):
"""Update step count."""
self.step_count += 1
Custom Data Loading¶
Advanced Data Pipeline¶
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class AdvancedSyntheticDataset(Dataset):
"""Advanced synthetic dataset with custom features."""
def __init__(self, sequences, config):
self.sequences = sequences
self.config = config
self.max_length = config.max_seq_len
# Precompute sequence features
self.sequence_features = self._compute_features()
def _compute_features(self):
"""Compute features for each sequence."""
features = []
for seq in self.sequences:
feature = {
'length': len(seq),
'unique_tokens': len(set(seq)),
'repetition_ratio': self._compute_repetition_ratio(seq),
'entropy': self._compute_entropy(seq),
}
features.append(feature)
return features
def _compute_repetition_ratio(self, sequence):
"""Compute repetition ratio in sequence."""
if len(sequence) < 2:
return 0.0
repeats = sum(1 for i in range(1, len(sequence)) if sequence[i] == sequence[i-1])
return repeats / (len(sequence) - 1)
def _compute_entropy(self, sequence):
"""Compute entropy of sequence."""
if not sequence:
return 0.0
from collections import Counter
counts = Counter(sequence)
total = len(sequence)
entropy = 0
for count in counts.values():
prob = count / total
entropy -= prob * np.log2(prob)
return entropy
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence = self.sequences[idx]
features = self.sequence_features[idx]
# Truncate or pad sequence
if len(sequence) > self.max_length:
sequence = sequence[:self.max_length]
else:
# Pad with special token (assuming 0 is padding)
sequence = sequence + [0] * (self.max_length - len(sequence))
return {
'input_ids': torch.tensor(sequence, dtype=torch.long),
'attention_mask': torch.tensor([1] * len(self.sequences[idx]) + [0] * max(0, self.max_length - len(self.sequences[idx])), dtype=torch.long),
'features': features,
'original_length': len(self.sequences[idx])
}
class CurriculumDataLoader:
"""Data loader with curriculum learning."""
def __init__(self, dataset, batch_size, curriculum_stages):
self.dataset = dataset
self.batch_size = batch_size
self.curriculum_stages = curriculum_stages
self.current_stage = 0
self.stage_step = 0
def get_current_dataloader(self):
"""Get data loader for current curriculum stage."""
stage_config = self.curriculum_stages[self.current_stage]
# Filter dataset based on current stage criteria
filtered_indices = self._filter_indices(stage_config)
# Create subset
subset = torch.utils.data.Subset(self.dataset, filtered_indices)
return DataLoader(
subset,
batch_size=self.batch_size,
shuffle=True,
collate_fn=self._collate_fn
)
def _filter_indices(self, stage_config):
"""Filter dataset indices based on stage criteria."""
indices = []
for idx, features in enumerate(self.dataset.sequence_features):
# Check if sequence meets stage criteria
if (features['length'] >= stage_config.get('min_length', 0) and
features['length'] <= stage_config.get('max_length', float('inf')) and
features['entropy'] >= stage_config.get('min_entropy', 0) and
features['entropy'] <= stage_config.get('max_entropy', float('inf'))):
indices.append(idx)
return indices
def _collate_fn(self, batch):
"""Custom collate function."""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'features': [item['features'] for item in batch],
'original_lengths': [item['original_length'] for item in batch]
}
def advance_stage(self):
"""Advance to next curriculum stage."""
if self.current_stage < len(self.curriculum_stages) - 1:
self.current_stage += 1
self.stage_step = 0
print(f"Advanced to curriculum stage {self.current_stage}")
def step(self):
"""Update step count and check for stage advancement."""
self.stage_step += 1
stage_config = self.curriculum_stages[self.current_stage]
if self.stage_step >= stage_config.get('steps', float('inf')):
self.advance_stage()
# Usage example
curriculum_stages = [
{'min_length': 0, 'max_length': 128, 'min_entropy': 0, 'max_entropy': 2, 'steps': 2000},
{'min_length': 64, 'max_length': 256, 'min_entropy': 1, 'max_entropy': 4, 'steps': 3000},
{'min_length': 128, 'max_length': 512, 'min_entropy': 2, 'max_entropy': 8, 'steps': 5000},
]
dataset = AdvancedSyntheticDataset(sequences, config)
curriculum_loader = CurriculumDataLoader(dataset, batch_size=32, curriculum_stages=curriculum_stages)
Custom Model Components¶
Custom Injection Mechanisms¶
import torch.nn as nn
class CrossAttentionInjection(nn.Module):
"""Inject plan using cross-attention mechanism."""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.latent_dim = config.latent_dim
self.num_heads = config.num_heads
# Project plan to hidden dimension
self.plan_projection = nn.Linear(config.latent_dim, config.hidden_dim)
# Cross-attention layers
self.cross_attention = nn.MultiheadAttention(
embed_dim=config.hidden_dim,
num_heads=config.num_heads,
batch_first=True
)
# Layer norm and feedforward
self.norm = nn.LayerNorm(config.hidden_dim)
self.feedforward = nn.Sequential(
nn.Linear(config.hidden_dim, config.hidden_dim * 4),
nn.GELU(),
nn.Linear(config.hidden_dim * 4, config.hidden_dim)
)
def forward(self, decoder_hidden, binary_plan):
"""Apply cross-attention injection."""
batch_size, seq_len, hidden_dim = decoder_hidden.shape
# Project plan to hidden dimension
plan_repr = self.plan_projection(binary_plan) # [batch_size, hidden_dim]
plan_repr = plan_repr.unsqueeze(1) # [batch_size, 1, hidden_dim]
# Cross-attention: decoder hidden as query, plan as key/value
attended, _ = self.cross_attention(
query=decoder_hidden,
key=plan_repr,
value=plan_repr
)
# Residual connection and normalization
output = self.norm(decoder_hidden + attended)
# Feedforward
ff_output = self.feedforward(output)
output = self.norm(output + ff_output)
return output
class MultiLevelInjection(nn.Module):
"""Inject plan at multiple levels with different transformations."""
def __init__(self, config):
super().__init__()
self.num_levels = config.num_injection_levels
# Different projections for each level
self.level_projections = nn.ModuleList([
nn.Linear(config.latent_dim, config.hidden_dim)
for _ in range(self.num_levels)
])
# Level-specific gates
self.level_gates = nn.ModuleList([
nn.Linear(config.hidden_dim, config.hidden_dim)
for _ in range(self.num_levels)
])
# Attention weights for combining levels
self.level_attention = nn.Linear(config.hidden_dim, self.num_levels)
def forward(self, decoder_hidden, binary_plan, level_weights=None):
"""Apply multi-level injection."""
batch_size, seq_len, hidden_dim = decoder_hidden.shape
# Generate level-specific representations
level_reprs = []
for i, (proj, gate) in enumerate(zip(self.level_projections, self.level_gates)):
# Project plan for this level
level_repr = proj(binary_plan) # [batch_size, hidden_dim]
level_repr = level_repr.unsqueeze(1).expand(-1, seq_len, -1)
# Apply level-specific gating
gate_values = torch.sigmoid(gate(decoder_hidden))
gated_repr = gate_values * level_repr
level_reprs.append(gated_repr)
# Compute attention weights for combining levels
if level_weights is None:
attention_logits = self.level_attention(decoder_hidden) # [batch_size, seq_len, num_levels]
level_weights = torch.softmax(attention_logits, dim=-1)
# Combine level representations
combined_repr = torch.zeros_like(decoder_hidden)
for i, level_repr in enumerate(level_reprs):
weight = level_weights[..., i:i+1] # [batch_size, seq_len, 1]
combined_repr += weight * level_repr
return decoder_hidden + combined_repr
Custom Encoder Architectures¶
class HierarchicalEncoder(nn.Module):
"""Hierarchical encoder for multi-scale planning."""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.num_levels = config.encoder_levels
# Encoders for different scales
self.local_encoder = self._build_encoder(config, scale='local')
self.global_encoder = self._build_encoder(config, scale='global')
# Fusion mechanism
self.fusion = nn.Sequential(
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.GELU(),
nn.Linear(config.hidden_dim, config.hidden_dim)
)
# Output projection
self.output_projection = nn.Linear(config.hidden_dim, config.latent_dim)
def _build_encoder(self, config, scale):
"""Build encoder for specific scale."""
if scale == 'local':
# Local encoder: smaller receptive field
return nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads // 2,
dim_feedforward=config.hidden_dim * 2,
batch_first=True
),
num_layers=2
)
else:
# Global encoder: larger receptive field
return nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=config.hidden_dim,
nhead=config.num_heads,
dim_feedforward=config.hidden_dim * 4,
batch_first=True
),
num_layers=3
)
def forward(self, hidden_states):
"""Forward pass with hierarchical encoding."""
batch_size, seq_len, hidden_dim = hidden_states.shape
# Local encoding: process overlapping windows
local_features = self._encode_local(hidden_states)
# Global encoding: process entire sequence
global_features = self._encode_global(hidden_states)
# Fuse local and global features
fused_features = self.fusion(torch.cat([local_features, global_features], dim=-1))
# Project to latent dimension
latent_repr = self.output_projection(fused_features)
return latent_repr
def _encode_local(self, hidden_states):
"""Encode local features using sliding windows."""
batch_size, seq_len, hidden_dim = hidden_states.shape
window_size = min(64, seq_len // 4)
stride = window_size // 2
local_features = []
for start in range(0, seq_len - window_size + 1, stride):
end = start + window_size
window = hidden_states[:, start:end, :]
# Encode window
encoded_window = self.local_encoder(window)
# Pool to single representation
pooled = encoded_window.mean(dim=1) # [batch_size, hidden_dim]
local_features.append(pooled)
# Combine local features
if local_features:
local_repr = torch.stack(local_features, dim=1).mean(dim=1)
else:
local_repr = hidden_states.mean(dim=1)
return local_repr
def _encode_global(self, hidden_states):
"""Encode global features using full sequence."""
# Use learned query for global aggregation
global_query = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
batch_size = hidden_states.size(0)
query = global_query.expand(batch_size, -1, -1)
# Concatenate query with hidden states
input_with_query = torch.cat([query, hidden_states], dim=1)
# Encode
encoded = self.global_encoder(input_with_query)
# Extract global representation (first token)
global_repr = encoded[:, 0, :]
return global_repr
Advanced Training Techniques¶
Adversarial Training¶
class AdversarialTrainer:
"""Adversarial training for Free Transformer."""
def __init__(self, model, discriminator, config):
self.model = model
self.discriminator = discriminator
self.config = config
# Optimizers
self.model_optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
self.disc_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=config.disc_learning_rate)
def train_step(self, batch):
"""Single adversarial training step."""
tokens = batch['input_ids']
# Train discriminator
disc_loss = self._train_discriminator(tokens)
# Train generator (model)
gen_loss = self._train_generator(tokens)
return {
'discriminator_loss': disc_loss,
'generator_loss': gen_loss
}
def _train_discriminator(self, real_tokens):
"""Train discriminator to distinguish real from generated sequences."""
self.disc_optimizer.zero_grad()
# Real sequences
real_logits = self.discriminator(real_tokens)
real_loss = F.binary_cross_entropy_with_logits(
real_logits,
torch.ones_like(real_logits)
)
# Generated sequences
with torch.no_grad():
generated_tokens = self.model.generate(
real_tokens[:, :10],
max_new_tokens=real_tokens.size(1) - 10
)
fake_logits = self.discriminator(generated_tokens)
fake_loss = F.binary_cross_entropy_with_logits(
fake_logits,
torch.zeros_like(fake_logits)
)
# Total discriminator loss
disc_loss = (real_loss + fake_loss) / 2
disc_loss.backward()
self.disc_optimizer.step()
return disc_loss.item()
def _train_generator(self, tokens):
"""Train generator to fool discriminator."""
self.model_optimizer.zero_grad()
# Standard Free Transformer loss
logits, z_logits = self.model(tokens, mode='training')
standard_loss = free_transformer_loss(
logits=logits,
z_logits=z_logits,
targets=tokens,
latent_dim=self.config.latent_dim,
kl_weight=self.config.kl_weight,
free_bits=self.config.free_bits
)
# Adversarial loss
generated_tokens = self.model.generate(
tokens[:, :10],
max_new_tokens=tokens.size(1) - 10
)
fake_logits = self.discriminator(generated_tokens)
adversarial_loss = F.binary_cross_entropy_with_logits(
fake_logits,
torch.ones_like(fake_logits) # Want discriminator to think it's real
)
# Combined loss
total_loss = standard_loss['total_loss'] + self.config.adversarial_weight * adversarial_loss
total_loss.backward()
self.model_optimizer.step()
return total_loss.item()
class SimpleDiscriminator(nn.Module):
"""Simple discriminator for adversarial training."""
def __init__(self, vocab_size, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=hidden_dim * 4,
batch_first=True
),
num_layers=4
)
self.classifier = nn.Linear(hidden_dim, 1)
def forward(self, tokens):
"""Forward pass."""
embedded = self.embedding(tokens)
encoded = self.encoder(embedded)
# Global average pooling
pooled = encoded.mean(dim=1)
# Classification
logits = self.classifier(pooled)
return logits
Next Steps¶
- Evaluation Examples: Advanced evaluation techniques
- API Reference: Complete API documentation
- Training Guide: Standard training practices