Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

模型概览

LLaMA 模型是目前最流行和性能最强大的开源模型之一,基于 LLaMA 所构造的模型生态可以覆盖绝大部分模型使用场景。本节将介绍LLaMA的模型结构及代码实现。

与在之前文章中所介绍的 Transformer架构(爱吃牛油果的璐璐:万字长文全面解析transformer(二更,附代码实现))不同的地方包括采用了前置层归一化(Pre-normalization)并使用RMSNorm归一化函数(Normalizing Function)、激活函数更换为SwiGLU,并使用了旋转位置嵌入(RoP),整体Transformer架构与GPT-2类似,如下图所示。

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?
LLaMA模型结构
Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?
基本的transformer

接下来,将分别介绍RMSNorm归一化函数、SwiGLU激活函数和旋转位置嵌入(RoPE)的具体内容和实现。

RMSNorm 归一化函数

为了使得模型训练过程更加稳定,GPT-2相较于GPT就引入了前置层归一化方法,将第一个层归一化移动到多头自注意力层之前,第二个层归一化也移动到了全连接层之前,同时残差连接的位置也调整到了多头自注意力层与全连接层之后。层归一化中也采用了RMSNorm归一化函数。针对输入向量a,RMSNorm 函数计算公式如下:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

此外,RMSNorm 还可以引入可学习的缩放因子 gi 和偏移参数 bi,从而得到:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

RMSNorm 在HuggingFace Transformer库中代码实现如下所示:

class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): “”” LlamaRMSNorm is equivalent to T5LayerNorm “”” super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps # eps 防止取倒数之后分母为 0 def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # weight 是末尾乘的可训练参数, 即 g_i return (self.weight * hidden_states).to(input_dtype)

SwiGLU 激活函数

SwiGLU激活函数是Shazeer在文献中提出,并在PaLM等模中进行了广泛应用,并且取得了不错的效果,相较于 ReLU 函数在大部分评测中都有不少提升。在 LLaMA 中全连接层使用带有 SwiGLU激活函数的的计算公式如下:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

其中,σ(x) 是 Sigmoid 函数。下图给出了 Swish 激活函数在参数 β 不同取值下的形状。可以看到当 β 趋近于 0 时,Swish 函数趋近于线性函数 y = x,当 β 趋近于无穷大时,Swish 函数趋近于 ReLU 函数,β 取值为 1 时,Swish 函数是光滑且非单调。在 HuggingFace 的 Transformer 库中Swish函数使用silu函数代替。

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

旋转位置嵌入(RoPE)

位置编码上,使用旋转位置嵌入(Rotary Positional Embeddings,RoPE)代替原有的绝对位置编码。RoPE借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。其目标是通过下述运算来给q,k 添加绝对位置信息:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

经过上述操作后, qm~\tilde{q_{m}}kn~\tilde{k_{n}} 就带有位置m和n 的绝对位置信息。详细的证明和求解过程可以参考文献[52],最终可以得到二维情况下用复数表示的RoPE:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

根据复数乘法的几何意义,上述变换实际上是对应向量旋转,所以位置向量称为“旋转式位置编码”。还可以使用矩阵形式表示:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

根据内积满足线性叠加的性质,任意偶数维的 RoPE,都可以表示为二维情形的拼接,即:

Meta 发布 AI 大型语言模型 LLaMA,其中都有哪些值得关注的亮点设计?

由于上述矩阵 Rn 具有稀疏性,因此可以使用逐位相乘 ⊗ 操作进一步加快计算速度。RoPE 在HuggingFace Transformer库中代码实现如下所示:

class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base **(torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer(“inv_freq”, inv_freq) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum(“i,j->ij”, t, self.inv_freq) # Different from paper, but it uses a different permutation # in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) dtype = torch.get_default_dtype() self.register_buffer(“cos_cached”, emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer(“sin_cached”, emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] # This `if` block is unlikely to be run after we build sin/cos in `__init__`. # Keep the logic here just in case. if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum(“i,j->ij”, t, self.inv_freq) # Different from paper, but it uses a different permutation # in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer(“cos_cached”, emb.cos()[None, None, :, :].to(x.dtype), persistent=False) self.register_buffer(“sin_cached”, emb.sin()[None, None, :, :].to(x.dtype), persistent=False) return ( self.cos_cached[:, :, :seq_len, …].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, …].to(dtype=x.dtype), ) def rotate_half(x): “””Rotates half the hidden dims of the input.””” x1 = x[…, :x.shape[-1] // 2] x2 = x[…, x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed

上一篇 2024年4月28日 15:03:26
下一篇 2024年4月28日

相关推荐