【Tranformer-GPT】使用注意力机制进行类GPT模型训练

By e2hang at 2025-11-01 • 0人收藏 • 94人看过

一、相关原理

    类GPT,把输入的token作为文章的开头,进行自回归输入,最终输出接下来的文本。相比Transformer(翻译),GPT只需要用到Decoder,相比之下比较好写


二、具体效果

    现在一共训练了55万个Batch,取其中一些输出作为训练效果的体现,具体如下:

·batch-14000; loss-5.2

320c02ffd735775dd27705ee1d1242e6.png

词汇表还没有多少,正在学习基本语句


·batch-25000; loss-3.5126d2f53936bae91a19a65b571eed8f6.png

句子结构基本正确,但是逻辑欠缺,出现一些不明所以的句子;对部分词语理解有误


·batch-48000; loss-3.0

764335f5b1cc210feb53d98ae7d94e85.png

对部分词的理解欠缺,句子上下文衔接不连贯


·batch-69000; loss-2.5

517caa8175379d5ffe91e304f46442a4.png

逻辑转换莫名其妙,出现不明所以的人物,词性未完全理解


·batch-97000; loss-2.2

84ce2186db4b1e657449ed786fa3cc87.png

基本语句流畅,句间逻辑有很大问题,并且重复的词比较多


·batch-117000; loss-2.0

70e56d8f64e2e48939a047442d419ed9.png

句子语法正确,词义理解基本正确,重复的词汇减少,但是句间逻辑依然有比较大的问题


·batch-197000; loss-1.75

9036cdd29a36d7632097ec97f40b1f93.png

还在学习句子之间的逻辑,明显比上面好,但是出现莫名其妙的转折点


·batch-384000; loss-1.55

89952d488929c6be00ae80b0ed4d4762.png

句子之间的逻辑明显好很多,形容词增加,句子成分更加复杂


·batch-425000; loss-1.5

d746f254faf24d890e0f70488e903758.png

生成的故事已经比较有逻辑了,中间可能有些断断续续,突然出现了一些莫名其妙的内容,但是至少已经很连贯了


·batch-514000; loss-1.45

1bfb8025841c971381b891c15e4199b8.png

对于长故事的创造力很强,但是句子之间还是缺乏逻辑,以及连接的时候会有很多问题。GPT还在天马行空的想象吧!


对上面的内容进行一个总结:

1、loss的折线图如下所示,逐渐趋缓并且达到一个瓶颈期

f9973dab3cbddd4caeda4bcde025586e.png

2、学习内容、产出文本解读

    在多次实验中,生成参数保持一致,唯一的变量是所使用的模型。根据 Transformer 的结构特点,每个注意力头(Multi-Head Attention)都会捕捉不同层面的上下文关系。从生成结果可以看出,GPT 已经能够较好地理解词语的基本含义以及上下文之间的常见组合。然而,它在深入理解整体语义、以及在自回归生成过程中有效记忆较远的上下文内容方面,仍存在一定的不足。


三、实现代码(仅展示部分重要模块)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
 
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x, mask=None):
        B, T, C = x.shape
 
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
 
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e4)
 
        attn = F.softmax(scores.float(), dim=-1).type_as(scores)
        attn = self.dropout(attn)
 
        output = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(output)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Model(nn.Module):
    def __init__(self, vocab_size, d_model=512, n_heads=8, num_layers=12,
                 max_seq_len=128, d_ff=2048, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
 
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
 
        self.blocks = nn.ModuleList([
            Block(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)
        ])
 
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
 
        self.embedding.weight = self.head.weight
 
        self.apply(self._init_weights)
 
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
 
    def forward(self, idx, targets=None):
        B, T = idx.shape
        device = idx.device
 
        tok_emb = self.embedding(idx)
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos)
        x = self.dropout(tok_emb + pos_emb)
 
        mask = torch.tril(torch.ones(T, T, device=device)).view(1, 1, T, T)
 
        for block in self.blocks:
            x = block(x, mask)
 
        x = self.ln_f(x)
        logits = self.head(x)
 
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss = torch.clamp(loss, 0, 15)
 
        return logits, loss


登录后方可回帖

登 录
信息栏
欢迎来到滑稽社论坛!注册会员即可发帖!

你好啊

Loading...