189 lines
7.5 KiB
Python
189 lines
7.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the Apache License, Version 2.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
|
|
|
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
|
# which extends the original RoPE concept to handle 2D spatial positions.
|
|
|
|
# Inspired by:
|
|
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
|
# https://github.com/naver-ai/rope-vit
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
class PositionGetter:
|
|
"""Generates and caches 2D spatial positions for patches in a grid.
|
|
|
|
This class efficiently manages the generation of spatial coordinates for patches
|
|
in a 2D grid, caching results to avoid redundant computations.
|
|
|
|
Attributes:
|
|
position_cache: Dictionary storing precomputed position tensors for different
|
|
grid dimensions.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes the position generator with an empty cache."""
|
|
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
|
|
|
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
|
"""Generates spatial positions for a batch of patches.
|
|
|
|
Args:
|
|
batch_size: Number of samples in the batch.
|
|
height: Height of the grid in patches.
|
|
width: Width of the grid in patches.
|
|
device: Target device for the position tensor.
|
|
|
|
Returns:
|
|
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
|
for each position in the grid, repeated for each batch item.
|
|
"""
|
|
if (height, width) not in self.position_cache:
|
|
y_coords = torch.arange(height, device=device)
|
|
x_coords = torch.arange(width, device=device)
|
|
positions = torch.cartesian_prod(y_coords, x_coords)
|
|
self.position_cache[height, width] = positions
|
|
|
|
cached_positions = self.position_cache[height, width]
|
|
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
|
|
|
|
|
class RotaryPositionEmbedding2D(nn.Module):
|
|
"""2D Rotary Position Embedding implementation.
|
|
|
|
This module applies rotary position embeddings to input tokens based on their
|
|
2D spatial positions. It handles the position-dependent rotation of features
|
|
separately for vertical and horizontal dimensions.
|
|
|
|
Args:
|
|
frequency: Base frequency for the position embeddings. Default: 100.0
|
|
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
|
|
|
Attributes:
|
|
base_frequency: Base frequency for computing position embeddings.
|
|
scaling_factor: Factor to scale the computed frequencies.
|
|
frequency_cache: Cache for storing precomputed frequency components.
|
|
"""
|
|
|
|
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
|
"""Initializes the 2D RoPE module."""
|
|
super().__init__()
|
|
self.base_frequency = frequency
|
|
self.scaling_factor = scaling_factor
|
|
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
|
|
|
def _compute_frequency_components(
|
|
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Computes frequency components for rotary embeddings.
|
|
|
|
Args:
|
|
dim: Feature dimension (must be even).
|
|
seq_len: Maximum sequence length.
|
|
device: Target device for computations.
|
|
dtype: Data type for the computed tensors.
|
|
|
|
Returns:
|
|
Tuple of (cosine, sine) tensors for frequency components.
|
|
"""
|
|
cache_key = (dim, seq_len, device, dtype)
|
|
if cache_key not in self.frequency_cache:
|
|
# Compute frequency bands
|
|
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
|
inv_freq = 1.0 / (self.base_frequency**exponents)
|
|
|
|
# Generate position-dependent frequencies
|
|
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
|
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
|
|
|
# Compute and cache frequency components
|
|
angles = angles.to(dtype)
|
|
angles = torch.cat((angles, angles), dim=-1)
|
|
cos_components = angles.cos().to(dtype)
|
|
sin_components = angles.sin().to(dtype)
|
|
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
|
|
|
return self.frequency_cache[cache_key]
|
|
|
|
@staticmethod
|
|
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
|
"""Performs feature rotation by splitting and recombining feature dimensions.
|
|
|
|
Args:
|
|
x: Input tensor to rotate.
|
|
|
|
Returns:
|
|
Rotated feature tensor.
|
|
"""
|
|
feature_dim = x.shape[-1]
|
|
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
def _apply_1d_rope(
|
|
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""Applies 1D rotary position embeddings along one dimension.
|
|
|
|
Args:
|
|
tokens: Input token features.
|
|
positions: Position indices.
|
|
cos_comp: Cosine components for rotation.
|
|
sin_comp: Sine components for rotation.
|
|
|
|
Returns:
|
|
Tokens with applied rotary position embeddings.
|
|
"""
|
|
# Embed positions with frequency components
|
|
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
|
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
|
|
|
# Apply rotation
|
|
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
|
|
|
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
|
"""Applies 2D rotary position embeddings to input tokens.
|
|
|
|
Args:
|
|
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
|
The feature dimension (dim) must be divisible by 4.
|
|
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
|
the y and x coordinates for each token.
|
|
|
|
Returns:
|
|
Tensor of same shape as input with applied 2D rotary position embeddings.
|
|
|
|
Raises:
|
|
AssertionError: If input dimensions are invalid or positions are malformed.
|
|
"""
|
|
# Validate inputs
|
|
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
|
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
|
|
|
# Compute feature dimension for each spatial direction
|
|
feature_dim = tokens.size(-1) // 2
|
|
|
|
# Get frequency components
|
|
max_position = int(positions.max()) + 1
|
|
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
|
|
|
# Split features for vertical and horizontal processing
|
|
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
|
|
|
# Apply RoPE separately for each dimension
|
|
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
|
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
|
|
|
# Combine processed features
|
|
return torch.cat((vertical_features, horizontal_features), dim=-1)
|