Skip to content

vllm.model_executor.layers.rotary_embedding.grok1_scaling_rope

Grok1ScalingRotaryEmbedding

Bases: RotaryEmbedding

Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071.

Source code in vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py
class Grok1ScalingRotaryEmbedding(RotaryEmbedding):
    """Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: int,
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
        *,
        extra_method: str = "yarn_log",
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extra_method = extra_method
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        # Get n-d magnitude scaling corrected for interpolation
        self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
        super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                         is_neox_style, dtype)

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
        pos_freqs = self.base**(
            torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
            self.rotary_dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = yarn_find_correction_range(
            self.beta_fast,
            self.beta_slow,
            self.rotary_dim,
            self.base,
            self.max_position_embeddings,
        )
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (1 - yarn_linear_ramp_mask(
            low, high, self.rotary_dim // 2,
            dtype=torch.float)) * self.extrapolation_factor
        if self.extra_method in ["original"]:
            inv_freq = inv_freq_extrapolation
        elif self.extra_method in ["yarn", "yarn_linear"]:
            inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) +
                        inv_freq_extrapolation * inv_freq_mask)
        elif self.extra_method == "yarn_log":
            inv_freq = torch.exp(
                torch.log(inv_freq_extrapolation) * inv_freq_mask +
                torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask))
        elif self.extra_method == "theta_scale":
            exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
            theta_scale_exponent = self.base**(
                math.log(self.max_position_embeddings * self.scaling_factor /
                         (2 * math.pi)) /
                math.log(self.max_position_embeddings / (2 * math.pi)))
            inv_freq = torch.tensor(
                1.0 / (theta_scale_exponent**(exponents / self.rotary_dim)),
                dtype=torch.float32,
            )
        else:
            raise ValueError(
                f"Unknown extrapolation method: {self.extra_method}")
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                         dtype=torch.float32)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        # cos = freqs.cos() * self.mscale
        # sin = freqs.sin() * self.mscale
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

attn_factor instance-attribute

attn_factor = attn_factor

beta_fast instance-attribute

beta_fast = beta_fast

beta_slow instance-attribute

beta_slow = beta_slow

extra_method instance-attribute

extra_method = extra_method

extrapolation_factor instance-attribute

extrapolation_factor = extrapolation_factor

mscale instance-attribute

mscale = float(
    yarn_get_mscale(scaling_factor) * attn_factor
)

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: int,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
    *,
    extra_method: str = "yarn_log",
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: int,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: torch.dtype,
    *,
    extra_method: str = "yarn_log",
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
) -> None:
    self.scaling_factor = scaling_factor
    self.extra_method = extra_method
    self.extrapolation_factor = extrapolation_factor
    self.attn_factor = attn_factor
    self.beta_fast = beta_fast
    self.beta_slow = beta_slow
    # Get n-d magnitude scaling corrected for interpolation
    self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
    super().__init__(head_size, rotary_dim, max_position_embeddings, base,
                     is_neox_style, dtype)

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.scaling_factor)
    t = torch.arange(self.max_position_embeddings * self.scaling_factor,
                     dtype=torch.float32)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    # cos = freqs.cos() * self.mscale
    # sin = freqs.sin() * self.mscale
    cos = freqs.cos()
    sin = freqs.sin()
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(scaling_factor: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
    pos_freqs = self.base**(
        torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
        self.rotary_dim)
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

    low, high = yarn_find_correction_range(
        self.beta_fast,
        self.beta_slow,
        self.rotary_dim,
        self.base,
        self.max_position_embeddings,
    )
    # Get n-d rotational scaling corrected for extrapolation
    inv_freq_mask = (1 - yarn_linear_ramp_mask(
        low, high, self.rotary_dim // 2,
        dtype=torch.float)) * self.extrapolation_factor
    if self.extra_method in ["original"]:
        inv_freq = inv_freq_extrapolation
    elif self.extra_method in ["yarn", "yarn_linear"]:
        inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) +
                    inv_freq_extrapolation * inv_freq_mask)
    elif self.extra_method == "yarn_log":
        inv_freq = torch.exp(
            torch.log(inv_freq_extrapolation) * inv_freq_mask +
            torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask))
    elif self.extra_method == "theta_scale":
        exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
        theta_scale_exponent = self.base**(
            math.log(self.max_position_embeddings * self.scaling_factor /
                     (2 * math.pi)) /
            math.log(self.max_position_embeddings / (2 * math.pi)))
        inv_freq = torch.tensor(
            1.0 / (theta_scale_exponent**(exponents / self.rotary_dim)),
            dtype=torch.float32,
        )
    else:
        raise ValueError(
            f"Unknown extrapolation method: {self.extra_method}")
    return inv_freq