Latent Planning Mechanism¶
This page provides an in-depth explanation of the latent planning mechanism that makes the Free Transformer unique.
Core Concept¶
Traditional autoregressive models generate tokens sequentially, making decisions based only on the tokens seen so far. The Free Transformer introduces explicit latent planning - the model first creates an abstract "plan" for the entire sequence, then generates tokens to fulfill that plan.
Planning vs Reactive Generation¶
Reactive Generation (Standard Transformers)¶
Each token depends only on previous tokens, leading to: - Local coherence but potential global inconsistency - Difficulty with long-range planning - Limited controllability
Plan-Based Generation (Free Transformer)¶
Full Context → Abstract Plan Z → Token 1, Token 2, ..., Token N
↓ ↑ ↑ ↑
└──────────────┴───────┴────────────┘
The model first creates a plan, then generates all tokens conditioned on that plan: - Global coherence through explicit planning - Better long-range dependencies - Controllable generation via plan manipulation
Mathematical Formulation¶
Standard Autoregressive Model¶
Free Transformer (Conditional VAE)¶
Where: - \(z\) is the latent plan variable - \(P(z)\) is the prior distribution (uniform for binary plans) - \(P(x_1, ..., x_T | z)\) is the conditional generation model
Plan Representation¶
Binary Plans¶
The Free Transformer uses binary latent variables: \(\(z \in \{0, 1\}^d\)\)
Where \(d\) is the latent dimension (typically 16-64).
Why Binary?¶
- Interpretability: Each bit can represent a discrete choice
- Efficiency: Compact representation
- Controllability: Easy to manipulate specific aspects
- Stability: Avoids posterior collapse issues common with continuous latents
Plan Semantics¶
Each bit in the plan can potentially encode:
- Style: Formal vs informal, technical vs casual
- Structure: Narrative vs expository, linear vs non-linear
- Content: Topic focus, emotional tone
- Length: Short vs long form content
Architecture Components¶
1. Non-Causal Encoder¶
The encoder creates the latent plan from the full sequence:
class NonCausalEncoder(nn.Module):
def __init__(self, config):
self.attention_layers = nn.ModuleList([
NonCausalAttention(config) for _ in range(config.encoder_layers)
])
self.learned_query = nn.Parameter(torch.randn(config.hidden_dim))
def forward(self, hidden_states):
# Use learned query to aggregate sequence information
query = self.learned_query.expand(hidden_states.size(0), 1, -1)
# Non-causal attention over entire sequence
for layer in self.attention_layers:
query = layer(query, hidden_states, hidden_states)
return query.squeeze(1)
Key Features: - Non-causal attention: Can see the entire sequence - Learned query: Single vector that aggregates information - Separate parameters: Independent from decoder
2. Binary Mapping¶
Converts continuous encoder output to discrete binary plan:
class BinaryMapper(nn.Module):
def __init__(self, config):
self.projection = nn.Linear(config.hidden_dim, config.latent_dim)
self.temperature = config.gumbel_temperature
def forward(self, encoder_output, training=True):
logits = self.projection(encoder_output)
if training:
# Gumbel-Softmax for differentiable sampling
binary_soft = F.gumbel_softmax(
torch.stack([logits, -logits], dim=-1),
tau=self.temperature,
hard=True
)
return binary_soft[..., 0], logits
else:
# Hard binary sampling
return (logits > 0).float(), logits
Gumbel-Softmax Trick: - Enables gradient flow through discrete sampling - Temperature controls discreteness vs continuity - Hard sampling during forward, soft during backward
3. Plan Injection¶
Integrates the binary plan into decoder representations:
class PlanInjection(nn.Module):
def __init__(self, config):
self.plan_projection = nn.Linear(config.latent_dim, config.hidden_dim)
self.gate = nn.Linear(config.hidden_dim, config.hidden_dim)
def forward(self, decoder_hidden, binary_plan):
# Project plan to hidden dimension
plan_repr = self.plan_projection(binary_plan)
# Broadcast to sequence length
plan_repr = plan_repr.unsqueeze(1).expand(-1, decoder_hidden.size(1), -1)
# Gated injection
gate_values = torch.sigmoid(self.gate(decoder_hidden))
return decoder_hidden + gate_values * plan_repr
Injection Strategies:
1. Additive: hidden + plan_projection(z)
2. Gated: hidden + gate * plan_projection(z) (used)
3. Concatenation: concat(hidden, plan_projection(z))
4. Cross-attention: Plan as keys/values
Training Dynamics¶
Variational Objective¶
The model is trained to maximize the Evidence Lower Bound (ELBO):
Where: - Reconstruction term: \(\mathbb{E}_{q(z|x)}[\log p(x|z)]\) - how well the model generates given the plan - Regularization term: \(KL(q(z|x) || p(z))\) - keeps posterior close to prior - β-VAE weight: Controls trade-off between reconstruction and regularization
Free Bits Regularization¶
To prevent posterior collapse, we use free bits:
This ensures the model uses at least free_bits nats of information in the latent variable.
Training vs Inference¶
Training Mode: 1. Encode full sequence → latent plan 2. Inject plan into decoder 3. Optimize reconstruction + KL loss
Inference Mode: 1. Sample plan from prior: \(z \sim p(z) = \text{Uniform}(\{0,1\}^d)\) 2. Inject sampled plan into decoder 3. Generate autoregressively
Plan Analysis and Interpretation¶
Plan Utilization¶
Monitor whether the model actually uses the latent variable:
def analyze_plan_usage(model, dataloader):
"""Analyze how much the model uses the latent plan."""
kl_divergences = []
for batch in dataloader:
with torch.no_grad():
_, z_logits = model(batch['input_ids'], mode='training')
# Compute KL divergence for each sample
posterior = torch.sigmoid(z_logits)
prior = torch.full_like(posterior, 0.5)
kl = F.kl_div(torch.log(posterior + 1e-8), prior, reduction='none')
kl_divergences.append(kl.sum(dim=-1))
kl_divergences = torch.cat(kl_divergences)
print(f"Mean KL divergence: {kl_divergences.mean():.4f}")
print(f"Std KL divergence: {kl_divergences.std():.4f}")
print(f"Min KL divergence: {kl_divergences.min():.4f}")
print(f"Max KL divergence: {kl_divergences.max():.4f}")
return kl_divergences
Plan Interpolation¶
Explore the latent space by interpolating between plans:
def interpolate_plans(model, prompt, plan1, plan2, steps=5):
"""Generate text with interpolated plans."""
generations = []
for i in range(steps):
alpha = i / (steps - 1)
interpolated_plan = (1 - alpha) * plan1 + alpha * plan2
interpolated_plan = (interpolated_plan > 0.5).float()
# Generate with interpolated plan
with torch.no_grad():
generated = model.generate_with_plan(
prompt,
interpolated_plan,
max_new_tokens=50
)
generations.append(generated)
return generations
Plan Manipulation¶
Control generation by modifying specific plan bits:
def manipulate_plan(model, prompt, bit_index, value):
"""Generate text with specific plan bit set to value."""
# Sample random plan
plan = torch.bernoulli(torch.full((1, model.config.latent_dim), 0.5))
# Set specific bit
plan[0, bit_index] = value
# Generate with modified plan
with torch.no_grad():
generated = model.generate_with_plan(prompt, plan, max_new_tokens=100)
return generated
# Example: Compare generations with bit 5 set to 0 vs 1
gen_0 = manipulate_plan(model, prompt, bit_index=5, value=0)
gen_1 = manipulate_plan(model, prompt, bit_index=5, value=1)
Advanced Planning Techniques¶
Hierarchical Planning¶
Use multiple latent variables at different levels:
class HierarchicalPlanner(nn.Module):
def __init__(self, config):
self.global_encoder = NonCausalEncoder(config)
self.local_encoders = nn.ModuleList([
NonCausalEncoder(config) for _ in range(config.num_local_levels)
])
def forward(self, hidden_states):
# Global plan for entire sequence
global_plan = self.global_encoder(hidden_states)
# Local plans for subsequences
local_plans = []
chunk_size = hidden_states.size(1) // len(self.local_encoders)
for i, encoder in enumerate(self.local_encoders):
start = i * chunk_size
end = (i + 1) * chunk_size
chunk = hidden_states[:, start:end]
local_plan = encoder(chunk)
local_plans.append(local_plan)
return global_plan, local_plans
Conditional Planning¶
Condition plans on external information:
class ConditionalPlanner(nn.Module):
def __init__(self, config):
self.encoder = NonCausalEncoder(config)
self.condition_projection = nn.Linear(config.condition_dim, config.hidden_dim)
def forward(self, hidden_states, condition):
# Project condition to hidden space
condition_repr = self.condition_projection(condition)
# Add condition to hidden states
conditioned_hidden = hidden_states + condition_repr.unsqueeze(1)
# Encode with condition
plan = self.encoder(conditioned_hidden)
return plan
Troubleshooting Planning Issues¶
Posterior Collapse¶
Symptoms: KL loss drops to zero, model ignores latent variable Solutions: - Increase free bits threshold - Reduce KL weight (β) - Use KL annealing - Check encoder capacity
Plan Underutilization¶
Symptoms: Low KL divergence, similar generations Solutions: - Increase latent dimension - Improve encoder architecture - Use stronger regularization - Check injection mechanism
Training Instability¶
Symptoms: Loss spikes, gradient explosions Solutions: - Gradient clipping - Lower learning rate - Reduce Gumbel temperature - Use warmup schedule
Evaluation Metrics¶
Plan Quality Metrics¶
- KL Divergence: Measures plan utilization
- Mutual Information: I(X; Z) between input and plan
- Plan Consistency: Similarity of plans for similar inputs
- Generation Diversity: Variety in outputs for different plans
Implementation¶
def evaluate_planning(model, dataloader):
"""Comprehensive evaluation of planning mechanism."""
metrics = {
'kl_divergence': [],
'plan_entropy': [],
'generation_diversity': []
}
for batch in dataloader:
with torch.no_grad():
logits, z_logits = model(batch['input_ids'], mode='training')
# KL divergence
posterior = torch.sigmoid(z_logits)
kl = compute_kl_divergence(posterior)
metrics['kl_divergence'].append(kl)
# Plan entropy
entropy = -posterior * torch.log(posterior + 1e-8) - (1 - posterior) * torch.log(1 - posterior + 1e-8)
metrics['plan_entropy'].append(entropy.sum(dim=-1))
return {k: torch.cat(v).mean().item() for k, v in metrics.items()}
Next Steps¶
- Free Transformer Architecture: Complete architecture overview
- Training Guide: How to train with latent planning
- Examples: Practical usage examples