Losses API Reference¶
free_transformer.losses
¶
Loss functions for Free Transformer training.
FreeTransformerLoss(latent_dim=16, beta_kl=1.0, free_bits=0.3466, ignore_index=-100)
¶
Bases: Module
Wrapper class for Free Transformer loss.
Source code in src/free_transformer/losses.py
compute_kl_divergence(z_logits, latent_dim=16, free_bits=0.3466)
¶
Compute KL divergence with free bits for latent plan Z.
Implements Equation (5) from paper: (1/T) * sum_t max(0, D_KL(Q(Z_t|S) || P(Z_t)) - kappa)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z_logits
|
Tensor
|
Encoder logits [batch, seq_len, latent_dim] |
required |
latent_dim
|
int
|
Dimension of latent space (H) |
16
|
free_bits
|
float
|
Free bits budget (kappa) |
0.3466
|
Returns:
| Name | Type | Description |
|---|---|---|
kl_loss |
Tensor
|
Scalar KL divergence loss |
Source code in src/free_transformer/losses.py
compute_reconstruction_loss(logits, targets, ignore_index=-100)
¶
Standard cross-entropy loss for token prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Model predictions [batch, seq_len, vocab_size] |
required |
targets
|
Tensor
|
Target tokens [batch, seq_len] |
required |
ignore_index
|
int
|
Index to ignore in loss computation |
-100
|
Returns:
| Name | Type | Description |
|---|---|---|
loss |
Tensor
|
Scalar loss value |
Source code in src/free_transformer/losses.py
compute_vae_loss(logits, targets, z_logits, latent_dim=16, beta_kl=1.0, free_bits=0.3466, ignore_index=-100)
¶
Complete VAE loss for Free Transformer.
Loss = Reconstruction + beta * KL_with_free_bits
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
Model predictions [batch, seq_len, vocab_size] |
required |
targets
|
Tensor
|
Target tokens [batch, seq_len] |
required |
z_logits
|
Tensor
|
Encoder logits [batch, seq_len, latent_dim] |
required |
latent_dim
|
int
|
Latent dimension |
16
|
beta_kl
|
float
|
Weight for KL term |
1.0
|
free_bits
|
float
|
Free bits budget |
0.3466
|
ignore_index
|
int
|
Index to ignore in reconstruction loss |
-100
|
Returns:
| Name | Type | Description |
|---|---|---|
total_loss |
Tensor
|
Combined loss |
metrics |
dict
|
Dictionary of individual loss components |