Skip to content

Model API Reference

free_transformer.model.FreeTransformer(config)

Bases: Module

Free Transformer: Conditional VAE-based language model with latent planning.

Implements the architecture from the Free Transformer paper with: - Split decoder stack (first half for context, second half for generation) - Non-causal encoder for latent plan inference - Binary mapper for differentiable discrete sampling - Injection mechanism for plan integration

Source code in src/free_transformer/model.py
def __init__(self, config):
    super().__init__()
    self.config = config

    # Token embeddings
    self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)

    # First half of decoder blocks (context processing)
    self.first_half_blocks = nn.ModuleList(
        [
            TransformerBlock(
                dim=config.hidden_dim,
                num_heads=config.num_heads,
                num_kv_heads=config.num_kv_heads,
                ffn_dim=config.ffn_hidden_dim,
                dropout=config.dropout,
                causal=True,
                use_rope=config.use_rope,
                max_seq_len=config.max_seq_len,
            )
            for _ in range(config.split_layer)
        ]
    )

    # Encoder module (non-causal, for plan inference)
    self.encoder = EncoderBlock(
        dim=config.hidden_dim,
        num_heads=config.num_heads,
        ffn_dim=config.ffn_hidden_dim,
        latent_dim=config.latent_dim,
        dropout=config.dropout,
    )

    # Binary mapper and latent plan handler
    self.latent_plan = LatentPlan(
        latent_dim=config.latent_dim,
        hidden_dim=config.hidden_dim,
    )

    # Injection mechanism
    self.injection = InjectionMechanism(config.hidden_dim)

    # Second half of decoder blocks (generation with plan)
    second_half_layers = config.num_layers - config.split_layer
    if second_half_layers <= 0:
        raise ValueError(
            f"split_layer ({config.split_layer}) must be less than num_layers ({config.num_layers})"
        )

    self.second_half_blocks = nn.ModuleList(
        [
            TransformerBlock(
                dim=config.hidden_dim,
                num_heads=config.num_heads,
                num_kv_heads=config.num_kv_heads,
                ffn_dim=config.ffn_hidden_dim,
                dropout=config.dropout,
                causal=True,
                use_rope=config.use_rope,
                max_seq_len=config.max_seq_len,
            )
            for _ in range(second_half_layers)
        ]
    )

    # Final output
    self.norm = RMSNorm(config.hidden_dim)
    self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)

    # Tie weights
    self.output.weight = self.token_embedding.weight

    self.apply(self._init_weights)

forward(tokens, mode='training')

Forward pass with mode switching.

Parameters:

Name Type Description Default
tokens Tensor

Input token IDs [batch, seq_len]

required
mode Literal['training', 'inference']

'training' uses encoder path, 'inference' samples random Z

'training'

Returns:

Name Type Description
logits Tensor

Output logits [batch, seq_len, vocab_size]

z_logits Optional[Tensor]

Encoder logits for Z (only in training mode) [batch, seq_len, latent_dim]

Source code in src/free_transformer/model.py
def forward(
    self,
    tokens: torch.Tensor,
    mode: Literal["training", "inference"] = "training",
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Forward pass with mode switching.

    Args:
        tokens: Input token IDs [batch, seq_len]
        mode: 'training' uses encoder path, 'inference' samples random Z

    Returns:
        logits: Output logits [batch, seq_len, vocab_size]
        z_logits: Encoder logits for Z (only in training mode) [batch, seq_len, latent_dim]
    """
    batch_size, seq_len = tokens.shape

    # 1. Embed tokens
    x = self.token_embedding(tokens)

    # 2. Process through first half of decoder
    for block in self.first_half_blocks:
        x = block(x)

    # 3. Generate or sample latent plan Z
    if mode == "training":
        # Encoder path: infer Z from context
        z_logits = self.encoder(x)  # [batch, seq_len, latent_dim]
        z_onehot = self.latent_plan.sample_from_logits(
            z_logits
        )  # [batch, seq_len, 2^latent_dim]
    else:  # inference
        # Sample Z from uniform prior
        z_onehot = self.latent_plan.sample_from_prior(batch_size, seq_len, device=tokens.device)
        z_logits = None

    # 4. Project Z to hidden dimension and inject into decoder
    z_projected = self.latent_plan.project_to_hidden(z_onehot)
    x_with_z = self.injection(x, z_projected)

    # 5. Process through second half of decoder
    # First block uses injected kv, rest use standard self-attention
    x = self.second_half_blocks[0](x, kv_input=x_with_z)
    for block in self.second_half_blocks[1:]:
        x = block(x)

    # 6. Final output projection
    x = self.norm(x)
    logits = self.output(x)

    assert isinstance(logits, torch.Tensor)
    return logits, z_logits

generate(prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=None)

Autoregressive generation with random latent plans.

Parameters:

Name Type Description Default
prompt_tokens Tensor

Initial tokens [batch, prompt_len]

required
max_new_tokens int

Number of tokens to generate

100
temperature float

Sampling temperature

1.0
top_k Optional[int]

Top-k filtering

None

Returns:

Type Description
Tensor

Generated tokens [batch, prompt_len + max_new_tokens]

Source code in src/free_transformer/model.py
@torch.no_grad()
def generate(
    self,
    prompt_tokens: torch.Tensor,
    max_new_tokens: int = 100,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
) -> torch.Tensor:
    """
    Autoregressive generation with random latent plans.

    Args:
        prompt_tokens: Initial tokens [batch, prompt_len]
        max_new_tokens: Number of tokens to generate
        temperature: Sampling temperature
        top_k: Top-k filtering

    Returns:
        Generated tokens [batch, prompt_len + max_new_tokens]
    """
    tokens = prompt_tokens

    for _ in range(max_new_tokens):
        # Forward pass in inference mode
        logits, _ = self.forward(tokens, mode="inference")

        # Get logits for last position
        logits = logits[:, -1, :] / temperature

        # Top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float("-inf")

        # Sample next token
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Append to sequence
        tokens = torch.cat([tokens, next_token], dim=1)

        # Truncate if exceeds max length
        if tokens.shape[1] > self.config.max_seq_len:
            tokens = tokens[:, -self.config.max_seq_len :]

    return tokens

free_transformer.baseline.TransformerBaseline(config)

Bases: Module

Standard autoregressive Transformer without latent planning.

Source code in src/free_transformer/baseline.py
def __init__(self, config):
    super().__init__()
    self.config = config

    self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)

    self.blocks = nn.ModuleList(
        [
            TransformerBlock(
                dim=config.hidden_dim,
                num_heads=config.num_heads,
                num_kv_heads=config.num_kv_heads,
                ffn_dim=config.ffn_hidden_dim,
                dropout=config.dropout,
                causal=True,
                use_rope=config.use_rope,
                max_seq_len=config.max_seq_len,
            )
            for _ in range(config.num_layers)
        ]
    )

    self.norm = RMSNorm(config.hidden_dim)
    self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
    self.output.weight = self.token_embedding.weight

    self.apply(self._init_weights)