Skip to content

Tokenization

tokenization

Tokenization and compression for DHB invariants (VLA).

Modules: - vqvae: Basic VQ-VAE tokenizer (DHBTokenizer) - rvq: Residual VQ for higher capacity (ResidualVQTokenizer) - hierarchical: Multi-level hierarchical tokenization - compression: BPE, entropy coding, RLE for token sequences - fast_tokenizer: FAST-style DCT + BPE tokenizer (no PyTorch needed) - fsq: Finite Scalar Quantization (no learned codebook) - register_encoder: Transformer encoder with register tokens - nested_dropout: Masked nested dropout for ordered token spaces - oat_decoder: Cross-attention decoder with prefix support - oat_tokenizer: Combined OAT-style tokenizer

Classes

BPECompressor

Byte-Pair Encoding for token sequences.

Merges frequent token pairs into super-tokens, reducing sequence length while preserving exact recoverability (lossless).

Inspired by FAST (Physical Intelligence, 2025) which achieves ~10x compression on action sequences via DCT + BPE.

Example

compressor = BPECompressor(vocab_size=512, num_merges=100) compressor.fit(token_corpus) # List of token sequences compressed = compressor.encode([1, 2, 1, 2, 3]) # [256, 256, 3] if (1,2)->256 original = compressor.decode(compressed)

Source code in src/dhb_xr/tokenization/compression.py
class BPECompressor:
    """
    Byte-Pair Encoding for token sequences.

    Merges frequent token pairs into super-tokens, reducing sequence length
    while preserving exact recoverability (lossless).

    Inspired by FAST (Physical Intelligence, 2025) which achieves ~10x compression
    on action sequences via DCT + BPE.

    Example:
        >>> compressor = BPECompressor(vocab_size=512, num_merges=100)
        >>> compressor.fit(token_corpus)  # List of token sequences
        >>> compressed = compressor.encode([1, 2, 1, 2, 3])  # [256, 256, 3] if (1,2)->256
        >>> original = compressor.decode(compressed)
    """

    def __init__(self, vocab_size: int = 256, num_merges: int = 100):
        """
        Args:
            vocab_size: Original VQ codebook size (tokens 0 to vocab_size-1)
            num_merges: Number of BPE merges to learn
        """
        self.vocab_size = vocab_size
        self.num_merges = num_merges
        self.merges: Dict[Tuple[int, int], int] = {}  # (a, b) -> merged_token
        self.reverse_merges: Dict[int, Tuple[int, int]] = {}  # merged_token -> (a, b)
        self._fitted = False

    def fit(self, token_sequences: List[List[int]]) -> "BPECompressor":
        """
        Learn BPE merges from a corpus of token sequences.

        Args:
            token_sequences: List of token sequences (each a list of ints)

        Returns:
            self
        """
        # Flatten and count pair frequencies
        all_tokens = []
        for seq in token_sequences:
            all_tokens.extend(list(seq))

        # Iteratively merge most frequent pairs
        current_tokens = list(all_tokens)
        next_id = self.vocab_size

        for _ in range(self.num_merges):
            # Count pairs
            pair_counts = Counter()
            for i in range(len(current_tokens) - 1):
                pair = (current_tokens[i], current_tokens[i + 1])
                pair_counts[pair] += 1

            if not pair_counts:
                break

            # Get most frequent pair
            best_pair = pair_counts.most_common(1)[0][0]
            if pair_counts[best_pair] < 2:
                break  # No benefit from merging singletons

            # Merge
            self.merges[best_pair] = next_id
            self.reverse_merges[next_id] = best_pair

            # Replace in sequence
            new_tokens = []
            i = 0
            while i < len(current_tokens):
                if i < len(current_tokens) - 1 and (current_tokens[i], current_tokens[i + 1]) == best_pair:
                    new_tokens.append(next_id)
                    i += 2
                else:
                    new_tokens.append(current_tokens[i])
                    i += 1

            current_tokens = new_tokens
            next_id += 1

        self._fitted = True
        return self

    def encode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
        """
        Encode a token sequence using learned BPE merges.

        Args:
            tokens: Original token sequence

        Returns:
            Compressed token sequence
        """
        if not self._fitted:
            raise RuntimeError("BPECompressor must be fitted before encoding")

        tokens = list(tokens)

        # Apply merges in order learned
        for (a, b), merged in self.merges.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
                    new_tokens.append(merged)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens

        return tokens

    def decode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
        """
        Decode a compressed sequence back to original tokens.

        Args:
            tokens: Compressed token sequence

        Returns:
            Original token sequence
        """
        tokens = list(tokens)

        # Recursively expand merged tokens
        changed = True
        while changed:
            changed = False
            new_tokens = []
            for t in tokens:
                if t in self.reverse_merges:
                    a, b = self.reverse_merges[t]
                    new_tokens.extend([a, b])
                    changed = True
                else:
                    new_tokens.append(t)
            tokens = new_tokens

        return tokens

    def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
        """Compute compression ratio (original_len / compressed_len)."""
        compressed = self.encode(tokens)
        return len(tokens) / len(compressed) if compressed else 1.0

    @property
    def extended_vocab_size(self) -> int:
        """Total vocabulary size including merged tokens."""
        return self.vocab_size + len(self.merges)

    def get_stats(self) -> Dict:
        """Get compression statistics."""
        return {
            "original_vocab": self.vocab_size,
            "num_merges": len(self.merges),
            "extended_vocab": self.extended_vocab_size,
            "fitted": self._fitted,
        }
Attributes
extended_vocab_size property
extended_vocab_size

Total vocabulary size including merged tokens.

Functions
__init__
__init__(vocab_size=256, num_merges=100)

Parameters:

Name Type Description Default
vocab_size int

Original VQ codebook size (tokens 0 to vocab_size-1)

256
num_merges int

Number of BPE merges to learn

100
Source code in src/dhb_xr/tokenization/compression.py
def __init__(self, vocab_size: int = 256, num_merges: int = 100):
    """
    Args:
        vocab_size: Original VQ codebook size (tokens 0 to vocab_size-1)
        num_merges: Number of BPE merges to learn
    """
    self.vocab_size = vocab_size
    self.num_merges = num_merges
    self.merges: Dict[Tuple[int, int], int] = {}  # (a, b) -> merged_token
    self.reverse_merges: Dict[int, Tuple[int, int]] = {}  # merged_token -> (a, b)
    self._fitted = False
compression_ratio
compression_ratio(tokens)

Compute compression ratio (original_len / compressed_len).

Source code in src/dhb_xr/tokenization/compression.py
def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
    """Compute compression ratio (original_len / compressed_len)."""
    compressed = self.encode(tokens)
    return len(tokens) / len(compressed) if compressed else 1.0
decode
decode(tokens)

Decode a compressed sequence back to original tokens.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Compressed token sequence

required

Returns:

Type Description
List[int]

Original token sequence

Source code in src/dhb_xr/tokenization/compression.py
def decode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
    """
    Decode a compressed sequence back to original tokens.

    Args:
        tokens: Compressed token sequence

    Returns:
        Original token sequence
    """
    tokens = list(tokens)

    # Recursively expand merged tokens
    changed = True
    while changed:
        changed = False
        new_tokens = []
        for t in tokens:
            if t in self.reverse_merges:
                a, b = self.reverse_merges[t]
                new_tokens.extend([a, b])
                changed = True
            else:
                new_tokens.append(t)
        tokens = new_tokens

    return tokens
encode
encode(tokens)

Encode a token sequence using learned BPE merges.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Original token sequence

required

Returns:

Type Description
List[int]

Compressed token sequence

Source code in src/dhb_xr/tokenization/compression.py
def encode(self, tokens: Union[List[int], np.ndarray]) -> List[int]:
    """
    Encode a token sequence using learned BPE merges.

    Args:
        tokens: Original token sequence

    Returns:
        Compressed token sequence
    """
    if not self._fitted:
        raise RuntimeError("BPECompressor must be fitted before encoding")

    tokens = list(tokens)

    # Apply merges in order learned
    for (a, b), merged in self.merges.items():
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and tokens[i] == a and tokens[i + 1] == b:
                new_tokens.append(merged)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        tokens = new_tokens

    return tokens
fit
fit(token_sequences)

Learn BPE merges from a corpus of token sequences.

Parameters:

Name Type Description Default
token_sequences List[List[int]]

List of token sequences (each a list of ints)

required

Returns:

Type Description
BPECompressor

self

Source code in src/dhb_xr/tokenization/compression.py
def fit(self, token_sequences: List[List[int]]) -> "BPECompressor":
    """
    Learn BPE merges from a corpus of token sequences.

    Args:
        token_sequences: List of token sequences (each a list of ints)

    Returns:
        self
    """
    # Flatten and count pair frequencies
    all_tokens = []
    for seq in token_sequences:
        all_tokens.extend(list(seq))

    # Iteratively merge most frequent pairs
    current_tokens = list(all_tokens)
    next_id = self.vocab_size

    for _ in range(self.num_merges):
        # Count pairs
        pair_counts = Counter()
        for i in range(len(current_tokens) - 1):
            pair = (current_tokens[i], current_tokens[i + 1])
            pair_counts[pair] += 1

        if not pair_counts:
            break

        # Get most frequent pair
        best_pair = pair_counts.most_common(1)[0][0]
        if pair_counts[best_pair] < 2:
            break  # No benefit from merging singletons

        # Merge
        self.merges[best_pair] = next_id
        self.reverse_merges[next_id] = best_pair

        # Replace in sequence
        new_tokens = []
        i = 0
        while i < len(current_tokens):
            if i < len(current_tokens) - 1 and (current_tokens[i], current_tokens[i + 1]) == best_pair:
                new_tokens.append(next_id)
                i += 2
            else:
                new_tokens.append(current_tokens[i])
                i += 1

        current_tokens = new_tokens
        next_id += 1

    self._fitted = True
    return self
get_stats
get_stats()

Get compression statistics.

Source code in src/dhb_xr/tokenization/compression.py
def get_stats(self) -> Dict:
    """Get compression statistics."""
    return {
        "original_vocab": self.vocab_size,
        "num_merges": len(self.merges),
        "extended_vocab": self.extended_vocab_size,
        "fitted": self._fitted,
    }

CausalConv1dEncoder

Bases: Module

Stack of causal convs: (B, T, C) -> (B, T, D).

Source code in src/dhb_xr/tokenization/causal_encoder.py
class CausalConv1dEncoder(nn.Module):
    """Stack of causal convs: (B, T, C) -> (B, T, D)."""

    def __init__(
        self,
        in_dim: int,
        hidden_dim: int,
        out_dim: int,
        num_layers: int = 2,
        kernel_size: int = 3,
    ):
        super().__init__()
        layers = []
        c_in = in_dim
        for _ in range(num_layers - 1):
            layers.append(CausalConv1d(c_in, hidden_dim, kernel_size))
            layers.append(nn.ReLU())
            c_in = hidden_dim
        layers.append(CausalConv1d(c_in, out_dim, kernel_size))
        self.net = nn.Sequential(*layers)
        self.out_dim = out_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C) -> (B, C, T)
        x = x.transpose(1, 2)
        out = self.net(x)
        return out.transpose(1, 2)

DHBTokenizer

Bases: Module

Causal VQ-VAE for invariant sequences. invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).

Source code in src/dhb_xr/tokenization/vqvae.py
class DHBTokenizer(nn.Module):
    """
    Causal VQ-VAE for invariant sequences.
    invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).
    """

    def __init__(
        self,
        invariant_dim: int,
        latent_dim: int,
        codebook_size: int,
        num_layers: int = 2,
        kernel_size: int = 3,
    ):
        super().__init__()
        self.encoder = CausalConv1dEncoder(
            invariant_dim, latent_dim, latent_dim, num_layers, kernel_size
        )
        self.vq = VectorQuantizer(codebook_size, latent_dim)
        self.decoder = CausalConv1dEncoder(
            latent_dim, latent_dim, invariant_dim, num_layers, kernel_size
        )
        self.invariant_dim = invariant_dim
        self.latent_dim = latent_dim
        self.codebook_size = codebook_size

    def forward(self, invariants: torch.Tensor) -> tuple:
        z = self.encoder(invariants)
        indices, z_q_st, z_q = self.vq(z)
        reconstructed = self.decoder(z_q_st)
        return indices, reconstructed, z, z_q

    def loss(
        self,
        invariants: torch.Tensor,
        reconstructed: torch.Tensor,
        z: torch.Tensor,
        z_q: torch.Tensor,
        beta: float = 0.25,
    ) -> torch.Tensor:
        rec_loss = F.mse_loss(reconstructed, invariants)
        commitment = F.mse_loss(z, z_q)
        codebook = F.mse_loss(z_q, z.detach())
        return rec_loss + beta * commitment + codebook

    # ---- Flow matching integration API ----

    def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
        """
        Encode invariants to continuous latent space (before quantization).

        This is useful for flow matching which operates in continuous space.

        Args:
            invariants: Input invariant sequences (B, T, C).

        Returns:
            Continuous latent z (B, T, latent_dim).
        """
        return self.encoder(invariants)

    def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode from continuous latent to invariants.

        Bypasses the VQ step, useful for flow matching generation.

        Args:
            z: Continuous latent (B, T, latent_dim).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        return self.decoder(z)

    def quantize(self, z: torch.Tensor) -> tuple:
        """
        Quantize continuous latent to discrete tokens.

        Args:
            z: Continuous latent (B, T, latent_dim).

        Returns:
            Tuple of (indices, z_q_st, z_q).
        """
        return self.vq(z)

    def get_codebook_embeddings(self) -> torch.Tensor:
        """
        Get the VQ codebook embeddings.

        Useful for flow matching in embedding space or visualization.

        Returns:
            Codebook embeddings (codebook_size, latent_dim).
        """
        return self.vq.embedding.weight.data

    def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Convert token indices to embeddings.

        Args:
            indices: Token indices (B, T).

        Returns:
            Token embeddings (B, T, latent_dim).
        """
        return self.vq.embedding(indices)

    def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Decode token indices to invariants.

        Args:
            indices: Token indices (B, T).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        z_q = self.embed_tokens(indices)
        return self.decoder(z_q)
Functions
decode_from_latent
decode_from_latent(z)

Decode from continuous latent to invariants.

Bypasses the VQ step, useful for flow matching generation.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
    """
    Decode from continuous latent to invariants.

    Bypasses the VQ step, useful for flow matching generation.

    Args:
        z: Continuous latent (B, T, latent_dim).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    return self.decoder(z)
decode_tokens
decode_tokens(indices)

Decode token indices to invariants.

Parameters:

Name Type Description Default
indices Tensor

Token indices (B, T).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
    """
    Decode token indices to invariants.

    Args:
        indices: Token indices (B, T).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    z_q = self.embed_tokens(indices)
    return self.decoder(z_q)
embed_tokens
embed_tokens(indices)

Convert token indices to embeddings.

Parameters:

Name Type Description Default
indices Tensor

Token indices (B, T).

required

Returns:

Type Description
Tensor

Token embeddings (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
    """
    Convert token indices to embeddings.

    Args:
        indices: Token indices (B, T).

    Returns:
        Token embeddings (B, T, latent_dim).
    """
    return self.vq.embedding(indices)
encode_continuous
encode_continuous(invariants)

Encode invariants to continuous latent space (before quantization).

This is useful for flow matching which operates in continuous space.

Parameters:

Name Type Description Default
invariants Tensor

Input invariant sequences (B, T, C).

required

Returns:

Type Description
Tensor

Continuous latent z (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
    """
    Encode invariants to continuous latent space (before quantization).

    This is useful for flow matching which operates in continuous space.

    Args:
        invariants: Input invariant sequences (B, T, C).

    Returns:
        Continuous latent z (B, T, latent_dim).
    """
    return self.encoder(invariants)
get_codebook_embeddings
get_codebook_embeddings()

Get the VQ codebook embeddings.

Useful for flow matching in embedding space or visualization.

Returns:

Type Description
Tensor

Codebook embeddings (codebook_size, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def get_codebook_embeddings(self) -> torch.Tensor:
    """
    Get the VQ codebook embeddings.

    Useful for flow matching in embedding space or visualization.

    Returns:
        Codebook embeddings (codebook_size, latent_dim).
    """
    return self.vq.embedding.weight.data
quantize
quantize(z)

Quantize continuous latent to discrete tokens.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required

Returns:

Type Description
tuple

Tuple of (indices, z_q_st, z_q).

Source code in src/dhb_xr/tokenization/vqvae.py
def quantize(self, z: torch.Tensor) -> tuple:
    """
    Quantize continuous latent to discrete tokens.

    Args:
        z: Continuous latent (B, T, latent_dim).

    Returns:
        Tuple of (indices, z_q_st, z_q).
    """
    return self.vq(z)

EntropyCompressor

Entropy coding (Huffman) for token sequences.

Assigns variable-length codes based on token frequencies, achieving near-optimal bits-per-token based on entropy.

For RVQ indices with K=256, naive encoding = 8 bits/token. With entropy coding: typically 4-6 bits/token (1.5-2x compression).

Source code in src/dhb_xr/tokenization/compression.py
class EntropyCompressor:
    """
    Entropy coding (Huffman) for token sequences.

    Assigns variable-length codes based on token frequencies,
    achieving near-optimal bits-per-token based on entropy.

    For RVQ indices with K=256, naive encoding = 8 bits/token.
    With entropy coding: typically 4-6 bits/token (1.5-2x compression).
    """

    def __init__(self):
        self.codes: Dict[int, str] = {}  # token -> binary string
        self.reverse_codes: Dict[str, int] = {}  # binary string -> token
        self.frequencies: Dict[int, int] = {}
        self._fitted = False

    def fit(self, token_sequences: List[List[int]]) -> "EntropyCompressor":
        """
        Build Huffman tree from token frequencies.

        Args:
            token_sequences: List of token sequences

        Returns:
            self
        """
        # Count frequencies
        self.frequencies = Counter()
        for seq in token_sequences:
            self.frequencies.update(seq)

        if not self.frequencies:
            self._fitted = True
            return self

        # Build Huffman tree
        heap = [HuffmanNode(token, freq) for token, freq in self.frequencies.items()]
        heapq.heapify(heap)

        while len(heap) > 1:
            left = heapq.heappop(heap)
            right = heapq.heappop(heap)
            merged = HuffmanNode(None, left.freq + right.freq, left, right)
            heapq.heappush(heap, merged)

        # Generate codes
        self.codes = {}
        if heap:
            self._generate_codes(heap[0], "")

        # Handle single-token case
        if len(self.codes) == 1:
            token = list(self.codes.keys())[0]
            self.codes[token] = "0"

        self.reverse_codes = {v: k for k, v in self.codes.items()}
        self._fitted = True
        return self

    def _generate_codes(self, node: HuffmanNode, code: str):
        """Recursively generate Huffman codes."""
        if node.token is not None:
            self.codes[node.token] = code if code else "0"
            return
        if node.left:
            self._generate_codes(node.left, code + "0")
        if node.right:
            self._generate_codes(node.right, code + "1")

    def encode(self, tokens: Union[List[int], np.ndarray]) -> str:
        """
        Encode tokens to binary string.

        Args:
            tokens: Token sequence

        Returns:
            Binary string (e.g., "0110101...")
        """
        if not self._fitted:
            raise RuntimeError("EntropyCompressor must be fitted before encoding")
        return "".join(self.codes.get(t, "") for t in tokens)

    def decode(self, binary_string: str) -> List[int]:
        """
        Decode binary string back to tokens.

        Args:
            binary_string: Encoded binary string

        Returns:
            Original token sequence
        """
        tokens = []
        current = ""
        for bit in binary_string:
            current += bit
            if current in self.reverse_codes:
                tokens.append(self.reverse_codes[current])
                current = ""
        return tokens

    def bits_per_token(self, tokens: Union[List[int], np.ndarray]) -> float:
        """Compute average bits per token."""
        encoded = self.encode(tokens)
        return len(encoded) / len(tokens) if tokens else 0.0

    def theoretical_entropy(self) -> float:
        """Compute theoretical entropy H = -sum(p * log2(p))."""
        total = sum(self.frequencies.values())
        if total == 0:
            return 0.0
        entropy = 0.0
        for freq in self.frequencies.values():
            p = freq / total
            if p > 0:
                entropy -= p * np.log2(p)
        return entropy

    def get_stats(self) -> Dict:
        """Get compression statistics."""
        return {
            "unique_tokens": len(self.codes),
            "theoretical_entropy": self.theoretical_entropy(),
            "avg_code_length": np.mean([len(c) for c in self.codes.values()]) if self.codes else 0,
            "fitted": self._fitted,
        }
Functions
bits_per_token
bits_per_token(tokens)

Compute average bits per token.

Source code in src/dhb_xr/tokenization/compression.py
def bits_per_token(self, tokens: Union[List[int], np.ndarray]) -> float:
    """Compute average bits per token."""
    encoded = self.encode(tokens)
    return len(encoded) / len(tokens) if tokens else 0.0
decode
decode(binary_string)

Decode binary string back to tokens.

Parameters:

Name Type Description Default
binary_string str

Encoded binary string

required

Returns:

Type Description
List[int]

Original token sequence

Source code in src/dhb_xr/tokenization/compression.py
def decode(self, binary_string: str) -> List[int]:
    """
    Decode binary string back to tokens.

    Args:
        binary_string: Encoded binary string

    Returns:
        Original token sequence
    """
    tokens = []
    current = ""
    for bit in binary_string:
        current += bit
        if current in self.reverse_codes:
            tokens.append(self.reverse_codes[current])
            current = ""
    return tokens
encode
encode(tokens)

Encode tokens to binary string.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required

Returns:

Type Description
str

Binary string (e.g., "0110101...")

Source code in src/dhb_xr/tokenization/compression.py
def encode(self, tokens: Union[List[int], np.ndarray]) -> str:
    """
    Encode tokens to binary string.

    Args:
        tokens: Token sequence

    Returns:
        Binary string (e.g., "0110101...")
    """
    if not self._fitted:
        raise RuntimeError("EntropyCompressor must be fitted before encoding")
    return "".join(self.codes.get(t, "") for t in tokens)
fit
fit(token_sequences)

Build Huffman tree from token frequencies.

Parameters:

Name Type Description Default
token_sequences List[List[int]]

List of token sequences

required

Returns:

Type Description
EntropyCompressor

self

Source code in src/dhb_xr/tokenization/compression.py
def fit(self, token_sequences: List[List[int]]) -> "EntropyCompressor":
    """
    Build Huffman tree from token frequencies.

    Args:
        token_sequences: List of token sequences

    Returns:
        self
    """
    # Count frequencies
    self.frequencies = Counter()
    for seq in token_sequences:
        self.frequencies.update(seq)

    if not self.frequencies:
        self._fitted = True
        return self

    # Build Huffman tree
    heap = [HuffmanNode(token, freq) for token, freq in self.frequencies.items()]
    heapq.heapify(heap)

    while len(heap) > 1:
        left = heapq.heappop(heap)
        right = heapq.heappop(heap)
        merged = HuffmanNode(None, left.freq + right.freq, left, right)
        heapq.heappush(heap, merged)

    # Generate codes
    self.codes = {}
    if heap:
        self._generate_codes(heap[0], "")

    # Handle single-token case
    if len(self.codes) == 1:
        token = list(self.codes.keys())[0]
        self.codes[token] = "0"

    self.reverse_codes = {v: k for k, v in self.codes.items()}
    self._fitted = True
    return self
get_stats
get_stats()

Get compression statistics.

Source code in src/dhb_xr/tokenization/compression.py
def get_stats(self) -> Dict:
    """Get compression statistics."""
    return {
        "unique_tokens": len(self.codes),
        "theoretical_entropy": self.theoretical_entropy(),
        "avg_code_length": np.mean([len(c) for c in self.codes.values()]) if self.codes else 0,
        "fitted": self._fitted,
    }
theoretical_entropy
theoretical_entropy()

Compute theoretical entropy H = -sum(p * log2(p)).

Source code in src/dhb_xr/tokenization/compression.py
def theoretical_entropy(self) -> float:
    """Compute theoretical entropy H = -sum(p * log2(p))."""
    total = sum(self.frequencies.values())
    if total == 0:
        return 0.0
    entropy = 0.0
    for freq in self.frequencies.values():
        p = freq / total
        if p > 0:
            entropy -= p * np.log2(p)
    return entropy

FASTTokenizer

FAST-style tokenizer: DCT frequency compression + BPE.

Converts continuous invariant (or action) chunks into a compact sequence of discrete tokens using the Discrete Cosine Transform for energy compaction and Byte-Pair Encoding for further compression.

Example

tokenizer = FASTTokenizer(scale=10.0, vocab_size=1024, num_merges=200)

Fit BPE on a corpus of invariant chunks

corpus = [np.random.randn(50, 8) for _ in range(100)] tokenizer.fit(corpus)

Tokenize a single chunk

tokens = tokenizer.encode(np.random.randn(50, 8)) recon = tokenizer.decode(tokens, time_horizon=50, dim=8)

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
class FASTTokenizer:
    """
    FAST-style tokenizer: DCT frequency compression + BPE.

    Converts continuous invariant (or action) chunks into a compact sequence of
    discrete tokens using the Discrete Cosine Transform for energy compaction
    and Byte-Pair Encoding for further compression.

    Example:
        >>> tokenizer = FASTTokenizer(scale=10.0, vocab_size=1024, num_merges=200)
        >>> # Fit BPE on a corpus of invariant chunks
        >>> corpus = [np.random.randn(50, 8) for _ in range(100)]
        >>> tokenizer.fit(corpus)
        >>> # Tokenize a single chunk
        >>> tokens = tokenizer.encode(np.random.randn(50, 8))
        >>> recon = tokenizer.decode(tokens, time_horizon=50, dim=8)
    """

    def __init__(
        self,
        scale: float = 10.0,
        vocab_size: int = 1024,
        num_merges: int = 200,
    ):
        """
        Args:
            scale: Scaling factor for DCT coefficients before rounding.
                   Higher values preserve more detail but increase alphabet size.
            vocab_size: Maximum BPE vocabulary size (includes initial alphabet + merges).
            num_merges: Number of BPE merge operations to learn.
        """
        if not HAS_SCIPY_FFT:
            raise ImportError(
                "FASTTokenizer requires scipy. Install with: pip install scipy"
            )

        self.scale = scale
        self.vocab_size = vocab_size
        self.num_merges = num_merges

        # BPE state
        self.merges: Dict[Tuple[int, int], int] = {}
        self.reverse_merges: Dict[int, Tuple[int, int]] = {}
        self._fitted = False

        # Metadata for decoding
        self._min_token: int = 0
        self._time_horizon: Optional[int] = None
        self._dim: Optional[int] = None

    # ------------------------------------------------------------------
    # DCT helpers
    # ------------------------------------------------------------------

    def _dct_encode(self, chunk: np.ndarray) -> np.ndarray:
        """Apply DCT along time axis, scale, and round.

        Args:
            chunk: (T, D) continuous invariant/action chunk.

        Returns:
            (T, D) integer DCT coefficients.
        """
        coeffs = dct(chunk, axis=0, norm="ortho")
        return np.around(coeffs * self.scale).astype(int)

    def _dct_decode(self, int_coeffs: np.ndarray) -> np.ndarray:
        """Inverse: un-scale and apply IDCT.

        Args:
            int_coeffs: (T, D) integer DCT coefficients.

        Returns:
            (T, D) reconstructed continuous chunk.
        """
        coeffs = int_coeffs.astype(float) / self.scale
        return idct(coeffs, axis=0, norm="ortho")

    # ------------------------------------------------------------------
    # BPE helpers
    # ------------------------------------------------------------------

    def _flatten_to_tokens(self, int_coeffs: np.ndarray) -> List[int]:
        """Flatten integer DCT coefficients to a token sequence.

        Shifts values so that the minimum token is 0.

        Args:
            int_coeffs: (T, D) integer coefficients.

        Returns:
            List of non-negative integer tokens.
        """
        flat = int_coeffs.flatten()
        shifted = (flat - self._min_token).tolist()
        return [max(0, t) for t in shifted]

    def _unflatten_from_tokens(
        self,
        tokens: List[int],
        time_horizon: int,
        dim: int,
    ) -> np.ndarray:
        """Reshape flat token list back to (T, D) integer coefficients.

        Args:
            tokens: Flat token list (after BPE decode).
            time_horizon: Number of timesteps T.
            dim: Invariant/action dimension D.

        Returns:
            (T, D) integer coefficient array.
        """
        arr = np.array(tokens, dtype=int) + self._min_token
        expected = time_horizon * dim
        if len(arr) < expected:
            arr = np.pad(arr, (0, expected - len(arr)))
        elif len(arr) > expected:
            arr = arr[:expected]
        return arr.reshape(time_horizon, dim)

    def _bpe_encode(self, tokens: List[int]) -> List[int]:
        """Apply learned BPE merges to a flat token sequence."""
        for (a, b), merged in self.merges.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (
                    i < len(tokens) - 1
                    and tokens[i] == a
                    and tokens[i + 1] == b
                ):
                    new_tokens.append(merged)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens
        return tokens

    def _bpe_decode(self, tokens: List[int]) -> List[int]:
        """Recursively expand BPE merges back to base tokens."""
        changed = True
        while changed:
            changed = False
            new_tokens = []
            for t in tokens:
                if t in self.reverse_merges:
                    a, b = self.reverse_merges[t]
                    new_tokens.extend([a, b])
                    changed = True
                else:
                    new_tokens.append(t)
            tokens = new_tokens
        return tokens

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def fit(self, chunks: List[np.ndarray]) -> "FASTTokenizer":
        """
        Learn BPE merges from a corpus of invariant/action chunks.

        Args:
            chunks: List of (T_i, D) arrays. Chunks may have different lengths
                    but must share the same dimension D.

        Returns:
            self (fitted tokenizer).
        """
        # Compute DCT + quantize for all chunks
        all_int_coeffs = [self._dct_encode(c) for c in chunks]

        # Determine global min/max for alphabet
        all_flat = np.concatenate([c.flatten() for c in all_int_coeffs])
        self._min_token = int(all_flat.min())
        max_token = int(all_flat.max())
        alphabet_size = max_token - self._min_token + 1

        if alphabet_size > self.vocab_size:
            import warnings

            warnings.warn(
                f"Initial alphabet size ({alphabet_size}) exceeds vocab_size "
                f"({self.vocab_size}). Consider increasing vocab_size or scale."
            )

        # Build flat token corpus
        corpus_tokens: List[int] = []
        for c in all_int_coeffs:
            corpus_tokens.extend(self._flatten_to_tokens(c))

        # Learn BPE merges
        current = list(corpus_tokens)
        next_id = alphabet_size  # Merge IDs start after initial alphabet

        for _ in range(self.num_merges):
            # Count pairs
            pair_counts: Dict[Tuple[int, int], int] = Counter()
            for i in range(len(current) - 1):
                pair_counts[(current[i], current[i + 1])] += 1

            if not pair_counts:
                break

            best_pair = max(pair_counts, key=pair_counts.get)
            if pair_counts[best_pair] < 2:
                break

            self.merges[best_pair] = next_id
            self.reverse_merges[next_id] = best_pair

            # Apply merge
            new_current = []
            i = 0
            while i < len(current):
                if (
                    i < len(current) - 1
                    and (current[i], current[i + 1]) == best_pair
                ):
                    new_current.append(next_id)
                    i += 2
                else:
                    new_current.append(current[i])
                    i += 1
            current = new_current
            next_id += 1

            if next_id >= self.vocab_size:
                break

        self._fitted = True

        # Cache typical dims from corpus
        if chunks:
            self._time_horizon = chunks[0].shape[0]
            self._dim = chunks[0].shape[1]

        return self

    def encode(self, chunk: np.ndarray) -> List[int]:
        """
        Tokenize a single invariant/action chunk.

        Args:
            chunk: (T, D) continuous chunk.

        Returns:
            List of BPE-compressed token IDs.
        """
        if not self._fitted:
            raise RuntimeError("FASTTokenizer must be fitted before encoding. Call .fit() first.")

        # Cache dimensions for decoding
        self._time_horizon = chunk.shape[0]
        self._dim = chunk.shape[1]

        int_coeffs = self._dct_encode(chunk)
        flat_tokens = self._flatten_to_tokens(int_coeffs)
        compressed = self._bpe_encode(flat_tokens)
        return compressed

    def encode_batch(self, chunks: np.ndarray) -> List[List[int]]:
        """
        Tokenize a batch of chunks.

        Args:
            chunks: (B, T, D) batch of chunks.

        Returns:
            List of B token sequences.
        """
        if chunks.ndim == 2:
            return [self.encode(chunks)]
        return [self.encode(chunks[i]) for i in range(chunks.shape[0])]

    def decode(
        self,
        tokens: List[int],
        time_horizon: Optional[int] = None,
        dim: Optional[int] = None,
    ) -> np.ndarray:
        """
        Decode tokens back to a continuous chunk.

        Args:
            tokens: BPE-compressed token sequence.
            time_horizon: Number of timesteps T (uses cached value if None).
            dim: Invariant/action dimension D (uses cached value if None).

        Returns:
            (T, D) reconstructed chunk.
        """
        T = time_horizon or self._time_horizon
        D = dim or self._dim
        if T is None or D is None:
            raise ValueError(
                "time_horizon and dim must be provided (or cached from encode)."
            )

        flat_tokens = self._bpe_decode(list(tokens))
        int_coeffs = self._unflatten_from_tokens(flat_tokens, T, D)
        return self._dct_decode(int_coeffs)

    def decode_batch(
        self,
        token_sequences: List[List[int]],
        time_horizon: Optional[int] = None,
        dim: Optional[int] = None,
    ) -> np.ndarray:
        """
        Decode a batch of token sequences.

        Args:
            token_sequences: List of B token sequences.
            time_horizon: Number of timesteps T.
            dim: Invariant/action dimension D.

        Returns:
            (B, T, D) batch of reconstructed chunks.
        """
        decoded = [self.decode(ts, time_horizon, dim) for ts in token_sequences]
        return np.stack(decoded)

    def compression_ratio(self, chunk: np.ndarray) -> float:
        """Compute compression ratio for a chunk.

        Returns:
            original_values / compressed_tokens.
        """
        tokens = self.encode(chunk)
        original = chunk.size  # T * D float values
        return original / len(tokens) if tokens else 1.0

    def reconstruction_error(self, chunk: np.ndarray) -> float:
        """Compute MSE reconstruction error for a chunk."""
        tokens = self.encode(chunk)
        recon = self.decode(tokens, chunk.shape[0], chunk.shape[1])
        return float(np.mean((chunk - recon) ** 2))

    def get_stats(self, chunk: Optional[np.ndarray] = None) -> Dict:
        """Get tokenizer statistics.

        Args:
            chunk: Optional chunk to compute per-chunk statistics.

        Returns:
            Dictionary of statistics.
        """
        stats = {
            "scale": self.scale,
            "vocab_size": self.vocab_size,
            "num_merges_learned": len(self.merges),
            "min_token": self._min_token,
            "fitted": self._fitted,
        }
        if chunk is not None:
            stats["compression_ratio"] = self.compression_ratio(chunk)
            stats["mse"] = self.reconstruction_error(chunk)
        return stats
Functions
__init__
__init__(scale=10.0, vocab_size=1024, num_merges=200)

Parameters:

Name Type Description Default
scale float

Scaling factor for DCT coefficients before rounding. Higher values preserve more detail but increase alphabet size.

10.0
vocab_size int

Maximum BPE vocabulary size (includes initial alphabet + merges).

1024
num_merges int

Number of BPE merge operations to learn.

200
Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def __init__(
    self,
    scale: float = 10.0,
    vocab_size: int = 1024,
    num_merges: int = 200,
):
    """
    Args:
        scale: Scaling factor for DCT coefficients before rounding.
               Higher values preserve more detail but increase alphabet size.
        vocab_size: Maximum BPE vocabulary size (includes initial alphabet + merges).
        num_merges: Number of BPE merge operations to learn.
    """
    if not HAS_SCIPY_FFT:
        raise ImportError(
            "FASTTokenizer requires scipy. Install with: pip install scipy"
        )

    self.scale = scale
    self.vocab_size = vocab_size
    self.num_merges = num_merges

    # BPE state
    self.merges: Dict[Tuple[int, int], int] = {}
    self.reverse_merges: Dict[int, Tuple[int, int]] = {}
    self._fitted = False

    # Metadata for decoding
    self._min_token: int = 0
    self._time_horizon: Optional[int] = None
    self._dim: Optional[int] = None
compression_ratio
compression_ratio(chunk)

Compute compression ratio for a chunk.

Returns:

Type Description
float

original_values / compressed_tokens.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def compression_ratio(self, chunk: np.ndarray) -> float:
    """Compute compression ratio for a chunk.

    Returns:
        original_values / compressed_tokens.
    """
    tokens = self.encode(chunk)
    original = chunk.size  # T * D float values
    return original / len(tokens) if tokens else 1.0
decode
decode(tokens, time_horizon=None, dim=None)

Decode tokens back to a continuous chunk.

Parameters:

Name Type Description Default
tokens List[int]

BPE-compressed token sequence.

required
time_horizon Optional[int]

Number of timesteps T (uses cached value if None).

None
dim Optional[int]

Invariant/action dimension D (uses cached value if None).

None

Returns:

Type Description
ndarray

(T, D) reconstructed chunk.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def decode(
    self,
    tokens: List[int],
    time_horizon: Optional[int] = None,
    dim: Optional[int] = None,
) -> np.ndarray:
    """
    Decode tokens back to a continuous chunk.

    Args:
        tokens: BPE-compressed token sequence.
        time_horizon: Number of timesteps T (uses cached value if None).
        dim: Invariant/action dimension D (uses cached value if None).

    Returns:
        (T, D) reconstructed chunk.
    """
    T = time_horizon or self._time_horizon
    D = dim or self._dim
    if T is None or D is None:
        raise ValueError(
            "time_horizon and dim must be provided (or cached from encode)."
        )

    flat_tokens = self._bpe_decode(list(tokens))
    int_coeffs = self._unflatten_from_tokens(flat_tokens, T, D)
    return self._dct_decode(int_coeffs)
decode_batch
decode_batch(token_sequences, time_horizon=None, dim=None)

Decode a batch of token sequences.

Parameters:

Name Type Description Default
token_sequences List[List[int]]

List of B token sequences.

required
time_horizon Optional[int]

Number of timesteps T.

None
dim Optional[int]

Invariant/action dimension D.

None

Returns:

Type Description
ndarray

(B, T, D) batch of reconstructed chunks.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def decode_batch(
    self,
    token_sequences: List[List[int]],
    time_horizon: Optional[int] = None,
    dim: Optional[int] = None,
) -> np.ndarray:
    """
    Decode a batch of token sequences.

    Args:
        token_sequences: List of B token sequences.
        time_horizon: Number of timesteps T.
        dim: Invariant/action dimension D.

    Returns:
        (B, T, D) batch of reconstructed chunks.
    """
    decoded = [self.decode(ts, time_horizon, dim) for ts in token_sequences]
    return np.stack(decoded)
encode
encode(chunk)

Tokenize a single invariant/action chunk.

Parameters:

Name Type Description Default
chunk ndarray

(T, D) continuous chunk.

required

Returns:

Type Description
List[int]

List of BPE-compressed token IDs.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def encode(self, chunk: np.ndarray) -> List[int]:
    """
    Tokenize a single invariant/action chunk.

    Args:
        chunk: (T, D) continuous chunk.

    Returns:
        List of BPE-compressed token IDs.
    """
    if not self._fitted:
        raise RuntimeError("FASTTokenizer must be fitted before encoding. Call .fit() first.")

    # Cache dimensions for decoding
    self._time_horizon = chunk.shape[0]
    self._dim = chunk.shape[1]

    int_coeffs = self._dct_encode(chunk)
    flat_tokens = self._flatten_to_tokens(int_coeffs)
    compressed = self._bpe_encode(flat_tokens)
    return compressed
encode_batch
encode_batch(chunks)

Tokenize a batch of chunks.

Parameters:

Name Type Description Default
chunks ndarray

(B, T, D) batch of chunks.

required

Returns:

Type Description
List[List[int]]

List of B token sequences.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def encode_batch(self, chunks: np.ndarray) -> List[List[int]]:
    """
    Tokenize a batch of chunks.

    Args:
        chunks: (B, T, D) batch of chunks.

    Returns:
        List of B token sequences.
    """
    if chunks.ndim == 2:
        return [self.encode(chunks)]
    return [self.encode(chunks[i]) for i in range(chunks.shape[0])]
fit
fit(chunks)

Learn BPE merges from a corpus of invariant/action chunks.

Parameters:

Name Type Description Default
chunks List[ndarray]

List of (T_i, D) arrays. Chunks may have different lengths but must share the same dimension D.

required

Returns:

Type Description
FASTTokenizer

self (fitted tokenizer).

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def fit(self, chunks: List[np.ndarray]) -> "FASTTokenizer":
    """
    Learn BPE merges from a corpus of invariant/action chunks.

    Args:
        chunks: List of (T_i, D) arrays. Chunks may have different lengths
                but must share the same dimension D.

    Returns:
        self (fitted tokenizer).
    """
    # Compute DCT + quantize for all chunks
    all_int_coeffs = [self._dct_encode(c) for c in chunks]

    # Determine global min/max for alphabet
    all_flat = np.concatenate([c.flatten() for c in all_int_coeffs])
    self._min_token = int(all_flat.min())
    max_token = int(all_flat.max())
    alphabet_size = max_token - self._min_token + 1

    if alphabet_size > self.vocab_size:
        import warnings

        warnings.warn(
            f"Initial alphabet size ({alphabet_size}) exceeds vocab_size "
            f"({self.vocab_size}). Consider increasing vocab_size or scale."
        )

    # Build flat token corpus
    corpus_tokens: List[int] = []
    for c in all_int_coeffs:
        corpus_tokens.extend(self._flatten_to_tokens(c))

    # Learn BPE merges
    current = list(corpus_tokens)
    next_id = alphabet_size  # Merge IDs start after initial alphabet

    for _ in range(self.num_merges):
        # Count pairs
        pair_counts: Dict[Tuple[int, int], int] = Counter()
        for i in range(len(current) - 1):
            pair_counts[(current[i], current[i + 1])] += 1

        if not pair_counts:
            break

        best_pair = max(pair_counts, key=pair_counts.get)
        if pair_counts[best_pair] < 2:
            break

        self.merges[best_pair] = next_id
        self.reverse_merges[next_id] = best_pair

        # Apply merge
        new_current = []
        i = 0
        while i < len(current):
            if (
                i < len(current) - 1
                and (current[i], current[i + 1]) == best_pair
            ):
                new_current.append(next_id)
                i += 2
            else:
                new_current.append(current[i])
                i += 1
        current = new_current
        next_id += 1

        if next_id >= self.vocab_size:
            break

    self._fitted = True

    # Cache typical dims from corpus
    if chunks:
        self._time_horizon = chunks[0].shape[0]
        self._dim = chunks[0].shape[1]

    return self
get_stats
get_stats(chunk=None)

Get tokenizer statistics.

Parameters:

Name Type Description Default
chunk Optional[ndarray]

Optional chunk to compute per-chunk statistics.

None

Returns:

Type Description
Dict

Dictionary of statistics.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def get_stats(self, chunk: Optional[np.ndarray] = None) -> Dict:
    """Get tokenizer statistics.

    Args:
        chunk: Optional chunk to compute per-chunk statistics.

    Returns:
        Dictionary of statistics.
    """
    stats = {
        "scale": self.scale,
        "vocab_size": self.vocab_size,
        "num_merges_learned": len(self.merges),
        "min_token": self._min_token,
        "fitted": self._fitted,
    }
    if chunk is not None:
        stats["compression_ratio"] = self.compression_ratio(chunk)
        stats["mse"] = self.reconstruction_error(chunk)
    return stats
reconstruction_error
reconstruction_error(chunk)

Compute MSE reconstruction error for a chunk.

Source code in src/dhb_xr/tokenization/fast_tokenizer.py
def reconstruction_error(self, chunk: np.ndarray) -> float:
    """Compute MSE reconstruction error for a chunk."""
    tokens = self.encode(chunk)
    recon = self.decode(tokens, chunk.shape[0], chunk.shape[1])
    return float(np.mean((chunk - recon) ** 2))

FSQ

Bases: Module

Finite Scalar Quantization.

Maps continuous latents to discrete codes by bounding (tanh) and rounding. No learned codebook -- the codebook is implicitly defined by the levels.

Example

fsq = FSQ(levels=[8, 5, 5, 5])

Effective codebook size: 8 * 5 * 5 * 5 = 1000

z = torch.randn(2, 10, 4) # (B, T, D) where D = len(levels) z_q, indices = fsq(z)

z_q: (2, 10, 4) quantized, indices: (2, 10) codebook indices

z_recon = fsq.indices_to_embedding(indices) assert torch.allclose(z_q, z_recon)

Parameters:

Name Type Description Default
levels List[int]

List of quantization levels per dimension. E.g., [8, 5, 5, 5] -> codebook size 1000.

required
drop_quant_p float

During training, probability of skipping quantization per sample (quantization dropout for regularization).

0.0
Source code in src/dhb_xr/tokenization/fsq.py
class FSQ(nn.Module):
    """
    Finite Scalar Quantization.

    Maps continuous latents to discrete codes by bounding (tanh) and rounding.
    No learned codebook -- the codebook is implicitly defined by the levels.

    Example:
        >>> fsq = FSQ(levels=[8, 5, 5, 5])
        >>> # Effective codebook size: 8 * 5 * 5 * 5 = 1000
        >>> z = torch.randn(2, 10, 4)  # (B, T, D) where D = len(levels)
        >>> z_q, indices = fsq(z)
        >>> # z_q: (2, 10, 4) quantized, indices: (2, 10) codebook indices
        >>> z_recon = fsq.indices_to_embedding(indices)
        >>> assert torch.allclose(z_q, z_recon)

    Args:
        levels: List of quantization levels per dimension.
                E.g., [8, 5, 5, 5] -> codebook size 1000.
        drop_quant_p: During training, probability of skipping quantization
                      per sample (quantization dropout for regularization).
    """

    def __init__(
        self,
        levels: List[int],
        drop_quant_p: float = 0.0,
    ):
        super().__init__()

        _levels = torch.tensor(levels, dtype=torch.int32)
        self.register_buffer("_levels", _levels, persistent=False)

        _basis = torch.cumprod(
            torch.tensor([1] + levels[:-1], dtype=torch.int32), dim=0
        )
        self.register_buffer("_basis", _basis, persistent=False)

        self.dim = len(levels)
        self.codebook_size = int(_levels.prod().item())
        self.drop_quant_p = drop_quant_p

        # Build implicit codebook for lookup
        implicit_codebook = self.indices_to_embedding(
            torch.arange(self.codebook_size)
        )
        self.register_buffer(
            "implicit_codebook", implicit_codebook, persistent=False
        )

    def __repr__(self) -> str:
        return (
            f"FSQ(levels={self._levels.tolist()}, "
            f"codebook_size={self.codebook_size}, "
            f"drop_quant_p={self.drop_quant_p})"
        )

    @property
    def latent_dim(self) -> int:
        """Dimension of the latent space (= number of FSQ levels)."""
        return self.dim

    def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
        """Bound z to the valid range for each level via tanh.

        Args:
            z: (..., D) unbounded latent vectors.

        Returns:
            (..., D) bounded latent vectors in [-half_l, half_l].
        """
        half_l = (self._levels - 1) * (1 + eps) / 2
        offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
        shift = (offset / half_l).atanh()
        return (z + shift).tanh() * half_l - offset

    def quantize(self, z: Tensor) -> Tensor:
        """Quantize z: bound, round (STE), normalize to [-1, 1].

        Args:
            z: (..., D) unbounded latent vectors.

        Returns:
            (..., D) quantized and normalized vectors.
        """
        bounded = self.bound(z)

        if self.training and self.drop_quant_p > 0.0:
            # Quantization dropout: skip quantization for some samples
            zhat = round_ste(bounded)
            B = z.shape[0]
            mask = torch.bernoulli(
                torch.full((B,), self.drop_quant_p, device=z.device)
            )
            mask = mask.view(B, *([1] * (z.ndim - 1)))
            # mask=1 -> keep original (skip quant), mask=0 -> use quantized
            quantized = bounded * mask + zhat * (1 - mask)
        else:
            quantized = round_ste(bounded)

        # Normalize to [-1, 1]
        half_width = self._levels // 2
        return quantized / half_width

    def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
        """Convert normalized codes to integer codes."""
        half_width = self._levels // 2
        return (zhat_normalized * half_width) + half_width

    def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
        """Convert integer codes to normalized codes."""
        half_width = self._levels // 2
        return (zhat - half_width) / half_width

    def codes_to_indices(self, zhat: Tensor) -> Tensor:
        """Convert quantized codes to flat codebook indices.

        Args:
            zhat: (..., D) quantized codes (normalized).

        Returns:
            (...) integer indices.
        """
        assert zhat.shape[-1] == self.dim
        zhat_int = self._scale_and_shift(zhat)
        return (zhat_int * self._basis).sum(dim=-1).to(torch.int64)

    def indices_to_embedding(self, indices: Tensor) -> Tensor:
        """Convert flat codebook indices to normalized code vectors.

        Args:
            indices: (...) integer indices.

        Returns:
            (..., D) normalized code vectors.
        """
        indices = indices.unsqueeze(-1)
        codes_non_centered = (indices // self._basis) % self._levels
        return self._scale_and_shift_inverse(codes_non_centered.float())

    @autocast(device_type="cuda", enabled=False)
    def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
        """Quantize and return codes + indices.

        Args:
            z: (..., D) continuous latent vectors.

        Returns:
            Tuple of:
            - z_q: (..., D) quantized vectors (normalized to [-1, 1]).
            - indices: (...) flat codebook indices.
        """
        assert z.shape[-1] == self.dim, (
            f"Expected last dim {self.dim}, got {z.shape[-1]}"
        )
        z_q = self.quantize(z.float())
        indices = self.codes_to_indices(z_q)
        return z_q, indices.long()
Attributes
latent_dim property
latent_dim

Dimension of the latent space (= number of FSQ levels).

Functions
bound
bound(z, eps=0.001)

Bound z to the valid range for each level via tanh.

Parameters:

Name Type Description Default
z Tensor

(..., D) unbounded latent vectors.

required

Returns:

Type Description
Tensor

(..., D) bounded latent vectors in [-half_l, half_l].

Source code in src/dhb_xr/tokenization/fsq.py
def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
    """Bound z to the valid range for each level via tanh.

    Args:
        z: (..., D) unbounded latent vectors.

    Returns:
        (..., D) bounded latent vectors in [-half_l, half_l].
    """
    half_l = (self._levels - 1) * (1 + eps) / 2
    offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
    shift = (offset / half_l).atanh()
    return (z + shift).tanh() * half_l - offset
codes_to_indices
codes_to_indices(zhat)

Convert quantized codes to flat codebook indices.

Parameters:

Name Type Description Default
zhat Tensor

(..., D) quantized codes (normalized).

required

Returns:

Type Description
Tensor

(...) integer indices.

Source code in src/dhb_xr/tokenization/fsq.py
def codes_to_indices(self, zhat: Tensor) -> Tensor:
    """Convert quantized codes to flat codebook indices.

    Args:
        zhat: (..., D) quantized codes (normalized).

    Returns:
        (...) integer indices.
    """
    assert zhat.shape[-1] == self.dim
    zhat_int = self._scale_and_shift(zhat)
    return (zhat_int * self._basis).sum(dim=-1).to(torch.int64)
forward
forward(z)

Quantize and return codes + indices.

Parameters:

Name Type Description Default
z Tensor

(..., D) continuous latent vectors.

required

Returns:

Type Description
Tensor

Tuple of:

Tensor
  • z_q: (..., D) quantized vectors (normalized to [-1, 1]).
Tuple[Tensor, Tensor]
  • indices: (...) flat codebook indices.
Source code in src/dhb_xr/tokenization/fsq.py
@autocast(device_type="cuda", enabled=False)
def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
    """Quantize and return codes + indices.

    Args:
        z: (..., D) continuous latent vectors.

    Returns:
        Tuple of:
        - z_q: (..., D) quantized vectors (normalized to [-1, 1]).
        - indices: (...) flat codebook indices.
    """
    assert z.shape[-1] == self.dim, (
        f"Expected last dim {self.dim}, got {z.shape[-1]}"
    )
    z_q = self.quantize(z.float())
    indices = self.codes_to_indices(z_q)
    return z_q, indices.long()
indices_to_embedding
indices_to_embedding(indices)

Convert flat codebook indices to normalized code vectors.

Parameters:

Name Type Description Default
indices Tensor

(...) integer indices.

required

Returns:

Type Description
Tensor

(..., D) normalized code vectors.

Source code in src/dhb_xr/tokenization/fsq.py
def indices_to_embedding(self, indices: Tensor) -> Tensor:
    """Convert flat codebook indices to normalized code vectors.

    Args:
        indices: (...) integer indices.

    Returns:
        (..., D) normalized code vectors.
    """
    indices = indices.unsqueeze(-1)
    codes_non_centered = (indices // self._basis) % self._levels
    return self._scale_and_shift_inverse(codes_non_centered.float())
quantize
quantize(z)

Quantize z: bound, round (STE), normalize to [-1, 1].

Parameters:

Name Type Description Default
z Tensor

(..., D) unbounded latent vectors.

required

Returns:

Type Description
Tensor

(..., D) quantized and normalized vectors.

Source code in src/dhb_xr/tokenization/fsq.py
def quantize(self, z: Tensor) -> Tensor:
    """Quantize z: bound, round (STE), normalize to [-1, 1].

    Args:
        z: (..., D) unbounded latent vectors.

    Returns:
        (..., D) quantized and normalized vectors.
    """
    bounded = self.bound(z)

    if self.training and self.drop_quant_p > 0.0:
        # Quantization dropout: skip quantization for some samples
        zhat = round_ste(bounded)
        B = z.shape[0]
        mask = torch.bernoulli(
            torch.full((B,), self.drop_quant_p, device=z.device)
        )
        mask = mask.view(B, *([1] * (z.ndim - 1)))
        # mask=1 -> keep original (skip quant), mask=0 -> use quantized
        quantized = bounded * mask + zhat * (1 - mask)
    else:
        quantized = round_ste(bounded)

    # Normalize to [-1, 1]
    half_width = self._levels // 2
    return quantized / half_width

HierarchicalTokenizer

Bases: Module

Hierarchical RVQ with variable-rate output.

Provides coarse-to-fine tokenization: - Level 0: Low-frequency global structure (high compression) - Level 1-N: Residual details (configurable refinement)

For inference, can truncate to fewer levels for faster/coarser output.

Example

tokenizer = HierarchicalTokenizer( ... invariant_dim=8, latent_dim=32, codebook_size=256, num_levels=4 ... ) tokens, recon = tokenizer(invariants)

Coarse only (4x fewer tokens)

tokens_coarse, recon_coarse = tokenizer(invariants, max_level=1)

Source code in src/dhb_xr/tokenization/hierarchical.py
class HierarchicalTokenizer(nn.Module):
    """
    Hierarchical RVQ with variable-rate output.

    Provides coarse-to-fine tokenization:
    - Level 0: Low-frequency global structure (high compression)
    - Level 1-N: Residual details (configurable refinement)

    For inference, can truncate to fewer levels for faster/coarser output.

    Example:
        >>> tokenizer = HierarchicalTokenizer(
        ...     invariant_dim=8, latent_dim=32, codebook_size=256, num_levels=4
        ... )
        >>> tokens, recon = tokenizer(invariants)
        >>> 
        >>> # Coarse only (4x fewer tokens)
        >>> tokens_coarse, recon_coarse = tokenizer(invariants, max_level=1)
    """

    def __init__(
        self,
        invariant_dim: int,
        latent_dim: int,
        codebook_size: int,
        num_levels: int = 4,
        temporal_downsample: int = 2,
        num_layers: int = 2,
    ):
        """
        Args:
            invariant_dim: DHB invariant dimension (typically 8)
            latent_dim: Latent embedding dimension
            codebook_size: VQ codebook size per level
            num_levels: Number of hierarchy levels
            temporal_downsample: Downsample factor between levels
            num_layers: Conv layers per encoder/decoder
        """
        super().__init__()

        self.invariant_dim = invariant_dim
        self.latent_dim = latent_dim
        self.codebook_size = codebook_size
        self.num_levels = num_levels
        self.temporal_downsample = temporal_downsample

        # Per-level encoders (progressively downsample)
        self.encoders = nn.ModuleList()
        self.vqs = nn.ModuleList()
        self.decoders = nn.ModuleList()

        for level in range(num_levels):
            # Encoder: downsample temporally at each level
            if level == 0:
                enc = CausalConv1dEncoder(
                    invariant_dim, latent_dim, latent_dim, num_layers
                )
            else:
                enc = nn.Sequential(
                    CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
                    TemporalDownsample(temporal_downsample),
                )
            self.encoders.append(enc)

            # VQ at each level
            self.vqs.append(VectorQuantizer(codebook_size, latent_dim))

            # Decoder: upsample to match previous level
            if level == 0:
                dec = CausalConv1dEncoder(
                    latent_dim, latent_dim, invariant_dim, num_layers
                )
            else:
                dec = nn.Sequential(
                    TemporalUpsample(temporal_downsample),
                    CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
                )
            self.decoders.append(dec)

        # Final projection back to invariant space
        self.final_proj = nn.Linear(latent_dim, invariant_dim)

    def forward(
        self, 
        invariants: torch.Tensor, 
        max_level: int = None,
        return_all_levels: bool = False,
    ) -> tuple:
        """
        Hierarchical encoding and decoding.

        Args:
            invariants: (B, T, invariant_dim) input
            max_level: Stop at this level (None = all levels)
            return_all_levels: Return tokens/recon at each level

        Returns:
            all_tokens: List of (B, T_l) tokens per level
            reconstructed: (B, T, invariant_dim) reconstruction
            level_info: Optional dict with per-level details
        """
        B, T, C = invariants.shape
        max_level = max_level or self.num_levels

        all_tokens = []
        all_z = []
        all_z_q = []
        level_info = {}

        # Encode through hierarchy
        x = invariants
        for level in range(max_level):
            z = self.encoders[level](x if level == 0 else z_residual)
            indices, z_q_st, z_q = self.vqs[level](z)

            all_tokens.append(indices)
            all_z.append(z)
            all_z_q.append(z_q)

            if level < max_level - 1:
                z_residual = z - z_q.detach()

            level_info[f"level_{level}"] = {
                "shape": tuple(z.shape),
                "tokens": indices.shape[-1],
            }

        # Decode through hierarchy (reverse order)
        reconstructed = torch.zeros_like(invariants)
        for level in reversed(range(max_level)):
            dec_out = self.decoders[level](all_z_q[level])

            # Match temporal dimension
            if dec_out.shape[1] > reconstructed.shape[1]:
                dec_out = dec_out[:, :reconstructed.shape[1], :]
            elif dec_out.shape[1] < reconstructed.shape[1]:
                # Upsample to match
                dec_out = F.interpolate(
                    dec_out.transpose(1, 2), 
                    size=reconstructed.shape[1],
                    mode='linear',
                    align_corners=True
                ).transpose(1, 2)

            if level == 0:
                reconstructed = dec_out
            else:
                reconstructed = reconstructed + self.final_proj(dec_out)

        if return_all_levels:
            return all_tokens, reconstructed, level_info
        return all_tokens, reconstructed

    def loss(
        self,
        invariants: torch.Tensor,
        reconstructed: torch.Tensor,
        all_z: list,
        all_z_q: list,
        beta: float = 0.25,
        level_weights: list = None,
    ) -> torch.Tensor:
        """
        Compute hierarchical loss with per-level weighting.

        Args:
            invariants: Original input
            reconstructed: Reconstruction
            all_z: Latents at each level
            all_z_q: Quantized latents at each level
            beta: Commitment loss weight
            level_weights: Optional weights per level (default: exponential decay)

        Returns:
            Total loss
        """
        # Reconstruction loss
        rec_loss = F.mse_loss(reconstructed, invariants)

        # Per-level VQ losses
        if level_weights is None:
            level_weights = [0.5 ** i for i in range(len(all_z))]

        commitment = 0
        codebook = 0
        for i, (z, z_q) in enumerate(zip(all_z, all_z_q)):
            commitment += level_weights[i] * F.mse_loss(z, z_q.detach())
            codebook += level_weights[i] * F.mse_loss(z_q, z.detach())

        return rec_loss + beta * commitment + codebook

    def get_compression_stats(self, T: int, max_level: int = None) -> dict:
        """
        Compute compression statistics.

        Args:
            T: Original sequence length
            max_level: Number of levels to use

        Returns:
            Compression statistics
        """
        max_level = max_level or self.num_levels

        total_tokens = 0
        for level in range(max_level):
            level_T = T // (self.temporal_downsample ** level)
            total_tokens += level_T

        original_values = T * self.invariant_dim
        token_values = total_tokens  # Each token is one index

        return {
            "original_values": original_values,
            "total_tokens": total_tokens,
            "tokens_per_level": [T // (self.temporal_downsample ** l) for l in range(max_level)],
            "compression_ratio": original_values / token_values if token_values > 0 else 1,
            "bits_per_value": (total_tokens * np.log2(self.codebook_size)) / original_values,
        }
Functions
__init__
__init__(
    invariant_dim, latent_dim, codebook_size, num_levels=4, temporal_downsample=2, num_layers=2
)

Parameters:

Name Type Description Default
invariant_dim int

DHB invariant dimension (typically 8)

required
latent_dim int

Latent embedding dimension

required
codebook_size int

VQ codebook size per level

required
num_levels int

Number of hierarchy levels

4
temporal_downsample int

Downsample factor between levels

2
num_layers int

Conv layers per encoder/decoder

2
Source code in src/dhb_xr/tokenization/hierarchical.py
def __init__(
    self,
    invariant_dim: int,
    latent_dim: int,
    codebook_size: int,
    num_levels: int = 4,
    temporal_downsample: int = 2,
    num_layers: int = 2,
):
    """
    Args:
        invariant_dim: DHB invariant dimension (typically 8)
        latent_dim: Latent embedding dimension
        codebook_size: VQ codebook size per level
        num_levels: Number of hierarchy levels
        temporal_downsample: Downsample factor between levels
        num_layers: Conv layers per encoder/decoder
    """
    super().__init__()

    self.invariant_dim = invariant_dim
    self.latent_dim = latent_dim
    self.codebook_size = codebook_size
    self.num_levels = num_levels
    self.temporal_downsample = temporal_downsample

    # Per-level encoders (progressively downsample)
    self.encoders = nn.ModuleList()
    self.vqs = nn.ModuleList()
    self.decoders = nn.ModuleList()

    for level in range(num_levels):
        # Encoder: downsample temporally at each level
        if level == 0:
            enc = CausalConv1dEncoder(
                invariant_dim, latent_dim, latent_dim, num_layers
            )
        else:
            enc = nn.Sequential(
                CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
                TemporalDownsample(temporal_downsample),
            )
        self.encoders.append(enc)

        # VQ at each level
        self.vqs.append(VectorQuantizer(codebook_size, latent_dim))

        # Decoder: upsample to match previous level
        if level == 0:
            dec = CausalConv1dEncoder(
                latent_dim, latent_dim, invariant_dim, num_layers
            )
        else:
            dec = nn.Sequential(
                TemporalUpsample(temporal_downsample),
                CausalConv1dEncoder(latent_dim, latent_dim, latent_dim, num_layers),
            )
        self.decoders.append(dec)

    # Final projection back to invariant space
    self.final_proj = nn.Linear(latent_dim, invariant_dim)
forward
forward(invariants, max_level=None, return_all_levels=False)

Hierarchical encoding and decoding.

Parameters:

Name Type Description Default
invariants Tensor

(B, T, invariant_dim) input

required
max_level int

Stop at this level (None = all levels)

None
return_all_levels bool

Return tokens/recon at each level

False

Returns:

Name Type Description
all_tokens tuple

List of (B, T_l) tokens per level

reconstructed tuple

(B, T, invariant_dim) reconstruction

level_info tuple

Optional dict with per-level details

Source code in src/dhb_xr/tokenization/hierarchical.py
def forward(
    self, 
    invariants: torch.Tensor, 
    max_level: int = None,
    return_all_levels: bool = False,
) -> tuple:
    """
    Hierarchical encoding and decoding.

    Args:
        invariants: (B, T, invariant_dim) input
        max_level: Stop at this level (None = all levels)
        return_all_levels: Return tokens/recon at each level

    Returns:
        all_tokens: List of (B, T_l) tokens per level
        reconstructed: (B, T, invariant_dim) reconstruction
        level_info: Optional dict with per-level details
    """
    B, T, C = invariants.shape
    max_level = max_level or self.num_levels

    all_tokens = []
    all_z = []
    all_z_q = []
    level_info = {}

    # Encode through hierarchy
    x = invariants
    for level in range(max_level):
        z = self.encoders[level](x if level == 0 else z_residual)
        indices, z_q_st, z_q = self.vqs[level](z)

        all_tokens.append(indices)
        all_z.append(z)
        all_z_q.append(z_q)

        if level < max_level - 1:
            z_residual = z - z_q.detach()

        level_info[f"level_{level}"] = {
            "shape": tuple(z.shape),
            "tokens": indices.shape[-1],
        }

    # Decode through hierarchy (reverse order)
    reconstructed = torch.zeros_like(invariants)
    for level in reversed(range(max_level)):
        dec_out = self.decoders[level](all_z_q[level])

        # Match temporal dimension
        if dec_out.shape[1] > reconstructed.shape[1]:
            dec_out = dec_out[:, :reconstructed.shape[1], :]
        elif dec_out.shape[1] < reconstructed.shape[1]:
            # Upsample to match
            dec_out = F.interpolate(
                dec_out.transpose(1, 2), 
                size=reconstructed.shape[1],
                mode='linear',
                align_corners=True
            ).transpose(1, 2)

        if level == 0:
            reconstructed = dec_out
        else:
            reconstructed = reconstructed + self.final_proj(dec_out)

    if return_all_levels:
        return all_tokens, reconstructed, level_info
    return all_tokens, reconstructed
get_compression_stats
get_compression_stats(T, max_level=None)

Compute compression statistics.

Parameters:

Name Type Description Default
T int

Original sequence length

required
max_level int

Number of levels to use

None

Returns:

Type Description
dict

Compression statistics

Source code in src/dhb_xr/tokenization/hierarchical.py
def get_compression_stats(self, T: int, max_level: int = None) -> dict:
    """
    Compute compression statistics.

    Args:
        T: Original sequence length
        max_level: Number of levels to use

    Returns:
        Compression statistics
    """
    max_level = max_level or self.num_levels

    total_tokens = 0
    for level in range(max_level):
        level_T = T // (self.temporal_downsample ** level)
        total_tokens += level_T

    original_values = T * self.invariant_dim
    token_values = total_tokens  # Each token is one index

    return {
        "original_values": original_values,
        "total_tokens": total_tokens,
        "tokens_per_level": [T // (self.temporal_downsample ** l) for l in range(max_level)],
        "compression_ratio": original_values / token_values if token_values > 0 else 1,
        "bits_per_value": (total_tokens * np.log2(self.codebook_size)) / original_values,
    }
loss
loss(invariants, reconstructed, all_z, all_z_q, beta=0.25, level_weights=None)

Compute hierarchical loss with per-level weighting.

Parameters:

Name Type Description Default
invariants Tensor

Original input

required
reconstructed Tensor

Reconstruction

required
all_z list

Latents at each level

required
all_z_q list

Quantized latents at each level

required
beta float

Commitment loss weight

0.25
level_weights list

Optional weights per level (default: exponential decay)

None

Returns:

Type Description
Tensor

Total loss

Source code in src/dhb_xr/tokenization/hierarchical.py
def loss(
    self,
    invariants: torch.Tensor,
    reconstructed: torch.Tensor,
    all_z: list,
    all_z_q: list,
    beta: float = 0.25,
    level_weights: list = None,
) -> torch.Tensor:
    """
    Compute hierarchical loss with per-level weighting.

    Args:
        invariants: Original input
        reconstructed: Reconstruction
        all_z: Latents at each level
        all_z_q: Quantized latents at each level
        beta: Commitment loss weight
        level_weights: Optional weights per level (default: exponential decay)

    Returns:
        Total loss
    """
    # Reconstruction loss
    rec_loss = F.mse_loss(reconstructed, invariants)

    # Per-level VQ losses
    if level_weights is None:
        level_weights = [0.5 ** i for i in range(len(all_z))]

    commitment = 0
    codebook = 0
    for i, (z, z_q) in enumerate(zip(all_z, all_z_q)):
        commitment += level_weights[i] * F.mse_loss(z, z_q.detach())
        codebook += level_weights[i] * F.mse_loss(z_q, z.detach())

    return rec_loss + beta * commitment + codebook

MaskedNestedDropout

Bases: Module

Nested dropout module that replaces trailing tokens with a learnable mask token during training.

During training

Randomly sample keep_k in [1, N], replace tokens beyond keep_k with a learnable mask token.

During evaluation

If eval_keep_k is provided, mask tokens beyond keep_k. Otherwise, pass through without masking (use all tokens).

Parameters:

Name Type Description Default
dim int

Embedding dimension of the mask token.

required
mode str

Sampling strategy for keep_k during training. - "disable": No dropout (pass-through). - "uniform": Uniform probability across all prefix lengths. - "pow2": Only sample power-of-2 prefix lengths. - "linear_biased": Linear bias toward longer prefixes. - "quadratic_biased": Quadratic bias toward longer prefixes. - "cubic_biased": Cubic bias toward longer prefixes.

'uniform'
Example

dropout = MaskedNestedDropout(dim=64, mode="uniform") x = torch.randn(2, 8, 64) # (B, K, D)

Training: some trailing tokens replaced with mask

dropout.train() y = dropout(x)

Eval with prefix: decode with first 4 tokens only

dropout.eval() y = dropout(x, eval_keep_k=[4, 4])

Source code in src/dhb_xr/tokenization/nested_dropout.py
class MaskedNestedDropout(nn.Module):
    """
    Nested dropout module that replaces trailing tokens with a learnable
    mask token during training.

    During training:
        Randomly sample keep_k in [1, N], replace tokens beyond keep_k
        with a learnable mask token.

    During evaluation:
        If eval_keep_k is provided, mask tokens beyond keep_k.
        Otherwise, pass through without masking (use all tokens).

    Args:
        dim: Embedding dimension of the mask token.
        mode: Sampling strategy for keep_k during training.
            - "disable": No dropout (pass-through).
            - "uniform": Uniform probability across all prefix lengths.
            - "pow2": Only sample power-of-2 prefix lengths.
            - "linear_biased": Linear bias toward longer prefixes.
            - "quadratic_biased": Quadratic bias toward longer prefixes.
            - "cubic_biased": Cubic bias toward longer prefixes.

    Example:
        >>> dropout = MaskedNestedDropout(dim=64, mode="uniform")
        >>> x = torch.randn(2, 8, 64)  # (B, K, D)
        >>> # Training: some trailing tokens replaced with mask
        >>> dropout.train()
        >>> y = dropout(x)
        >>> # Eval with prefix: decode with first 4 tokens only
        >>> dropout.eval()
        >>> y = dropout(x, eval_keep_k=[4, 4])
    """

    def __init__(
        self,
        dim: int,
        mode: str = "uniform",
    ):
        super().__init__()
        self.dim = dim
        self.mode = mode

        if self.mode != "disable":
            self.mask_token = nn.Parameter(
                torch.randn(dim), requires_grad=True
            )
            trunc_normal_(self.mask_token, std=0.02)

    def _sample_keep_k(
        self,
        batch_size: int,
        num_tokens: int,
        device: torch.device,
    ) -> torch.Tensor:
        """Sample the number of tokens to keep per batch element.

        Args:
            batch_size: Batch size B.
            num_tokens: Total number of tokens N.
            device: Device for tensor creation.

        Returns:
            (B,) tensor of keep_k values in [1, N].
        """
        if self.mode == "uniform":
            return torch.randint(
                1, num_tokens + 1, (batch_size,), device=device
            )

        elif self.mode == "pow2":
            # Power-of-2 values up to num_tokens
            pow2_vals = []
            v = 1
            while v <= num_tokens:
                pow2_vals.append(v)
                v *= 2
            pow2_vals = torch.tensor(pow2_vals, device=device)
            idx = torch.randint(0, len(pow2_vals), (batch_size,), device=device)
            return pow2_vals[idx]

        elif self.mode.endswith("_biased"):
            power_map = {
                "linear_biased": 1.0,
                "quadratic_biased": 2.0,
                "cubic_biased": 3.0,
            }
            power = power_map.get(self.mode, 1.0)
            weights = (
                torch.arange(1, num_tokens + 1, dtype=torch.float32, device=device)
                ** power
            )
            weights = weights / weights.sum()
            indices = torch.multinomial(weights, batch_size, replacement=True)
            return indices + 1  # 0-indexed -> 1-indexed

        else:
            raise ValueError(f"Unknown sampling mode: {self.mode}")

    def forward(
        self,
        x: torch.Tensor,
        eval_keep_k: Optional[List[int]] = None,
    ) -> torch.Tensor:
        """
        Apply nested dropout.

        Args:
            x: (B, N, D) token sequence.
            eval_keep_k: Optional list of B integers specifying how many
                         tokens to keep per sample during evaluation.
                         If None during eval, all tokens are kept.

        Returns:
            (B, N, D) token sequence with trailing tokens masked.
        """
        if self.mode == "disable":
            return x

        B, N, D = x.shape
        x = x.clone()  # Don't modify input in-place

        if self.training:
            keep_ks = self._sample_keep_k(B, N, x.device)
            # Create mask: True for positions that should be masked
            positions = torch.arange(N, device=x.device).unsqueeze(0)  # (1, N)
            mask = positions >= keep_ks.unsqueeze(1)  # (B, N)
            x[mask] = self.mask_token
        elif eval_keep_k is not None:
            keep_ks = torch.tensor(eval_keep_k, device=x.device)
            positions = torch.arange(N, device=x.device).unsqueeze(0)
            mask = positions >= keep_ks.unsqueeze(1)
            x[mask] = self.mask_token

        return x
Functions
forward
forward(x, eval_keep_k=None)

Apply nested dropout.

Parameters:

Name Type Description Default
x Tensor

(B, N, D) token sequence.

required
eval_keep_k Optional[List[int]]

Optional list of B integers specifying how many tokens to keep per sample during evaluation. If None during eval, all tokens are kept.

None

Returns:

Type Description
Tensor

(B, N, D) token sequence with trailing tokens masked.

Source code in src/dhb_xr/tokenization/nested_dropout.py
def forward(
    self,
    x: torch.Tensor,
    eval_keep_k: Optional[List[int]] = None,
) -> torch.Tensor:
    """
    Apply nested dropout.

    Args:
        x: (B, N, D) token sequence.
        eval_keep_k: Optional list of B integers specifying how many
                     tokens to keep per sample during evaluation.
                     If None during eval, all tokens are kept.

    Returns:
        (B, N, D) token sequence with trailing tokens masked.
    """
    if self.mode == "disable":
        return x

    B, N, D = x.shape
    x = x.clone()  # Don't modify input in-place

    if self.training:
        keep_ks = self._sample_keep_k(B, N, x.device)
        # Create mask: True for positions that should be masked
        positions = torch.arange(N, device=x.device).unsqueeze(0)  # (1, N)
        mask = positions >= keep_ks.unsqueeze(1)  # (B, N)
        x[mask] = self.mask_token
    elif eval_keep_k is not None:
        keep_ks = torch.tensor(eval_keep_k, device=x.device)
        positions = torch.arange(N, device=x.device).unsqueeze(0)
        mask = positions >= keep_ks.unsqueeze(1)
        x[mask] = self.mask_token

    return x

OATDecoder

Bases: Module

Cross-attention decoder for register-token latents.

Decodes (B, K, latent_dim) register latents -> (B, T, output_dim).

Architecture: - Learned positional queries of length T (output timesteps) - nn.TransformerDecoder with cross-attention to register latents - Linear head to project to output dimension

Works with MaskedNestedDropout: during training/eval, some trailing register tokens may be replaced with a mask token, enabling variable-quality prefix decoding.

Parameters:

Name Type Description Default
output_dim int

Dimension of output features (e.g., 8 for DHB invariants).

required
output_horizon int

Number of output timesteps T.

required
emb_dim int

Internal transformer embedding dimension.

64
latent_dim int

Dimension of input register latents.

16
latent_horizon int

Number of register tokens K.

8
depth int

Number of transformer decoder layers.

4
num_heads int

Number of attention heads.

4
dropout float

Dropout rate.

0.1
use_causal_decoder bool

If True, apply causal mask to output queries (for autoregressive generation).

False
Example

dec = OATDecoder(output_dim=8, output_horizon=50, emb_dim=64, ... latent_dim=16, latent_horizon=8, depth=4) latents = torch.randn(2, 8, 16) recon = dec(latents) # (2, 50, 8)

Prefix decoding: only use first 4 register tokens

recon_coarse = dec(latents, eval_keep_k=[4, 4])

Source code in src/dhb_xr/tokenization/oat_decoder.py
class OATDecoder(nn.Module):
    """
    Cross-attention decoder for register-token latents.

    Decodes (B, K, latent_dim) register latents -> (B, T, output_dim).

    Architecture:
    - Learned positional queries of length T (output timesteps)
    - nn.TransformerDecoder with cross-attention to register latents
    - Linear head to project to output dimension

    Works with MaskedNestedDropout: during training/eval, some trailing
    register tokens may be replaced with a mask token, enabling
    variable-quality prefix decoding.

    Args:
        output_dim: Dimension of output features (e.g., 8 for DHB invariants).
        output_horizon: Number of output timesteps T.
        emb_dim: Internal transformer embedding dimension.
        latent_dim: Dimension of input register latents.
        latent_horizon: Number of register tokens K.
        depth: Number of transformer decoder layers.
        num_heads: Number of attention heads.
        dropout: Dropout rate.
        use_causal_decoder: If True, apply causal mask to output queries
                            (for autoregressive generation).

    Example:
        >>> dec = OATDecoder(output_dim=8, output_horizon=50, emb_dim=64,
        ...                  latent_dim=16, latent_horizon=8, depth=4)
        >>> latents = torch.randn(2, 8, 16)
        >>> recon = dec(latents)  # (2, 50, 8)
        >>> # Prefix decoding: only use first 4 register tokens
        >>> recon_coarse = dec(latents, eval_keep_k=[4, 4])
    """

    def __init__(
        self,
        output_dim: int,
        output_horizon: int,
        emb_dim: int = 64,
        latent_dim: int = 16,
        latent_horizon: int = 8,
        depth: int = 4,
        num_heads: int = 4,
        dropout: float = 0.1,
        use_causal_decoder: bool = False,
    ):
        super().__init__()

        self.output_dim = output_dim
        self.output_horizon = output_horizon
        self.latent_horizon = latent_horizon
        self.emb_dim = emb_dim
        self.use_causal_decoder = use_causal_decoder

        # Learned positional queries for each output timestep
        self.query_pos = nn.Parameter(
            torch.randn(1, output_horizon, emb_dim)
        )
        nn.init.trunc_normal_(self.query_pos, std=0.02)

        # Project register latents to embedding dim
        self.latent_proj = nn.Linear(latent_dim, emb_dim)

        # Positional encoding for latent tokens
        self.latent_pos = nn.Parameter(
            torch.randn(1, latent_horizon, emb_dim)
        )
        nn.init.trunc_normal_(self.latent_pos, std=0.02)

        # Transformer decoder (cross-attention from queries to latents)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=4 * emb_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer, num_layers=depth
        )

        # Output projection
        self.head = nn.Linear(emb_dim, output_dim)

    def forward(
        self,
        latents: torch.Tensor,
        eval_keep_k: Optional[List[int]] = None,
    ) -> torch.Tensor:
        """
        Decode register latents to output sequence.

        Args:
            latents: (B, K, latent_dim) register token latents.
                     May have trailing tokens masked by MaskedNestedDropout.
            eval_keep_k: Optional list of B integers. If provided, only
                         the first keep_k latent tokens per sample are
                         considered "real" (rest are masked). This is
                         handled upstream by MaskedNestedDropout; this
                         parameter is kept for API consistency.

        Returns:
            (B, T, output_dim) reconstructed output sequence.
        """
        B = latents.shape[0]

        # Project latents and add positional encoding
        memory = self.latent_proj(latents)  # (B, K, emb_dim)
        memory = memory + self.latent_pos[:, : latents.shape[1], :]

        # Expand positional queries
        queries = self.query_pos.expand(B, -1, -1)  # (B, T, emb_dim)

        # Optionally apply causal mask on queries
        tgt_mask = None
        if self.use_causal_decoder:
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                self.output_horizon, device=latents.device
            )

        # Cross-attention decoding
        out = self.decoder(
            queries,
            memory,
            tgt_mask=tgt_mask,
            tgt_is_causal=self.use_causal_decoder,
        )

        # Project to output dim
        return self.head(out)  # (B, T, output_dim)
Functions
forward
forward(latents, eval_keep_k=None)

Decode register latents to output sequence.

Parameters:

Name Type Description Default
latents Tensor

(B, K, latent_dim) register token latents. May have trailing tokens masked by MaskedNestedDropout.

required
eval_keep_k Optional[List[int]]

Optional list of B integers. If provided, only the first keep_k latent tokens per sample are considered "real" (rest are masked). This is handled upstream by MaskedNestedDropout; this parameter is kept for API consistency.

None

Returns:

Type Description
Tensor

(B, T, output_dim) reconstructed output sequence.

Source code in src/dhb_xr/tokenization/oat_decoder.py
def forward(
    self,
    latents: torch.Tensor,
    eval_keep_k: Optional[List[int]] = None,
) -> torch.Tensor:
    """
    Decode register latents to output sequence.

    Args:
        latents: (B, K, latent_dim) register token latents.
                 May have trailing tokens masked by MaskedNestedDropout.
        eval_keep_k: Optional list of B integers. If provided, only
                     the first keep_k latent tokens per sample are
                     considered "real" (rest are masked). This is
                     handled upstream by MaskedNestedDropout; this
                     parameter is kept for API consistency.

    Returns:
        (B, T, output_dim) reconstructed output sequence.
    """
    B = latents.shape[0]

    # Project latents and add positional encoding
    memory = self.latent_proj(latents)  # (B, K, emb_dim)
    memory = memory + self.latent_pos[:, : latents.shape[1], :]

    # Expand positional queries
    queries = self.query_pos.expand(B, -1, -1)  # (B, T, emb_dim)

    # Optionally apply causal mask on queries
    tgt_mask = None
    if self.use_causal_decoder:
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            self.output_horizon, device=latents.device
        )

    # Cross-attention decoding
    out = self.decoder(
        queries,
        memory,
        tgt_mask=tgt_mask,
        tgt_is_causal=self.use_causal_decoder,
    )

    # Project to output dim
    return self.head(out)  # (B, T, output_dim)

OATTokenizer

Bases: Module

OAT-style tokenizer with register encoding, FSQ, nested dropout, and cross-attention decoding.

Produces an ordered sequence of K discrete tokens from T input timesteps. Any prefix of k <= K tokens can be decoded to a valid reconstruction.

Parameters:

Name Type Description Default
input_dim int

Dimension of input features (e.g., 8 for DHB invariants).

8
input_horizon int

Number of input timesteps T.

50
emb_dim int

Internal transformer embedding dimension.

64
latent_dim int

FSQ latent dimension (= number of FSQ levels).

4
num_registers int

Number of register tokens K (compression factor).

8
fsq_levels Optional[List[int]]

List of FSQ quantization levels per dimension.

None
encoder_depth int

Number of transformer encoder layers.

4
decoder_depth int

Number of transformer decoder layers.

4
num_heads int

Number of attention heads.

4
dropout float

Dropout rate.

0.1
nested_dropout_mode str

Sampling mode for nested dropout.

'uniform'
drop_quant_p float

FSQ quantization dropout probability.

0.0
use_causal_decoder bool

Use causal mask in decoder queries.

False
Example

tok = OATTokenizer( ... input_dim=8, input_horizon=50, emb_dim=64, latent_dim=4, ... num_registers=8, fsq_levels=[8, 5, 5, 5], ... encoder_depth=4, decoder_depth=4 ... ) x = torch.randn(2, 50, 8) loss = tok(x) print(f"Training loss: {loss.item():.4f}")

Tokenize

tokens = tok.tokenize(x) # (2, 8)

Detokenize with prefix

recon = tok.detokenize(tokens[:, :4]) # use first 4 tokens

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
class OATTokenizer(nn.Module):
    """
    OAT-style tokenizer with register encoding, FSQ, nested dropout, and
    cross-attention decoding.

    Produces an ordered sequence of K discrete tokens from T input timesteps.
    Any prefix of k <= K tokens can be decoded to a valid reconstruction.

    Args:
        input_dim: Dimension of input features (e.g., 8 for DHB invariants).
        input_horizon: Number of input timesteps T.
        emb_dim: Internal transformer embedding dimension.
        latent_dim: FSQ latent dimension (= number of FSQ levels).
        num_registers: Number of register tokens K (compression factor).
        fsq_levels: List of FSQ quantization levels per dimension.
        encoder_depth: Number of transformer encoder layers.
        decoder_depth: Number of transformer decoder layers.
        num_heads: Number of attention heads.
        dropout: Dropout rate.
        nested_dropout_mode: Sampling mode for nested dropout.
        drop_quant_p: FSQ quantization dropout probability.
        use_causal_decoder: Use causal mask in decoder queries.

    Example:
        >>> tok = OATTokenizer(
        ...     input_dim=8, input_horizon=50, emb_dim=64, latent_dim=4,
        ...     num_registers=8, fsq_levels=[8, 5, 5, 5],
        ...     encoder_depth=4, decoder_depth=4
        ... )
        >>> x = torch.randn(2, 50, 8)
        >>> loss = tok(x)
        >>> print(f"Training loss: {loss.item():.4f}")
        >>> # Tokenize
        >>> tokens = tok.tokenize(x)  # (2, 8)
        >>> # Detokenize with prefix
        >>> recon = tok.detokenize(tokens[:, :4])  # use first 4 tokens
    """

    def __init__(
        self,
        input_dim: int = 8,
        input_horizon: int = 50,
        emb_dim: int = 64,
        latent_dim: int = 4,
        num_registers: int = 8,
        fsq_levels: Optional[List[int]] = None,
        encoder_depth: int = 4,
        decoder_depth: int = 4,
        num_heads: int = 4,
        dropout: float = 0.1,
        nested_dropout_mode: str = "uniform",
        drop_quant_p: float = 0.0,
        use_causal_decoder: bool = False,
    ):
        super().__init__()

        if fsq_levels is None:
            fsq_levels = [8, 5, 5, 5]  # Default: codebook size 1000

        assert len(fsq_levels) == latent_dim, (
            f"len(fsq_levels)={len(fsq_levels)} must equal latent_dim={latent_dim}"
        )

        self.input_dim = input_dim
        self.input_horizon = input_horizon
        self.num_registers = num_registers

        # Encoder: compress T timesteps -> K register tokens
        self.encoder = RegisterEncoder(
            input_dim=input_dim,
            emb_dim=emb_dim,
            latent_dim=latent_dim,
            num_registers=num_registers,
            depth=encoder_depth,
            num_heads=num_heads,
            dropout=dropout,
            max_seq_len=input_horizon + 64,
        )

        # Quantizer: FSQ (no learned codebook)
        self.quantizer = FSQ(
            levels=fsq_levels,
            drop_quant_p=drop_quant_p,
        )

        # Nested dropout: force coarse-to-fine ordering
        self.nested_dropout = MaskedNestedDropout(
            dim=latent_dim,
            mode=nested_dropout_mode,
        )

        # Decoder: cross-attention from positional queries to register latents
        self.decoder = OATDecoder(
            output_dim=input_dim,
            output_horizon=input_horizon,
            emb_dim=emb_dim,
            latent_dim=latent_dim,
            latent_horizon=num_registers,
            depth=decoder_depth,
            num_heads=num_heads,
            dropout=dropout,
            use_causal_decoder=use_causal_decoder,
        )

    @property
    def codebook_size(self) -> int:
        """Effective codebook size from FSQ levels."""
        return self.quantizer.codebook_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Training forward pass: encode -> quantize -> dropout -> decode -> loss.

        Args:
            x: (B, T, input_dim) input invariant/action sequence.

        Returns:
            Scalar MSE reconstruction loss.
        """
        # Encode to register latents
        latents = self.encoder(x)  # (B, K, latent_dim)

        # FSQ quantization
        latents_q, _indices = self.quantizer(latents)

        # Nested dropout (forces ordering during training)
        latents_dropped = self.nested_dropout(latents_q)

        # Decode
        recon = self.decoder(latents_dropped)  # (B, T, input_dim)

        # MSE loss (no VQ losses needed with FSQ)
        loss = F.mse_loss(recon, x)
        return loss

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode input to quantized latents and token indices.

        Args:
            x: (B, T, input_dim) input sequence.

        Returns:
            Tuple of:
            - latents_q: (B, K, latent_dim) quantized register latents.
            - indices: (B, K) discrete token indices.
        """
        latents = self.encoder(x)
        latents_q, indices = self.quantizer(latents)
        return latents_q, indices

    def decode(
        self,
        latents: torch.Tensor,
        eval_keep_k: Optional[List[int]] = None,
    ) -> torch.Tensor:
        """
        Decode quantized latents (with optional prefix masking).

        Args:
            latents: (B, K, latent_dim) quantized register latents.
            eval_keep_k: Optional list of B integers for prefix decoding.

        Returns:
            (B, T, input_dim) reconstructed sequence.
        """
        # Apply nested dropout for prefix masking
        latents_masked = self.nested_dropout(latents, eval_keep_k=eval_keep_k)
        return self.decoder(latents_masked)

    def tokenize(self, x: torch.Tensor) -> torch.Tensor:
        """
        Convert input to discrete token indices.

        Args:
            x: (B, T, input_dim) input sequence.

        Returns:
            (B, K) token indices.
        """
        _, indices = self.encode(x)
        return indices

    def detokenize(
        self,
        tokens: torch.Tensor,
        eval_keep_k: Optional[List[int]] = None,
    ) -> torch.Tensor:
        """
        Convert token indices back to reconstructed sequence.

        Supports prefix decoding: pass a subset of tokens (first k columns)
        to get a coarser reconstruction.

        Args:
            tokens: (B, K) or (B, k) token indices where k <= K.
            eval_keep_k: Optional list specifying keep_k per sample.
                         If tokens has fewer columns than num_registers,
                         keep_k is inferred automatically.

        Returns:
            (B, T, input_dim) reconstructed sequence.
        """
        B, K_actual = tokens.shape

        # Pad tokens to full register length if needed
        if K_actual < self.num_registers:
            pad = torch.zeros(
                B,
                self.num_registers - K_actual,
                dtype=tokens.dtype,
                device=tokens.device,
            )
            tokens_full = torch.cat([tokens, pad], dim=1)
            if eval_keep_k is None:
                eval_keep_k = [K_actual] * B
        else:
            tokens_full = tokens

        # Look up embeddings
        latents = self.quantizer.indices_to_embedding(tokens_full)

        return self.decode(latents, eval_keep_k=eval_keep_k)

    def autoencode(
        self,
        x: torch.Tensor,
        eval_keep_k: Optional[List[int]] = None,
    ) -> torch.Tensor:
        """
        Full encode-decode roundtrip.

        Args:
            x: (B, T, input_dim) input sequence.
            eval_keep_k: Optional prefix lengths for anytime decoding.

        Returns:
            (B, T, input_dim) reconstructed sequence.
        """
        latents_q, _ = self.encode(x)
        return self.decode(latents_q, eval_keep_k=eval_keep_k)

    def get_optimizer(
        self,
        learning_rate: float = 1e-4,
        weight_decay: float = 0.01,
        betas: Tuple[float, float] = (0.9, 0.999),
    ) -> torch.optim.Optimizer:
        """Create AdamW optimizer with proper weight decay groups."""
        decay_params = [
            p
            for n, p in self.named_parameters()
            if p.requires_grad and p.dim() >= 2
        ]
        nodecay_params = [
            p
            for n, p in self.named_parameters()
            if p.requires_grad and p.dim() < 2
        ]
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        return torch.optim.AdamW(
            optim_groups, lr=learning_rate, betas=betas
        )

    def get_compression_stats(self) -> Dict:
        """Get compression statistics."""
        return {
            "input_horizon": self.input_horizon,
            "input_dim": self.input_dim,
            "num_registers": self.num_registers,
            "compression_ratio": self.input_horizon / self.num_registers,
            "codebook_size": self.codebook_size,
            "fsq_levels": self.quantizer._levels.tolist(),
            "total_params": sum(
                p.numel() for p in self.parameters()
            ),
        }
Attributes
codebook_size property
codebook_size

Effective codebook size from FSQ levels.

Functions
autoencode
autoencode(x, eval_keep_k=None)

Full encode-decode roundtrip.

Parameters:

Name Type Description Default
x Tensor

(B, T, input_dim) input sequence.

required
eval_keep_k Optional[List[int]]

Optional prefix lengths for anytime decoding.

None

Returns:

Type Description
Tensor

(B, T, input_dim) reconstructed sequence.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def autoencode(
    self,
    x: torch.Tensor,
    eval_keep_k: Optional[List[int]] = None,
) -> torch.Tensor:
    """
    Full encode-decode roundtrip.

    Args:
        x: (B, T, input_dim) input sequence.
        eval_keep_k: Optional prefix lengths for anytime decoding.

    Returns:
        (B, T, input_dim) reconstructed sequence.
    """
    latents_q, _ = self.encode(x)
    return self.decode(latents_q, eval_keep_k=eval_keep_k)
decode
decode(latents, eval_keep_k=None)

Decode quantized latents (with optional prefix masking).

Parameters:

Name Type Description Default
latents Tensor

(B, K, latent_dim) quantized register latents.

required
eval_keep_k Optional[List[int]]

Optional list of B integers for prefix decoding.

None

Returns:

Type Description
Tensor

(B, T, input_dim) reconstructed sequence.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def decode(
    self,
    latents: torch.Tensor,
    eval_keep_k: Optional[List[int]] = None,
) -> torch.Tensor:
    """
    Decode quantized latents (with optional prefix masking).

    Args:
        latents: (B, K, latent_dim) quantized register latents.
        eval_keep_k: Optional list of B integers for prefix decoding.

    Returns:
        (B, T, input_dim) reconstructed sequence.
    """
    # Apply nested dropout for prefix masking
    latents_masked = self.nested_dropout(latents, eval_keep_k=eval_keep_k)
    return self.decoder(latents_masked)
detokenize
detokenize(tokens, eval_keep_k=None)

Convert token indices back to reconstructed sequence.

Supports prefix decoding: pass a subset of tokens (first k columns) to get a coarser reconstruction.

Parameters:

Name Type Description Default
tokens Tensor

(B, K) or (B, k) token indices where k <= K.

required
eval_keep_k Optional[List[int]]

Optional list specifying keep_k per sample. If tokens has fewer columns than num_registers, keep_k is inferred automatically.

None

Returns:

Type Description
Tensor

(B, T, input_dim) reconstructed sequence.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def detokenize(
    self,
    tokens: torch.Tensor,
    eval_keep_k: Optional[List[int]] = None,
) -> torch.Tensor:
    """
    Convert token indices back to reconstructed sequence.

    Supports prefix decoding: pass a subset of tokens (first k columns)
    to get a coarser reconstruction.

    Args:
        tokens: (B, K) or (B, k) token indices where k <= K.
        eval_keep_k: Optional list specifying keep_k per sample.
                     If tokens has fewer columns than num_registers,
                     keep_k is inferred automatically.

    Returns:
        (B, T, input_dim) reconstructed sequence.
    """
    B, K_actual = tokens.shape

    # Pad tokens to full register length if needed
    if K_actual < self.num_registers:
        pad = torch.zeros(
            B,
            self.num_registers - K_actual,
            dtype=tokens.dtype,
            device=tokens.device,
        )
        tokens_full = torch.cat([tokens, pad], dim=1)
        if eval_keep_k is None:
            eval_keep_k = [K_actual] * B
    else:
        tokens_full = tokens

    # Look up embeddings
    latents = self.quantizer.indices_to_embedding(tokens_full)

    return self.decode(latents, eval_keep_k=eval_keep_k)
encode
encode(x)

Encode input to quantized latents and token indices.

Parameters:

Name Type Description Default
x Tensor

(B, T, input_dim) input sequence.

required

Returns:

Type Description
Tensor

Tuple of:

Tensor
  • latents_q: (B, K, latent_dim) quantized register latents.
Tuple[Tensor, Tensor]
  • indices: (B, K) discrete token indices.
Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Encode input to quantized latents and token indices.

    Args:
        x: (B, T, input_dim) input sequence.

    Returns:
        Tuple of:
        - latents_q: (B, K, latent_dim) quantized register latents.
        - indices: (B, K) discrete token indices.
    """
    latents = self.encoder(x)
    latents_q, indices = self.quantizer(latents)
    return latents_q, indices
forward
forward(x)

Training forward pass: encode -> quantize -> dropout -> decode -> loss.

Parameters:

Name Type Description Default
x Tensor

(B, T, input_dim) input invariant/action sequence.

required

Returns:

Type Description
Tensor

Scalar MSE reconstruction loss.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Training forward pass: encode -> quantize -> dropout -> decode -> loss.

    Args:
        x: (B, T, input_dim) input invariant/action sequence.

    Returns:
        Scalar MSE reconstruction loss.
    """
    # Encode to register latents
    latents = self.encoder(x)  # (B, K, latent_dim)

    # FSQ quantization
    latents_q, _indices = self.quantizer(latents)

    # Nested dropout (forces ordering during training)
    latents_dropped = self.nested_dropout(latents_q)

    # Decode
    recon = self.decoder(latents_dropped)  # (B, T, input_dim)

    # MSE loss (no VQ losses needed with FSQ)
    loss = F.mse_loss(recon, x)
    return loss
get_compression_stats
get_compression_stats()

Get compression statistics.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def get_compression_stats(self) -> Dict:
    """Get compression statistics."""
    return {
        "input_horizon": self.input_horizon,
        "input_dim": self.input_dim,
        "num_registers": self.num_registers,
        "compression_ratio": self.input_horizon / self.num_registers,
        "codebook_size": self.codebook_size,
        "fsq_levels": self.quantizer._levels.tolist(),
        "total_params": sum(
            p.numel() for p in self.parameters()
        ),
    }
get_optimizer
get_optimizer(learning_rate=0.0001, weight_decay=0.01, betas=(0.9, 0.999))

Create AdamW optimizer with proper weight decay groups.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def get_optimizer(
    self,
    learning_rate: float = 1e-4,
    weight_decay: float = 0.01,
    betas: Tuple[float, float] = (0.9, 0.999),
) -> torch.optim.Optimizer:
    """Create AdamW optimizer with proper weight decay groups."""
    decay_params = [
        p
        for n, p in self.named_parameters()
        if p.requires_grad and p.dim() >= 2
    ]
    nodecay_params = [
        p
        for n, p in self.named_parameters()
        if p.requires_grad and p.dim() < 2
    ]
    optim_groups = [
        {"params": decay_params, "weight_decay": weight_decay},
        {"params": nodecay_params, "weight_decay": 0.0},
    ]
    return torch.optim.AdamW(
        optim_groups, lr=learning_rate, betas=betas
    )
tokenize
tokenize(x)

Convert input to discrete token indices.

Parameters:

Name Type Description Default
x Tensor

(B, T, input_dim) input sequence.

required

Returns:

Type Description
Tensor

(B, K) token indices.

Source code in src/dhb_xr/tokenization/oat_tokenizer.py
def tokenize(self, x: torch.Tensor) -> torch.Tensor:
    """
    Convert input to discrete token indices.

    Args:
        x: (B, T, input_dim) input sequence.

    Returns:
        (B, K) token indices.
    """
    _, indices = self.encode(x)
    return indices

ProgressiveTokenizer

Bases: Module

Progressive refinement tokenizer.

Outputs can be truncated at any level for variable-rate decoding: - 1 level: ~4x compression, coarse motion - 2 levels: ~2x compression, medium detail - 4 levels: ~1x compression, full fidelity

Ideal for streaming/bandwidth-adaptive applications.

Source code in src/dhb_xr/tokenization/hierarchical.py
class ProgressiveTokenizer(nn.Module):
    """
    Progressive refinement tokenizer.

    Outputs can be truncated at any level for variable-rate decoding:
    - 1 level: ~4x compression, coarse motion
    - 2 levels: ~2x compression, medium detail
    - 4 levels: ~1x compression, full fidelity

    Ideal for streaming/bandwidth-adaptive applications.
    """

    def __init__(
        self,
        invariant_dim: int,
        latent_dim: int,
        codebook_size: int,
        num_refinements: int = 3,
    ):
        super().__init__()

        self.base_tokenizer = nn.Sequential(
            CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 2),
        )
        self.base_vq = VectorQuantizer(codebook_size, latent_dim)
        self.base_decoder = CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 2)

        # Refinement stages
        self.refinements = nn.ModuleList()
        self.refine_vqs = nn.ModuleList()
        self.refine_decoders = nn.ModuleList()

        for _ in range(num_refinements):
            self.refinements.append(
                CausalConv1dEncoder(invariant_dim, latent_dim, latent_dim, 1)
            )
            self.refine_vqs.append(VectorQuantizer(codebook_size, latent_dim))
            self.refine_decoders.append(
                CausalConv1dEncoder(latent_dim, latent_dim, invariant_dim, 1)
            )

    def forward(self, invariants: torch.Tensor, num_refine: int = None) -> tuple:
        """
        Progressive tokenization.

        Args:
            invariants: (B, T, D) input
            num_refine: Number of refinement levels (0 = base only)

        Returns:
            all_tokens: List of token tensors
            reconstructed: Final reconstruction
        """
        if num_refine is None:
            num_refine = len(self.refinements)

        # Base encoding
        z_base = self.base_tokenizer(invariants)
        tokens_base, z_q_st, z_q = self.base_vq(z_base)
        recon = self.base_decoder(z_q_st)

        all_tokens = [tokens_base]

        # Progressive refinements
        residual = invariants - recon
        for i in range(min(num_refine, len(self.refinements))):
            z_ref = self.refinements[i](residual)
            tokens_ref, z_ref_st, z_ref_q = self.refine_vqs[i](z_ref)

            all_tokens.append(tokens_ref)

            recon = recon + self.refine_decoders[i](z_ref_st)
            residual = invariants - recon

        return all_tokens, recon
Functions
forward
forward(invariants, num_refine=None)

Progressive tokenization.

Parameters:

Name Type Description Default
invariants Tensor

(B, T, D) input

required
num_refine int

Number of refinement levels (0 = base only)

None

Returns:

Name Type Description
all_tokens tuple

List of token tensors

reconstructed tuple

Final reconstruction

Source code in src/dhb_xr/tokenization/hierarchical.py
def forward(self, invariants: torch.Tensor, num_refine: int = None) -> tuple:
    """
    Progressive tokenization.

    Args:
        invariants: (B, T, D) input
        num_refine: Number of refinement levels (0 = base only)

    Returns:
        all_tokens: List of token tensors
        reconstructed: Final reconstruction
    """
    if num_refine is None:
        num_refine = len(self.refinements)

    # Base encoding
    z_base = self.base_tokenizer(invariants)
    tokens_base, z_q_st, z_q = self.base_vq(z_base)
    recon = self.base_decoder(z_q_st)

    all_tokens = [tokens_base]

    # Progressive refinements
    residual = invariants - recon
    for i in range(min(num_refine, len(self.refinements))):
        z_ref = self.refinements[i](residual)
        tokens_ref, z_ref_st, z_ref_q = self.refine_vqs[i](z_ref)

        all_tokens.append(tokens_ref)

        recon = recon + self.refine_decoders[i](z_ref_st)
        residual = invariants - recon

    return all_tokens, recon

RLECompressor

Run-Length Encoding for token sequences.

Compresses repeated tokens (common in static/low-motion segments). Uses a special "repeat" token followed by (token, count) pairs.

Example

[5, 5, 5, 5, 3, 3] -> [(REPEAT, 5, 4), (REPEAT, 3, 2)] or simplified: [RLE_MARKER, 5, 4, RLE_MARKER, 3, 2]

Source code in src/dhb_xr/tokenization/compression.py
class RLECompressor:
    """
    Run-Length Encoding for token sequences.

    Compresses repeated tokens (common in static/low-motion segments).
    Uses a special "repeat" token followed by (token, count) pairs.

    Example:
        [5, 5, 5, 5, 3, 3] -> [(REPEAT, 5, 4), (REPEAT, 3, 2)]
        or simplified: [RLE_MARKER, 5, 4, RLE_MARKER, 3, 2]
    """

    def __init__(self, min_run: int = 3, max_count: int = 255):
        """
        Args:
            min_run: Minimum run length to compress (shorter runs kept as-is)
            max_count: Maximum count per run (limits encoding overhead)
        """
        self.min_run = min_run
        self.max_count = max_count
        self.RLE_MARKER = -1  # Special marker (will be shifted to valid range)

    def encode(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int]]:
        """
        Encode tokens with run-length encoding.

        Args:
            tokens: Token sequence

        Returns:
            List of (token, count) tuples
        """
        if len(tokens) == 0:
            return []

        tokens = list(tokens)
        encoded = []

        i = 0
        while i < len(tokens):
            current = tokens[i]
            count = 1

            # Count consecutive identical tokens
            while i + count < len(tokens) and tokens[i + count] == current and count < self.max_count:
                count += 1

            encoded.append((current, count))
            i += count

        return encoded

    def decode(self, encoded: List[Tuple[int, int]]) -> List[int]:
        """
        Decode RLE-encoded sequence.

        Args:
            encoded: List of (token, count) tuples

        Returns:
            Original token sequence
        """
        tokens = []
        for token, count in encoded:
            tokens.extend([token] * count)
        return tokens

    def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
        """Compute compression ratio."""
        encoded = self.encode(tokens)
        # Each (token, count) pair = 2 values vs count original tokens
        compressed_size = len(encoded) * 2
        original_size = len(tokens)
        return original_size / compressed_size if compressed_size > 0 else 1.0

    def get_stats(self, tokens: Union[List[int], np.ndarray]) -> Dict:
        """Get RLE statistics for a sequence."""
        encoded = self.encode(tokens)
        run_lengths = [count for _, count in encoded]
        return {
            "num_runs": len(encoded),
            "avg_run_length": np.mean(run_lengths) if run_lengths else 0,
            "max_run_length": max(run_lengths) if run_lengths else 0,
            "compression_ratio": self.compression_ratio(tokens),
        }
Functions
__init__
__init__(min_run=3, max_count=255)

Parameters:

Name Type Description Default
min_run int

Minimum run length to compress (shorter runs kept as-is)

3
max_count int

Maximum count per run (limits encoding overhead)

255
Source code in src/dhb_xr/tokenization/compression.py
def __init__(self, min_run: int = 3, max_count: int = 255):
    """
    Args:
        min_run: Minimum run length to compress (shorter runs kept as-is)
        max_count: Maximum count per run (limits encoding overhead)
    """
    self.min_run = min_run
    self.max_count = max_count
    self.RLE_MARKER = -1  # Special marker (will be shifted to valid range)
compression_ratio
compression_ratio(tokens)

Compute compression ratio.

Source code in src/dhb_xr/tokenization/compression.py
def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
    """Compute compression ratio."""
    encoded = self.encode(tokens)
    # Each (token, count) pair = 2 values vs count original tokens
    compressed_size = len(encoded) * 2
    original_size = len(tokens)
    return original_size / compressed_size if compressed_size > 0 else 1.0
decode
decode(encoded)

Decode RLE-encoded sequence.

Parameters:

Name Type Description Default
encoded List[Tuple[int, int]]

List of (token, count) tuples

required

Returns:

Type Description
List[int]

Original token sequence

Source code in src/dhb_xr/tokenization/compression.py
def decode(self, encoded: List[Tuple[int, int]]) -> List[int]:
    """
    Decode RLE-encoded sequence.

    Args:
        encoded: List of (token, count) tuples

    Returns:
        Original token sequence
    """
    tokens = []
    for token, count in encoded:
        tokens.extend([token] * count)
    return tokens
encode
encode(tokens)

Encode tokens with run-length encoding.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required

Returns:

Type Description
List[Tuple[int, int]]

List of (token, count) tuples

Source code in src/dhb_xr/tokenization/compression.py
def encode(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int]]:
    """
    Encode tokens with run-length encoding.

    Args:
        tokens: Token sequence

    Returns:
        List of (token, count) tuples
    """
    if len(tokens) == 0:
        return []

    tokens = list(tokens)
    encoded = []

    i = 0
    while i < len(tokens):
        current = tokens[i]
        count = 1

        # Count consecutive identical tokens
        while i + count < len(tokens) and tokens[i + count] == current and count < self.max_count:
            count += 1

        encoded.append((current, count))
        i += count

    return encoded
get_stats
get_stats(tokens)

Get RLE statistics for a sequence.

Source code in src/dhb_xr/tokenization/compression.py
def get_stats(self, tokens: Union[List[int], np.ndarray]) -> Dict:
    """Get RLE statistics for a sequence."""
    encoded = self.encode(tokens)
    run_lengths = [count for _, count in encoded]
    return {
        "num_runs": len(encoded),
        "avg_run_length": np.mean(run_lengths) if run_lengths else 0,
        "max_run_length": max(run_lengths) if run_lengths else 0,
        "compression_ratio": self.compression_ratio(tokens),
    }

RegisterEncoder

Bases: Module

Transformer encoder with register tokens.

Compresses (B, T, input_dim) -> (B, num_registers, latent_dim).

The key architectural innovation: learnable register tokens are appended to the embedded input, processed through a transformer with a causal-last attention mask, and then extracted as the compressed representation.

Parameters:

Name Type Description Default
input_dim int

Dimension of input features (e.g., 8 for DHB invariants).

required
emb_dim int

Internal transformer embedding dimension.

64
latent_dim int

Output dimension per register token.

16
num_registers int

Number of register tokens (K). Controls compression.

8
depth int

Number of transformer encoder layers.

4
num_heads int

Number of attention heads.

4
dropout float

Dropout rate.

0.1
max_seq_len int

Maximum input sequence length.

512
Example

enc = RegisterEncoder(input_dim=8, emb_dim=64, latent_dim=16, ... num_registers=8, depth=4) x = torch.randn(2, 50, 8) # (B, T=50, D=8) z = enc(x) # (B, 8, 16) -- 50 timesteps -> 8 tokens

Source code in src/dhb_xr/tokenization/register_encoder.py
class RegisterEncoder(nn.Module):
    """
    Transformer encoder with register tokens.

    Compresses (B, T, input_dim) -> (B, num_registers, latent_dim).

    The key architectural innovation: learnable register tokens are appended
    to the embedded input, processed through a transformer with a causal-last
    attention mask, and then extracted as the compressed representation.

    Args:
        input_dim: Dimension of input features (e.g., 8 for DHB invariants).
        emb_dim: Internal transformer embedding dimension.
        latent_dim: Output dimension per register token.
        num_registers: Number of register tokens (K). Controls compression.
        depth: Number of transformer encoder layers.
        num_heads: Number of attention heads.
        dropout: Dropout rate.
        max_seq_len: Maximum input sequence length.

    Example:
        >>> enc = RegisterEncoder(input_dim=8, emb_dim=64, latent_dim=16,
        ...                       num_registers=8, depth=4)
        >>> x = torch.randn(2, 50, 8)  # (B, T=50, D=8)
        >>> z = enc(x)                   # (B, 8, 16) -- 50 timesteps -> 8 tokens
    """

    def __init__(
        self,
        input_dim: int,
        emb_dim: int = 64,
        latent_dim: int = 16,
        num_registers: int = 8,
        depth: int = 4,
        num_heads: int = 4,
        dropout: float = 0.1,
        max_seq_len: int = 512,
    ):
        super().__init__()

        self.input_dim = input_dim
        self.emb_dim = emb_dim
        self.latent_dim = latent_dim
        self.num_registers = num_registers

        # Input projection
        self.input_proj = nn.Linear(input_dim, emb_dim)

        # Positional encoding for input tokens
        self.pos_enc = SinusoidalPositionalEncoding(emb_dim, max_seq_len)

        # Learnable register tokens
        self.registers = nn.Parameter(torch.randn(num_registers, emb_dim))
        nn.init.trunc_normal_(self.registers, std=0.02)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=4 * emb_dim,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=depth
        )

        # Output projection (register embeddings -> latent dim)
        self.head = nn.Linear(emb_dim, latent_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode input sequence to register token latents.

        Args:
            x: (B, T, input_dim) input invariant/action sequence.

        Returns:
            (B, num_registers, latent_dim) compressed representation.
        """
        B, T, _ = x.shape

        # Project input and add positional encoding
        x_emb = self.input_proj(x)  # (B, T, emb_dim)
        x_emb = self.pos_enc(x_emb)

        # Append register tokens
        reg = self.registers.unsqueeze(0).expand(B, -1, -1)
        combined = torch.cat([x_emb, reg], dim=1)  # (B, T+K, emb_dim)

        # Create causal-last mask
        mask = create_causal_last_mask(T, self.num_registers, str(x.device))

        # Transformer forward
        out = self.transformer(combined, mask=mask)

        # Extract register tokens and project
        latents = self.head(out[:, T:])  # (B, K, latent_dim)

        return latents
Functions
forward
forward(x)

Encode input sequence to register token latents.

Parameters:

Name Type Description Default
x Tensor

(B, T, input_dim) input invariant/action sequence.

required

Returns:

Type Description
Tensor

(B, num_registers, latent_dim) compressed representation.

Source code in src/dhb_xr/tokenization/register_encoder.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Encode input sequence to register token latents.

    Args:
        x: (B, T, input_dim) input invariant/action sequence.

    Returns:
        (B, num_registers, latent_dim) compressed representation.
    """
    B, T, _ = x.shape

    # Project input and add positional encoding
    x_emb = self.input_proj(x)  # (B, T, emb_dim)
    x_emb = self.pos_enc(x_emb)

    # Append register tokens
    reg = self.registers.unsqueeze(0).expand(B, -1, -1)
    combined = torch.cat([x_emb, reg], dim=1)  # (B, T+K, emb_dim)

    # Create causal-last mask
    mask = create_causal_last_mask(T, self.num_registers, str(x.device))

    # Transformer forward
    out = self.transformer(combined, mask=mask)

    # Extract register tokens and project
    latents = self.head(out[:, T:])  # (B, K, latent_dim)

    return latents

ResidualVQTokenizer

Bases: Module

RVQ: multiple codebooks on residual. invariants (B, T, C) -> list of (B, T) tokens, (B, T, C) reconstructed.

Source code in src/dhb_xr/tokenization/rvq.py
class ResidualVQTokenizer(nn.Module):
    """RVQ: multiple codebooks on residual. invariants (B, T, C) -> list of (B, T) tokens, (B, T, C) reconstructed."""

    def __init__(
        self,
        invariant_dim: int,
        latent_dim: int,
        codebook_size: int,
        num_codebooks: int = 2,
        num_layers: int = 2,
    ):
        super().__init__()
        self.encoder = CausalConv1dEncoder(
            invariant_dim, latent_dim, latent_dim, num_layers
        )
        self.vqs = nn.ModuleList([
            VectorQuantizer(codebook_size, latent_dim) for _ in range(num_codebooks)
        ])
        self.decoder = CausalConv1dEncoder(
            latent_dim, latent_dim, invariant_dim, num_layers
        )
        self.num_codebooks = num_codebooks

    def forward(self, invariants: torch.Tensor) -> tuple:
        z = self.encoder(invariants)
        residuals = z
        all_indices = []
        z_sum = torch.zeros_like(z)
        for vq in self.vqs:
            indices, z_q_st, z_q = vq(residuals)
            all_indices.append(indices)
            z_sum = z_sum + z_q_st
            residuals = residuals - z_q.detach()
        reconstructed = self.decoder(z_sum)
        return all_indices, reconstructed, z, z_sum

    # ---- Flow matching integration API ----

    def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
        """
        Encode invariants to continuous latent space (before quantization).

        Args:
            invariants: Input invariant sequences (B, T, C).

        Returns:
            Continuous latent z (B, T, latent_dim).
        """
        return self.encoder(invariants)

    def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode from continuous latent to invariants.

        Args:
            z: Continuous latent (B, T, latent_dim).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        return self.decoder(z)

    def quantize(self, z: torch.Tensor, num_codebooks: int = None) -> tuple:
        """
        Quantize continuous latent using RVQ.

        Args:
            z: Continuous latent (B, T, latent_dim).
            num_codebooks: Number of codebooks to use (default: all).

        Returns:
            Tuple of (all_indices, z_sum) where all_indices is list of (B, T).
        """
        if num_codebooks is None:
            num_codebooks = self.num_codebooks

        residuals = z
        all_indices = []
        z_sum = torch.zeros_like(z)

        for i, vq in enumerate(self.vqs[:num_codebooks]):
            indices, z_q_st, z_q = vq(residuals)
            all_indices.append(indices)
            z_sum = z_sum + z_q_st
            residuals = residuals - z_q.detach()

        return all_indices, z_sum

    def encode_partial(
        self,
        invariants: torch.Tensor,
        num_codebooks: int,
    ) -> tuple:
        """
        Encode with partial RVQ (for hierarchical VFM).

        Uses only the first num_codebooks codebooks.

        Args:
            invariants: Input invariant sequences (B, T, C).
            num_codebooks: Number of codebooks to use.

        Returns:
            Tuple of (all_indices, z_sum, reconstructed).
        """
        z = self.encoder(invariants)
        all_indices, z_sum = self.quantize(z, num_codebooks)
        reconstructed = self.decoder(z_sum)
        return all_indices, z_sum, reconstructed

    def get_codebook_embeddings(self, codebook_idx: int = 0) -> torch.Tensor:
        """
        Get codebook embeddings for a specific codebook.

        Args:
            codebook_idx: Index of the codebook (0 to num_codebooks-1).

        Returns:
            Codebook embeddings (codebook_size, latent_dim).
        """
        return self.vqs[codebook_idx].embedding.weight.data

    def get_all_codebook_embeddings(self) -> list:
        """
        Get all codebook embeddings.

        Returns:
            List of codebook embeddings, each (codebook_size, latent_dim).
        """
        return [vq.embedding.weight.data for vq in self.vqs]

    def embed_tokens(self, all_indices: list) -> torch.Tensor:
        """
        Convert RVQ token indices to summed embeddings.

        Args:
            all_indices: List of token indices, each (B, T).

        Returns:
            Summed embeddings (B, T, latent_dim).
        """
        z_sum = None
        for i, (indices, vq) in enumerate(zip(all_indices, self.vqs)):
            z_q = vq.embedding(indices)
            if z_sum is None:
                z_sum = z_q
            else:
                z_sum = z_sum + z_q
        return z_sum

    def decode_tokens(self, all_indices: list) -> torch.Tensor:
        """
        Decode RVQ token indices to invariants.

        Args:
            all_indices: List of token indices, each (B, T).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        z_sum = self.embed_tokens(all_indices)
        return self.decoder(z_sum)
Functions
decode_from_latent
decode_from_latent(z)

Decode from continuous latent to invariants.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
    """
    Decode from continuous latent to invariants.

    Args:
        z: Continuous latent (B, T, latent_dim).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    return self.decoder(z)
decode_tokens
decode_tokens(all_indices)

Decode RVQ token indices to invariants.

Parameters:

Name Type Description Default
all_indices list

List of token indices, each (B, T).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def decode_tokens(self, all_indices: list) -> torch.Tensor:
    """
    Decode RVQ token indices to invariants.

    Args:
        all_indices: List of token indices, each (B, T).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    z_sum = self.embed_tokens(all_indices)
    return self.decoder(z_sum)
embed_tokens
embed_tokens(all_indices)

Convert RVQ token indices to summed embeddings.

Parameters:

Name Type Description Default
all_indices list

List of token indices, each (B, T).

required

Returns:

Type Description
Tensor

Summed embeddings (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def embed_tokens(self, all_indices: list) -> torch.Tensor:
    """
    Convert RVQ token indices to summed embeddings.

    Args:
        all_indices: List of token indices, each (B, T).

    Returns:
        Summed embeddings (B, T, latent_dim).
    """
    z_sum = None
    for i, (indices, vq) in enumerate(zip(all_indices, self.vqs)):
        z_q = vq.embedding(indices)
        if z_sum is None:
            z_sum = z_q
        else:
            z_sum = z_sum + z_q
    return z_sum
encode_continuous
encode_continuous(invariants)

Encode invariants to continuous latent space (before quantization).

Parameters:

Name Type Description Default
invariants Tensor

Input invariant sequences (B, T, C).

required

Returns:

Type Description
Tensor

Continuous latent z (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
    """
    Encode invariants to continuous latent space (before quantization).

    Args:
        invariants: Input invariant sequences (B, T, C).

    Returns:
        Continuous latent z (B, T, latent_dim).
    """
    return self.encoder(invariants)
encode_partial
encode_partial(invariants, num_codebooks)

Encode with partial RVQ (for hierarchical VFM).

Uses only the first num_codebooks codebooks.

Parameters:

Name Type Description Default
invariants Tensor

Input invariant sequences (B, T, C).

required
num_codebooks int

Number of codebooks to use.

required

Returns:

Type Description
tuple

Tuple of (all_indices, z_sum, reconstructed).

Source code in src/dhb_xr/tokenization/rvq.py
def encode_partial(
    self,
    invariants: torch.Tensor,
    num_codebooks: int,
) -> tuple:
    """
    Encode with partial RVQ (for hierarchical VFM).

    Uses only the first num_codebooks codebooks.

    Args:
        invariants: Input invariant sequences (B, T, C).
        num_codebooks: Number of codebooks to use.

    Returns:
        Tuple of (all_indices, z_sum, reconstructed).
    """
    z = self.encoder(invariants)
    all_indices, z_sum = self.quantize(z, num_codebooks)
    reconstructed = self.decoder(z_sum)
    return all_indices, z_sum, reconstructed
get_all_codebook_embeddings
get_all_codebook_embeddings()

Get all codebook embeddings.

Returns:

Type Description
list

List of codebook embeddings, each (codebook_size, latent_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def get_all_codebook_embeddings(self) -> list:
    """
    Get all codebook embeddings.

    Returns:
        List of codebook embeddings, each (codebook_size, latent_dim).
    """
    return [vq.embedding.weight.data for vq in self.vqs]
get_codebook_embeddings
get_codebook_embeddings(codebook_idx=0)

Get codebook embeddings for a specific codebook.

Parameters:

Name Type Description Default
codebook_idx int

Index of the codebook (0 to num_codebooks-1).

0

Returns:

Type Description
Tensor

Codebook embeddings (codebook_size, latent_dim).

Source code in src/dhb_xr/tokenization/rvq.py
def get_codebook_embeddings(self, codebook_idx: int = 0) -> torch.Tensor:
    """
    Get codebook embeddings for a specific codebook.

    Args:
        codebook_idx: Index of the codebook (0 to num_codebooks-1).

    Returns:
        Codebook embeddings (codebook_size, latent_dim).
    """
    return self.vqs[codebook_idx].embedding.weight.data
quantize
quantize(z, num_codebooks=None)

Quantize continuous latent using RVQ.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required
num_codebooks int

Number of codebooks to use (default: all).

None

Returns:

Type Description
tuple

Tuple of (all_indices, z_sum) where all_indices is list of (B, T).

Source code in src/dhb_xr/tokenization/rvq.py
def quantize(self, z: torch.Tensor, num_codebooks: int = None) -> tuple:
    """
    Quantize continuous latent using RVQ.

    Args:
        z: Continuous latent (B, T, latent_dim).
        num_codebooks: Number of codebooks to use (default: all).

    Returns:
        Tuple of (all_indices, z_sum) where all_indices is list of (B, T).
    """
    if num_codebooks is None:
        num_codebooks = self.num_codebooks

    residuals = z
    all_indices = []
    z_sum = torch.zeros_like(z)

    for i, vq in enumerate(self.vqs[:num_codebooks]):
        indices, z_q_st, z_q = vq(residuals)
        all_indices.append(indices)
        z_sum = z_sum + z_q_st
        residuals = residuals - z_q.detach()

    return all_indices, z_sum

TokenCompressor

Unified compression pipeline for DHB-Token sequences.

Combines multiple compression methods for optimal results: 1. RLE for static segments (lossless, good for low-motion) 2. BPE for pattern merging (lossless, 2-4x on invariant patterns) 3. Entropy coding for final bitstream (lossless, 1.5-2x additional)

Overall achievable: 3-8x compression on typical DHB-Token sequences.

Example

compressor = TokenCompressor(vocab_size=256) compressor.fit(training_sequences)

Compress

compressed = compressor.compress(tokens) print(f"Ratio: {compressor.compression_ratio(tokens):.1f}x")

Decompress (lossless)

recovered = compressor.decompress(compressed) assert recovered == list(tokens)

Source code in src/dhb_xr/tokenization/compression.py
class TokenCompressor:
    """
    Unified compression pipeline for DHB-Token sequences.

    Combines multiple compression methods for optimal results:
    1. RLE for static segments (lossless, good for low-motion)
    2. BPE for pattern merging (lossless, 2-4x on invariant patterns)
    3. Entropy coding for final bitstream (lossless, 1.5-2x additional)

    Overall achievable: 3-8x compression on typical DHB-Token sequences.

    Example:
        >>> compressor = TokenCompressor(vocab_size=256)
        >>> compressor.fit(training_sequences)
        >>> 
        >>> # Compress
        >>> compressed = compressor.compress(tokens)
        >>> print(f"Ratio: {compressor.compression_ratio(tokens):.1f}x")
        >>> 
        >>> # Decompress (lossless)
        >>> recovered = compressor.decompress(compressed)
        >>> assert recovered == list(tokens)
    """

    def __init__(
        self,
        vocab_size: int = 256,
        use_rle: bool = True,
        use_bpe: bool = True,
        use_entropy: bool = True,
        bpe_merges: int = 100,
        rle_min_run: int = 3,
    ):
        """
        Args:
            vocab_size: VQ codebook size
            use_rle: Enable run-length encoding
            use_bpe: Enable byte-pair encoding
            use_entropy: Enable entropy (Huffman) coding
            bpe_merges: Number of BPE merges to learn
            rle_min_run: Minimum run length for RLE
        """
        self.vocab_size = vocab_size
        self.use_rle = use_rle
        self.use_bpe = use_bpe
        self.use_entropy = use_entropy

        self.rle = RLECompressor(min_run=rle_min_run) if use_rle else None
        self.bpe = BPECompressor(vocab_size=vocab_size, num_merges=bpe_merges) if use_bpe else None
        self.entropy = EntropyCompressor() if use_entropy else None

        self._fitted = False

    def fit(self, token_sequences: List[List[int]]) -> "TokenCompressor":
        """
        Fit all compression stages on training data.

        Args:
            token_sequences: List of token sequences

        Returns:
            self
        """
        current_sequences = [list(seq) for seq in token_sequences]

        # Stage 1: Learn BPE merges
        if self.bpe:
            self.bpe.fit(current_sequences)
            current_sequences = [self.bpe.encode(seq) for seq in current_sequences]

        # Stage 2: Learn entropy codes (after BPE)
        if self.entropy:
            self.entropy.fit(current_sequences)

        self._fitted = True
        return self

    def compress(self, tokens: Union[List[int], np.ndarray]) -> Dict:
        """
        Compress a token sequence.

        Args:
            tokens: Token sequence

        Returns:
            Dict with compressed data and metadata
        """
        if not self._fitted:
            raise RuntimeError("TokenCompressor must be fitted before compressing")

        tokens = list(tokens)
        original_len = len(tokens)

        result = {
            "original_length": original_len,
            "stages": {},
        }

        # Stage 1: RLE (optional, applied first for static segments)
        if self.rle:
            rle_encoded = self.rle.encode(tokens)
            # Flatten for next stage
            tokens = []
            for token, count in rle_encoded:
                tokens.extend([token, count])
            result["stages"]["rle"] = {
                "length": len(rle_encoded),
                "ratio": original_len / len(rle_encoded) if rle_encoded else 1.0,
            }
            result["rle_data"] = rle_encoded

        # Stage 2: BPE
        if self.bpe:
            bpe_encoded = self.bpe.encode(tokens)
            result["stages"]["bpe"] = {
                "length": len(bpe_encoded),
                "ratio": len(tokens) / len(bpe_encoded) if bpe_encoded else 1.0,
            }
            tokens = bpe_encoded

        # Stage 3: Entropy coding
        if self.entropy:
            binary = self.entropy.encode(tokens)
            result["stages"]["entropy"] = {
                "bits": len(binary),
                "bytes": len(binary) // 8 + (1 if len(binary) % 8 else 0),
                "bits_per_original_token": len(binary) / original_len if original_len > 0 else 0,
            }
            result["binary"] = binary
        else:
            result["tokens"] = tokens

        # Compute overall ratio
        if self.entropy:
            compressed_bits = len(result["binary"])
            original_bits = original_len * np.ceil(np.log2(self.vocab_size + 1))
            result["overall_ratio"] = original_bits / compressed_bits if compressed_bits > 0 else 1.0
        else:
            result["overall_ratio"] = original_len / len(result.get("tokens", tokens))

        return result

    def decompress(self, compressed: Dict) -> List[int]:
        """
        Decompress back to original tokens.

        Args:
            compressed: Output from compress()

        Returns:
            Original token sequence
        """
        # Reverse entropy coding
        if self.entropy and "binary" in compressed:
            tokens = self.entropy.decode(compressed["binary"])
        else:
            tokens = compressed.get("tokens", [])

        # Reverse BPE
        if self.bpe:
            tokens = self.bpe.decode(tokens)

        # Reverse RLE
        if self.rle and "rle_data" in compressed:
            # Reconstruct from (token, count) pairs
            tokens = self.rle.decode(compressed["rle_data"])

        return tokens

    def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
        """Compute overall compression ratio for a sequence."""
        compressed = self.compress(tokens)
        return compressed["overall_ratio"]

    def get_stats(self) -> Dict:
        """Get overall compression statistics."""
        stats = {
            "fitted": self._fitted,
            "stages": [],
        }
        if self.rle:
            stats["stages"].append("RLE")
        if self.bpe:
            stats["stages"].append("BPE")
            stats["bpe"] = self.bpe.get_stats()
        if self.entropy:
            stats["stages"].append("Entropy")
            stats["entropy"] = self.entropy.get_stats()
        return stats
Functions
__init__
__init__(
    vocab_size=256, use_rle=True, use_bpe=True, use_entropy=True, bpe_merges=100, rle_min_run=3
)

Parameters:

Name Type Description Default
vocab_size int

VQ codebook size

256
use_rle bool

Enable run-length encoding

True
use_bpe bool

Enable byte-pair encoding

True
use_entropy bool

Enable entropy (Huffman) coding

True
bpe_merges int

Number of BPE merges to learn

100
rle_min_run int

Minimum run length for RLE

3
Source code in src/dhb_xr/tokenization/compression.py
def __init__(
    self,
    vocab_size: int = 256,
    use_rle: bool = True,
    use_bpe: bool = True,
    use_entropy: bool = True,
    bpe_merges: int = 100,
    rle_min_run: int = 3,
):
    """
    Args:
        vocab_size: VQ codebook size
        use_rle: Enable run-length encoding
        use_bpe: Enable byte-pair encoding
        use_entropy: Enable entropy (Huffman) coding
        bpe_merges: Number of BPE merges to learn
        rle_min_run: Minimum run length for RLE
    """
    self.vocab_size = vocab_size
    self.use_rle = use_rle
    self.use_bpe = use_bpe
    self.use_entropy = use_entropy

    self.rle = RLECompressor(min_run=rle_min_run) if use_rle else None
    self.bpe = BPECompressor(vocab_size=vocab_size, num_merges=bpe_merges) if use_bpe else None
    self.entropy = EntropyCompressor() if use_entropy else None

    self._fitted = False
compress
compress(tokens)

Compress a token sequence.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required

Returns:

Type Description
Dict

Dict with compressed data and metadata

Source code in src/dhb_xr/tokenization/compression.py
def compress(self, tokens: Union[List[int], np.ndarray]) -> Dict:
    """
    Compress a token sequence.

    Args:
        tokens: Token sequence

    Returns:
        Dict with compressed data and metadata
    """
    if not self._fitted:
        raise RuntimeError("TokenCompressor must be fitted before compressing")

    tokens = list(tokens)
    original_len = len(tokens)

    result = {
        "original_length": original_len,
        "stages": {},
    }

    # Stage 1: RLE (optional, applied first for static segments)
    if self.rle:
        rle_encoded = self.rle.encode(tokens)
        # Flatten for next stage
        tokens = []
        for token, count in rle_encoded:
            tokens.extend([token, count])
        result["stages"]["rle"] = {
            "length": len(rle_encoded),
            "ratio": original_len / len(rle_encoded) if rle_encoded else 1.0,
        }
        result["rle_data"] = rle_encoded

    # Stage 2: BPE
    if self.bpe:
        bpe_encoded = self.bpe.encode(tokens)
        result["stages"]["bpe"] = {
            "length": len(bpe_encoded),
            "ratio": len(tokens) / len(bpe_encoded) if bpe_encoded else 1.0,
        }
        tokens = bpe_encoded

    # Stage 3: Entropy coding
    if self.entropy:
        binary = self.entropy.encode(tokens)
        result["stages"]["entropy"] = {
            "bits": len(binary),
            "bytes": len(binary) // 8 + (1 if len(binary) % 8 else 0),
            "bits_per_original_token": len(binary) / original_len if original_len > 0 else 0,
        }
        result["binary"] = binary
    else:
        result["tokens"] = tokens

    # Compute overall ratio
    if self.entropy:
        compressed_bits = len(result["binary"])
        original_bits = original_len * np.ceil(np.log2(self.vocab_size + 1))
        result["overall_ratio"] = original_bits / compressed_bits if compressed_bits > 0 else 1.0
    else:
        result["overall_ratio"] = original_len / len(result.get("tokens", tokens))

    return result
compression_ratio
compression_ratio(tokens)

Compute overall compression ratio for a sequence.

Source code in src/dhb_xr/tokenization/compression.py
def compression_ratio(self, tokens: Union[List[int], np.ndarray]) -> float:
    """Compute overall compression ratio for a sequence."""
    compressed = self.compress(tokens)
    return compressed["overall_ratio"]
decompress
decompress(compressed)

Decompress back to original tokens.

Parameters:

Name Type Description Default
compressed Dict

Output from compress()

required

Returns:

Type Description
List[int]

Original token sequence

Source code in src/dhb_xr/tokenization/compression.py
def decompress(self, compressed: Dict) -> List[int]:
    """
    Decompress back to original tokens.

    Args:
        compressed: Output from compress()

    Returns:
        Original token sequence
    """
    # Reverse entropy coding
    if self.entropy and "binary" in compressed:
        tokens = self.entropy.decode(compressed["binary"])
    else:
        tokens = compressed.get("tokens", [])

    # Reverse BPE
    if self.bpe:
        tokens = self.bpe.decode(tokens)

    # Reverse RLE
    if self.rle and "rle_data" in compressed:
        # Reconstruct from (token, count) pairs
        tokens = self.rle.decode(compressed["rle_data"])

    return tokens
fit
fit(token_sequences)

Fit all compression stages on training data.

Parameters:

Name Type Description Default
token_sequences List[List[int]]

List of token sequences

required

Returns:

Type Description
TokenCompressor

self

Source code in src/dhb_xr/tokenization/compression.py
def fit(self, token_sequences: List[List[int]]) -> "TokenCompressor":
    """
    Fit all compression stages on training data.

    Args:
        token_sequences: List of token sequences

    Returns:
        self
    """
    current_sequences = [list(seq) for seq in token_sequences]

    # Stage 1: Learn BPE merges
    if self.bpe:
        self.bpe.fit(current_sequences)
        current_sequences = [self.bpe.encode(seq) for seq in current_sequences]

    # Stage 2: Learn entropy codes (after BPE)
    if self.entropy:
        self.entropy.fit(current_sequences)

    self._fitted = True
    return self
get_stats
get_stats()

Get overall compression statistics.

Source code in src/dhb_xr/tokenization/compression.py
def get_stats(self) -> Dict:
    """Get overall compression statistics."""
    stats = {
        "fitted": self._fitted,
        "stages": [],
    }
    if self.rle:
        stats["stages"].append("RLE")
    if self.bpe:
        stats["stages"].append("BPE")
        stats["bpe"] = self.bpe.get_stats()
    if self.entropy:
        stats["stages"].append("Entropy")
        stats["entropy"] = self.entropy.get_stats()
    return stats

TokenReuser

Token reuse detector for inference acceleration.

Inspired by FlashVLA (2025): Skip decoding when tokens are stable/repeated, reusing previous outputs. Provides 2-5x effective speedup in long-horizon tasks.

Works by detecting: 1. Exact token repeats (static segments) 2. Token sequences matching known patterns (from database) 3. Low-variance token regions (approximate reuse)

Source code in src/dhb_xr/tokenization/compression.py
class TokenReuser:
    """
    Token reuse detector for inference acceleration.

    Inspired by FlashVLA (2025): Skip decoding when tokens are stable/repeated,
    reusing previous outputs. Provides 2-5x effective speedup in long-horizon tasks.

    Works by detecting:
    1. Exact token repeats (static segments)
    2. Token sequences matching known patterns (from database)
    3. Low-variance token regions (approximate reuse)
    """

    def __init__(self, window_size: int = 5, similarity_threshold: float = 0.9):
        """
        Args:
            window_size: Window for detecting stable regions
            similarity_threshold: Threshold for approximate matching
        """
        self.window_size = window_size
        self.similarity_threshold = similarity_threshold
        self.pattern_cache: Dict[tuple, np.ndarray] = {}  # pattern -> cached output

    def detect_stable_regions(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int, bool]]:
        """
        Detect regions where tokens are stable (can reuse previous decoding).

        Args:
            tokens: Token sequence

        Returns:
            List of (start, end, is_stable) tuples
        """
        tokens = np.array(tokens)
        n = len(tokens)

        if n < self.window_size:
            return [(0, n, False)]

        regions = []
        i = 0

        while i < n:
            # Check if next window_size tokens are identical
            end = min(i + self.window_size, n)
            window = tokens[i:end]

            if len(set(window)) == 1:  # All same
                # Extend stable region
                stable_start = i
                while end < n and tokens[end] == tokens[i]:
                    end += 1
                regions.append((stable_start, end, True))
                i = end
            else:
                regions.append((i, i + 1, False))
                i += 1

        # Merge adjacent non-stable regions
        merged = []
        for start, end, is_stable in regions:
            if merged and not merged[-1][2] and not is_stable:
                merged[-1] = (merged[-1][0], end, False)
            else:
                merged.append((start, end, is_stable))

        return merged

    def compute_reuse_potential(self, tokens: Union[List[int], np.ndarray]) -> Dict:
        """
        Analyze reuse potential for a sequence.

        Args:
            tokens: Token sequence

        Returns:
            Statistics on reuse potential
        """
        regions = self.detect_stable_regions(tokens)

        total_len = len(tokens)
        stable_len = sum(end - start for start, end, is_stable in regions if is_stable)

        return {
            "total_tokens": total_len,
            "stable_tokens": stable_len,
            "reuse_fraction": stable_len / total_len if total_len > 0 else 0,
            "num_regions": len(regions),
            "num_stable_regions": sum(1 for _, _, s in regions if s),
            "potential_speedup": total_len / (total_len - stable_len + len([r for r in regions if r[2]])) if total_len > stable_len else 1.0,
        }
Functions
__init__
__init__(window_size=5, similarity_threshold=0.9)

Parameters:

Name Type Description Default
window_size int

Window for detecting stable regions

5
similarity_threshold float

Threshold for approximate matching

0.9
Source code in src/dhb_xr/tokenization/compression.py
def __init__(self, window_size: int = 5, similarity_threshold: float = 0.9):
    """
    Args:
        window_size: Window for detecting stable regions
        similarity_threshold: Threshold for approximate matching
    """
    self.window_size = window_size
    self.similarity_threshold = similarity_threshold
    self.pattern_cache: Dict[tuple, np.ndarray] = {}  # pattern -> cached output
compute_reuse_potential
compute_reuse_potential(tokens)

Analyze reuse potential for a sequence.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required

Returns:

Type Description
Dict

Statistics on reuse potential

Source code in src/dhb_xr/tokenization/compression.py
def compute_reuse_potential(self, tokens: Union[List[int], np.ndarray]) -> Dict:
    """
    Analyze reuse potential for a sequence.

    Args:
        tokens: Token sequence

    Returns:
        Statistics on reuse potential
    """
    regions = self.detect_stable_regions(tokens)

    total_len = len(tokens)
    stable_len = sum(end - start for start, end, is_stable in regions if is_stable)

    return {
        "total_tokens": total_len,
        "stable_tokens": stable_len,
        "reuse_fraction": stable_len / total_len if total_len > 0 else 0,
        "num_regions": len(regions),
        "num_stable_regions": sum(1 for _, _, s in regions if s),
        "potential_speedup": total_len / (total_len - stable_len + len([r for r in regions if r[2]])) if total_len > stable_len else 1.0,
    }
detect_stable_regions
detect_stable_regions(tokens)

Detect regions where tokens are stable (can reuse previous decoding).

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required

Returns:

Type Description
List[Tuple[int, int, bool]]

List of (start, end, is_stable) tuples

Source code in src/dhb_xr/tokenization/compression.py
def detect_stable_regions(self, tokens: Union[List[int], np.ndarray]) -> List[Tuple[int, int, bool]]:
    """
    Detect regions where tokens are stable (can reuse previous decoding).

    Args:
        tokens: Token sequence

    Returns:
        List of (start, end, is_stable) tuples
    """
    tokens = np.array(tokens)
    n = len(tokens)

    if n < self.window_size:
        return [(0, n, False)]

    regions = []
    i = 0

    while i < n:
        # Check if next window_size tokens are identical
        end = min(i + self.window_size, n)
        window = tokens[i:end]

        if len(set(window)) == 1:  # All same
            # Extend stable region
            stable_start = i
            while end < n and tokens[end] == tokens[i]:
                end += 1
            regions.append((stable_start, end, True))
            i = end
        else:
            regions.append((i, i + 1, False))
            i += 1

    # Merge adjacent non-stable regions
    merged = []
    for start, end, is_stable in regions:
        if merged and not merged[-1][2] and not is_stable:
            merged[-1] = (merged[-1][0], end, False)
        else:
            merged.append((start, end, is_stable))

    return merged

Functions

compress_token_sequence

compress_token_sequence(tokens, vocab_size=256, method='bpe', **kwargs)

Compress a token sequence with specified method.

Parameters:

Name Type Description Default
tokens Union[List[int], ndarray]

Token sequence

required
vocab_size int

VQ codebook size

256
method str

"bpe", "entropy", "rle", or "full"

'bpe'
**kwargs

Additional arguments for compressor

{}

Returns:

Type Description
Dict

Compression result dict

Source code in src/dhb_xr/tokenization/compression.py
def compress_token_sequence(
    tokens: Union[List[int], np.ndarray],
    vocab_size: int = 256,
    method: str = "bpe",
    **kwargs
) -> Dict:
    """
    Compress a token sequence with specified method.

    Args:
        tokens: Token sequence
        vocab_size: VQ codebook size
        method: "bpe", "entropy", "rle", or "full"
        **kwargs: Additional arguments for compressor

    Returns:
        Compression result dict
    """
    if method == "bpe":
        compressor = BPECompressor(vocab_size=vocab_size, **kwargs)
        compressor.fit([list(tokens)])  # Self-fit for single sequence
        encoded = compressor.encode(tokens)
        return {
            "encoded": encoded,
            "ratio": len(tokens) / len(encoded),
            "method": "bpe",
        }
    elif method == "entropy":
        compressor = EntropyCompressor()
        compressor.fit([list(tokens)])
        binary = compressor.encode(tokens)
        return {
            "binary": binary,
            "bits_per_token": len(binary) / len(tokens),
            "ratio": (len(tokens) * 8) / len(binary),  # vs 8-bit naive
            "method": "entropy",
        }
    elif method == "rle":
        compressor = RLECompressor(**kwargs)
        encoded = compressor.encode(tokens)
        return {
            "encoded": encoded,
            "ratio": compressor.compression_ratio(tokens),
            "stats": compressor.get_stats(tokens),
            "method": "rle",
        }
    elif method == "full":
        compressor = TokenCompressor(vocab_size=vocab_size, **kwargs)
        compressor.fit([list(tokens)])
        return compressor.compress(tokens)
    else:
        raise ValueError(f"Unknown method: {method}")

Overview

VQ-VAE-based tokenization for discrete DHB invariant representations.

Main Classes

DHBTokenizer

DHBTokenizer

Bases: Module

Causal VQ-VAE for invariant sequences. invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).

Source code in src/dhb_xr/tokenization/vqvae.py
class DHBTokenizer(nn.Module):
    """
    Causal VQ-VAE for invariant sequences.
    invariants (B, T, C) -> tokens (B, T), reconstructed (B, T, C).
    """

    def __init__(
        self,
        invariant_dim: int,
        latent_dim: int,
        codebook_size: int,
        num_layers: int = 2,
        kernel_size: int = 3,
    ):
        super().__init__()
        self.encoder = CausalConv1dEncoder(
            invariant_dim, latent_dim, latent_dim, num_layers, kernel_size
        )
        self.vq = VectorQuantizer(codebook_size, latent_dim)
        self.decoder = CausalConv1dEncoder(
            latent_dim, latent_dim, invariant_dim, num_layers, kernel_size
        )
        self.invariant_dim = invariant_dim
        self.latent_dim = latent_dim
        self.codebook_size = codebook_size

    def forward(self, invariants: torch.Tensor) -> tuple:
        z = self.encoder(invariants)
        indices, z_q_st, z_q = self.vq(z)
        reconstructed = self.decoder(z_q_st)
        return indices, reconstructed, z, z_q

    def loss(
        self,
        invariants: torch.Tensor,
        reconstructed: torch.Tensor,
        z: torch.Tensor,
        z_q: torch.Tensor,
        beta: float = 0.25,
    ) -> torch.Tensor:
        rec_loss = F.mse_loss(reconstructed, invariants)
        commitment = F.mse_loss(z, z_q)
        codebook = F.mse_loss(z_q, z.detach())
        return rec_loss + beta * commitment + codebook

    # ---- Flow matching integration API ----

    def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
        """
        Encode invariants to continuous latent space (before quantization).

        This is useful for flow matching which operates in continuous space.

        Args:
            invariants: Input invariant sequences (B, T, C).

        Returns:
            Continuous latent z (B, T, latent_dim).
        """
        return self.encoder(invariants)

    def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decode from continuous latent to invariants.

        Bypasses the VQ step, useful for flow matching generation.

        Args:
            z: Continuous latent (B, T, latent_dim).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        return self.decoder(z)

    def quantize(self, z: torch.Tensor) -> tuple:
        """
        Quantize continuous latent to discrete tokens.

        Args:
            z: Continuous latent (B, T, latent_dim).

        Returns:
            Tuple of (indices, z_q_st, z_q).
        """
        return self.vq(z)

    def get_codebook_embeddings(self) -> torch.Tensor:
        """
        Get the VQ codebook embeddings.

        Useful for flow matching in embedding space or visualization.

        Returns:
            Codebook embeddings (codebook_size, latent_dim).
        """
        return self.vq.embedding.weight.data

    def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Convert token indices to embeddings.

        Args:
            indices: Token indices (B, T).

        Returns:
            Token embeddings (B, T, latent_dim).
        """
        return self.vq.embedding(indices)

    def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
        """
        Decode token indices to invariants.

        Args:
            indices: Token indices (B, T).

        Returns:
            Reconstructed invariants (B, T, invariant_dim).
        """
        z_q = self.embed_tokens(indices)
        return self.decoder(z_q)

Functions

decode_from_latent

decode_from_latent(z)

Decode from continuous latent to invariants.

Bypasses the VQ step, useful for flow matching generation.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def decode_from_latent(self, z: torch.Tensor) -> torch.Tensor:
    """
    Decode from continuous latent to invariants.

    Bypasses the VQ step, useful for flow matching generation.

    Args:
        z: Continuous latent (B, T, latent_dim).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    return self.decoder(z)

decode_tokens

decode_tokens(indices)

Decode token indices to invariants.

Parameters:

Name Type Description Default
indices Tensor

Token indices (B, T).

required

Returns:

Type Description
Tensor

Reconstructed invariants (B, T, invariant_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def decode_tokens(self, indices: torch.Tensor) -> torch.Tensor:
    """
    Decode token indices to invariants.

    Args:
        indices: Token indices (B, T).

    Returns:
        Reconstructed invariants (B, T, invariant_dim).
    """
    z_q = self.embed_tokens(indices)
    return self.decoder(z_q)

embed_tokens

embed_tokens(indices)

Convert token indices to embeddings.

Parameters:

Name Type Description Default
indices Tensor

Token indices (B, T).

required

Returns:

Type Description
Tensor

Token embeddings (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def embed_tokens(self, indices: torch.Tensor) -> torch.Tensor:
    """
    Convert token indices to embeddings.

    Args:
        indices: Token indices (B, T).

    Returns:
        Token embeddings (B, T, latent_dim).
    """
    return self.vq.embedding(indices)

encode_continuous

encode_continuous(invariants)

Encode invariants to continuous latent space (before quantization).

This is useful for flow matching which operates in continuous space.

Parameters:

Name Type Description Default
invariants Tensor

Input invariant sequences (B, T, C).

required

Returns:

Type Description
Tensor

Continuous latent z (B, T, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def encode_continuous(self, invariants: torch.Tensor) -> torch.Tensor:
    """
    Encode invariants to continuous latent space (before quantization).

    This is useful for flow matching which operates in continuous space.

    Args:
        invariants: Input invariant sequences (B, T, C).

    Returns:
        Continuous latent z (B, T, latent_dim).
    """
    return self.encoder(invariants)

get_codebook_embeddings

get_codebook_embeddings()

Get the VQ codebook embeddings.

Useful for flow matching in embedding space or visualization.

Returns:

Type Description
Tensor

Codebook embeddings (codebook_size, latent_dim).

Source code in src/dhb_xr/tokenization/vqvae.py
def get_codebook_embeddings(self) -> torch.Tensor:
    """
    Get the VQ codebook embeddings.

    Useful for flow matching in embedding space or visualization.

    Returns:
        Codebook embeddings (codebook_size, latent_dim).
    """
    return self.vq.embedding.weight.data

quantize

quantize(z)

Quantize continuous latent to discrete tokens.

Parameters:

Name Type Description Default
z Tensor

Continuous latent (B, T, latent_dim).

required

Returns:

Type Description
tuple

Tuple of (indices, z_q_st, z_q).

Source code in src/dhb_xr/tokenization/vqvae.py
def quantize(self, z: torch.Tensor) -> tuple:
    """
    Quantize continuous latent to discrete tokens.

    Args:
        z: Continuous latent (B, T, latent_dim).

    Returns:
        Tuple of (indices, z_q_st, z_q).
    """
    return self.vq(z)

Usage Example

import torch
from dhb_xr.tokenization.vqvae import DHBTokenizer

# Create tokenizer
tokenizer = DHBTokenizer(
    invariant_dim=8,
    latent_dim=16,
    codebook_size=64
)

# Tokenize invariants
invariants = torch.randn(10, 20, 8)  # (batch, time, features)
tokens = tokenizer.encode(invariants)

# Decode back
reconstructed = tokenizer.decode(tokens)