class Grok1ScalingRotaryEmbedding(RotaryEmbedding):
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extra_method: str = "yarn_log",
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extra_method = extra_method
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
if self.extra_method in ["original"]:
inv_freq = inv_freq_extrapolation
elif self.extra_method in ["yarn", "yarn_linear"]:
inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) +
inv_freq_extrapolation * inv_freq_mask)
elif self.extra_method == "yarn_log":
inv_freq = torch.exp(
torch.log(inv_freq_extrapolation) * inv_freq_mask +
torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask))
elif self.extra_method == "theta_scale":
exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
theta_scale_exponent = self.base**(
math.log(self.max_position_embeddings * self.scaling_factor /
(2 * math.pi)) /
math.log(self.max_position_embeddings / (2 * math.pi)))
inv_freq = torch.tensor(
1.0 / (theta_scale_exponent**(exponents / self.rotary_dim)),
dtype=torch.float32,
)
else:
raise ValueError(
f"Unknown extrapolation method: {self.extra_method}")
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
# cos = freqs.cos() * self.mscale
# sin = freqs.sin() * self.mscale
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache