jingyaogong commited on
Commit
dcca4a1
·
verified ·
1 Parent(s): 1107462

Upload 3 files

Browse files
Files changed (3) hide show
  1. LMConfig.py +10 -7
  2. config.json +32 -31
  3. model.py +146 -195
LMConfig.py CHANGED
@@ -15,7 +15,8 @@ class LMConfig(PretrainedConfig):
15
  hidden_dim: int = None,
16
  multiple_of: int = 64,
17
  norm_eps: float = 1e-5,
18
- max_seq_len: int = 512,
 
19
  dropout: float = 0.0,
20
  flash_attn: bool = True,
21
  ####################################################
@@ -23,13 +24,14 @@ class LMConfig(PretrainedConfig):
23
  # When use_moe is false, the following is invalid
24
  ####################################################
25
  use_moe: bool = False,
26
- num_experts_per_tok=2,
27
- n_routed_experts=4,
 
28
  n_shared_experts: bool = True,
29
- scoring_func='softmax',
30
- aux_loss_alpha=0.01,
31
- seq_aux=True,
32
- norm_topk_prob=True,
33
  **kwargs,
34
  ):
35
  self.dim = dim
@@ -41,6 +43,7 @@ class LMConfig(PretrainedConfig):
41
  self.multiple_of = multiple_of
42
  self.norm_eps = norm_eps
43
  self.max_seq_len = max_seq_len
 
44
  self.dropout = dropout
45
  self.flash_attn = flash_attn
46
  ####################################################
 
15
  hidden_dim: int = None,
16
  multiple_of: int = 64,
17
  norm_eps: float = 1e-5,
18
+ max_seq_len: int = 8192,
19
+ rope_theta: int = 1e4,
20
  dropout: float = 0.0,
21
  flash_attn: bool = True,
22
  ####################################################
 
24
  # When use_moe is false, the following is invalid
25
  ####################################################
26
  use_moe: bool = False,
27
+ ####################################################
28
+ num_experts_per_tok: int = 2,
29
+ n_routed_experts: int = 4,
30
  n_shared_experts: bool = True,
31
+ scoring_func: str = 'softmax',
32
+ aux_loss_alpha: float = 0.1,
33
+ seq_aux: bool = True,
34
+ norm_topk_prob: bool = True,
35
  **kwargs,
36
  ):
37
  self.dim = dim
 
43
  self.multiple_of = multiple_of
44
  self.norm_eps = norm_eps
45
  self.max_seq_len = max_seq_len
46
+ self.rope_theta = rope_theta
47
  self.dropout = dropout
48
  self.flash_attn = flash_attn
49
  ####################################################
config.json CHANGED
@@ -1,31 +1,32 @@
1
- {
2
- "architectures": [
3
- "Transformer"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "LMConfig.LMConfig",
7
- "AutoModelForCausalLM": "model.Transformer"
8
- },
9
- "aux_loss_alpha": 0.01,
10
- "dim": 768,
11
- "dropout": 0.0,
12
- "flash_attn": true,
13
- "hidden_dim": null,
14
- "max_seq_len": 512,
15
- "model_type": "minimind",
16
- "multiple_of": 64,
17
- "n_heads": 16,
18
- "n_kv_heads": 8,
19
- "n_layers": 16,
20
- "n_routed_experts": 4,
21
- "n_shared_experts": true,
22
- "norm_eps": 1e-05,
23
- "norm_topk_prob": true,
24
- "num_experts_per_tok": 2,
25
- "scoring_func": "softmax",
26
- "seq_aux": true,
27
- "torch_dtype": "float32",
28
- "transformers_version": "4.37.2",
29
- "use_moe": false,
30
- "vocab_size": 6400
31
- }
 
 
1
+ {
2
+ "architectures": [
3
+ "MiniMindLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "LMConfig.LMConfig",
7
+ "AutoModelForCausalLM": "model.MiniMindLM"
8
+ },
9
+ "aux_loss_alpha": 0.01,
10
+ "dim": 768,
11
+ "dropout": 0.0,
12
+ "flash_attn": true,
13
+ "hidden_dim": null,
14
+ "max_seq_len": 512,
15
+ "model_type": "minimind",
16
+ "multiple_of": 64,
17
+ "n_heads": 16,
18
+ "n_kv_heads": 8,
19
+ "n_layers": 16,
20
+ "n_routed_experts": 4,
21
+ "n_shared_experts": true,
22
+ "norm_eps": 1e-05,
23
+ "norm_topk_prob": true,
24
+ "rope_theta": 10000.0,
25
+ "num_experts_per_tok": 2,
26
+ "scoring_func": "softmax",
27
+ "seq_aux": true,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.37.2",
30
+ "use_moe": false,
31
+ "vocab_size": 6400
32
+ }
model.py CHANGED
@@ -4,7 +4,7 @@ import inspect
4
  import time
5
 
6
  from .LMConfig import LMConfig
7
- from typing import Any, Optional, Tuple
8
  import numpy as np
9
  import torch
10
  import torch.nn.functional as F
@@ -19,15 +19,11 @@ class RMSNorm(torch.nn.Module):
19
  self.eps = eps
20
  self.weight = nn.Parameter(torch.ones(dim))
21
 
22
- def _norm(self, x):
23
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
24
-
25
  def forward(self, x):
26
- output = self._norm(x.float()).type_as(x)
27
- return output * self.weight
28
 
29
 
30
- def precompute_pos_cis(dim: int, end: int, theta: float = 10000.0):
31
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
32
  t = torch.arange(end, device=freqs.device) # type: ignore
33
  freqs = torch.outer(t, freqs).float() # type: ignore
@@ -76,71 +72,69 @@ class Attention(nn.Module):
76
  self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
77
  self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
78
  self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
79
- self.k_cache, self.v_cache = None, None
80
  self.attn_dropout = nn.Dropout(args.dropout)
81
  self.resid_dropout = nn.Dropout(args.dropout)
82
  self.dropout = args.dropout
83
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
84
-
85
  # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
86
  mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
87
  mask = torch.triu(mask, diagonal=1)
88
- self.register_buffer("mask", mask)
89
-
90
- def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
91
- bsz, seqlen, _ = x.shape
92
 
 
 
 
 
 
 
93
  xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
94
-
95
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
96
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
97
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
98
 
99
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
100
-
101
- # 更高效的kv_cache实现
102
- if kv_cache and self.eval():
103
- if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
104
- xk = torch.cat((self.k_cache, xk), dim=1)
105
- xv = torch.cat((self.v_cache, xv), dim=1)
106
- self.k_cache, self.v_cache = xk, xv
107
-
108
- xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
109
- xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
110
-
111
- xq = xq.transpose(1, 2)
112
- xk = xk.transpose(1, 2)
113
- xv = xv.transpose(1, 2)
114
-
115
- if self.flash and seqlen != 1:
116
- output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
117
- dropout_p=self.dropout if self.training else 0.0,
118
- is_causal=True)
119
  else:
120
- scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
121
- scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
122
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
123
  scores = self.attn_dropout(scores)
124
- output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
125
-
126
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
127
 
128
- output = self.wo(output)
129
- output = self.resid_dropout(output)
130
- return output
131
 
132
 
133
  class FeedForward(nn.Module):
134
- def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
135
  super().__init__()
136
- if hidden_dim is None:
137
- hidden_dim = 4 * dim
138
  hidden_dim = int(2 * hidden_dim / 3)
139
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
140
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
141
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
142
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
143
- self.dropout = nn.Dropout(dropout)
144
 
145
  def forward(self, x):
146
  return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
@@ -168,7 +162,6 @@ class MoEGate(nn.Module):
168
 
169
  def forward(self, hidden_states):
170
  bsz, seq_len, h = hidden_states.shape
171
-
172
  hidden_states = hidden_states.view(-1, h)
173
  logits = F.linear(hidden_states, self.weight, None)
174
  if self.scoring_func == 'softmax':
@@ -200,7 +193,7 @@ class MoEGate(nn.Module):
200
  fi = ce * self.n_routed_experts
201
  aux_loss = (Pi * fi).sum() * self.alpha
202
  else:
203
- aux_loss = None
204
  return topk_idx, topk_weight, aux_loss
205
 
206
 
@@ -209,50 +202,35 @@ class MOEFeedForward(nn.Module):
209
  super().__init__()
210
  self.config = config
211
  self.experts = nn.ModuleList([
212
- FeedForward(
213
- dim=config.dim,
214
- hidden_dim=config.hidden_dim,
215
- multiple_of=config.multiple_of,
216
- dropout=config.dropout,
217
- )
218
  for _ in range(config.n_routed_experts)
219
  ])
220
-
221
  self.gate = MoEGate(config)
222
  if config.n_shared_experts is not None:
223
- self.shared_experts = FeedForward(
224
- dim=config.dim,
225
- hidden_dim=config.hidden_dim,
226
- multiple_of=config.multiple_of,
227
- dropout=config.dropout,
228
- )
229
 
230
  def forward(self, x):
231
  identity = x
232
  orig_shape = x.shape
233
  bsz, seq_len, _ = x.shape
234
-
235
  # 使用门控机制选择专家
236
  topk_idx, topk_weight, aux_loss = self.gate(x)
237
-
238
  x = x.view(-1, x.shape[-1])
239
  flat_topk_idx = topk_idx.view(-1)
240
-
241
  if self.training:
242
  # 训练模式下,重复输入数据
243
  x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
244
  y = torch.empty_like(x, dtype=torch.float16)
245
  for i, expert in enumerate(self.experts):
246
- y[flat_topk_idx == i] = expert(x[flat_topk_idx == i])
247
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
248
  y = y.view(*orig_shape)
249
  else:
250
  # 推理模式下,只选择最优专家
251
  y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
252
-
253
  if self.config.n_shared_experts is not None:
254
  y = y + self.shared_experts(identity)
255
-
256
  return y
257
 
258
  @torch.no_grad()
@@ -271,7 +249,7 @@ class MOEFeedForward(nn.Module):
271
  expert = self.experts[i]
272
  exp_token_idx = token_idxs[start_idx:end_idx]
273
  expert_tokens = x[exp_token_idx]
274
- expert_out = expert(expert_tokens)
275
  expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
276
  # 使用 scatter_add_ 进行 sum 操作
277
  expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
@@ -279,146 +257,119 @@ class MOEFeedForward(nn.Module):
279
  return expert_cache
280
 
281
 
282
- class TransformerBlock(nn.Module):
283
- def __init__(self, layer_id: int, args: LMConfig):
284
  super().__init__()
285
- self.n_heads = args.n_heads
286
- self.dim = args.dim
287
- self.head_dim = args.dim // args.n_heads
288
- self.attention = Attention(args)
289
 
290
  self.layer_id = layer_id
291
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
292
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
293
-
294
- if args.use_moe:
295
- self.feed_forward = MOEFeedForward(args)
296
- else:
297
- self.feed_forward = FeedForward(
298
- dim=args.dim,
299
- hidden_dim=args.hidden_dim,
300
- multiple_of=args.multiple_of,
301
- dropout=args.dropout,
302
- )
303
-
304
- def forward(self, x, pos_cis, kv_cache=False):
305
- h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
306
  out = h + self.feed_forward(self.ffn_norm(h))
307
- return out
308
 
309
 
310
- class Transformer(PreTrainedModel):
311
  config_class = LMConfig
312
- last_loss: Optional[torch.Tensor]
313
 
314
  def __init__(self, params: LMConfig = None):
315
- super().__init__(params)
316
- if not params:
317
- params = LMConfig()
318
- self.params = params
319
- self.vocab_size = params.vocab_size
320
- self.n_layers = params.n_layers
321
-
322
  self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
323
  self.dropout = nn.Dropout(params.dropout)
324
- self.layers = torch.nn.ModuleList()
325
- for layer_id in range(self.n_layers):
326
- self.layers.append(TransformerBlock(layer_id, params))
327
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
328
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
329
  self.tok_embeddings.weight = self.output.weight
330
- pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
331
- self.register_buffer("pos_cis", pos_cis, persistent=False)
332
-
333
- self.apply(self._init_weights)
334
-
335
- for pn, p in self.named_parameters():
336
- if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
337
- torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
338
-
339
- self.last_loss = None
340
  self.OUT = CausalLMOutputWithPast()
341
 
342
- def _init_weights(self, module):
343
- if isinstance(module, nn.Linear):
344
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
345
- if module.bias is not None:
346
- torch.nn.init.zeros_(module.bias)
347
- elif isinstance(module, nn.Embedding):
348
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
349
-
350
- def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
351
- kv_cache=False, **keyargs):
352
- current_idx = 0
353
- if 'input_ids' in keyargs:
354
- tokens = keyargs['input_ids']
355
- if 'attention_mask' in keyargs:
356
- targets = keyargs['attention_mask']
357
- if 'current_idx' in keyargs:
358
- current_idx = int(keyargs['current_idx'])
359
-
360
- _bsz, seqlen = tokens.shape
361
- h = self.tok_embeddings(tokens)
362
- h = self.dropout(h)
363
- pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
364
- for idx, layer in enumerate(self.layers):
365
- h = layer(h, pos_cis, kv_cache)
366
-
367
- h = self.norm(h)
368
-
369
- if targets is not None:
370
- logits = self.output(h)
371
- self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
372
- else:
373
- logits = self.output(h[:, [-1], :])
374
- self.last_loss = None
375
-
376
  self.OUT.__setitem__('logits', logits)
377
- self.OUT.__setitem__('last_loss', self.last_loss)
 
378
  return self.OUT
379
 
380
  @torch.inference_mode()
381
- def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
382
- # rp: repetition_penalty
383
- index = idx.shape[1]
384
- init_inference = True
385
- while idx.shape[1] < max_new_tokens - 1:
386
- if init_inference or not kv_cache:
387
- inference_res, init_inference = self(idx, kv_cache=kv_cache), False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  else:
389
- inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
390
-
391
- logits = inference_res.logits
392
- logits = logits[:, -1, :]
393
-
394
- for token in set(idx.tolist()[0]):
395
- logits[:, token] /= rp
396
-
397
- if temperature == 0.0:
398
- _, idx_next = torch.topk(logits, k=1, dim=-1)
399
- else:
400
- logits = logits / temperature
401
- if top_k is not None:
402
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
403
- logits[logits < v[:, [-1]]] = -float('Inf')
404
-
405
- probs = F.softmax(logits, dim=-1)
406
- idx_next = torch.multinomial(probs, num_samples=1, generator=None)
407
-
408
- if idx_next == eos:
409
  break
410
-
411
- idx = torch.cat((idx, idx_next), dim=1)
412
- if stream:
413
- yield idx[:, index:]
414
-
415
- if not stream:
416
- yield idx[:, index:]
417
-
418
- @torch.inference_mode()
419
- def eval_answer(self, idx):
420
- idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
421
- inference_res = self(idx_cond)
422
- logits = inference_res.logits
423
- logits = logits[:, -1, :]
424
- return logits
 
4
  import time
5
 
6
  from .LMConfig import LMConfig
7
+ from typing import Any, Optional, Tuple, List
8
  import numpy as np
9
  import torch
10
  import torch.nn.functional as F
 
19
  self.eps = eps
20
  self.weight = nn.Parameter(torch.ones(dim))
21
 
 
 
 
22
  def forward(self, x):
23
+ return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
 
24
 
25
 
26
+ def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
27
  freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
28
  t = torch.arange(end, device=freqs.device) # type: ignore
29
  freqs = torch.outer(t, freqs).float() # type: ignore
 
72
  self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
73
  self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
74
  self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
 
75
  self.attn_dropout = nn.Dropout(args.dropout)
76
  self.resid_dropout = nn.Dropout(args.dropout)
77
  self.dropout = args.dropout
78
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
 
79
  # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
80
  mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
81
  mask = torch.triu(mask, diagonal=1)
82
+ self.register_buffer("mask", mask, persistent=False)
 
 
 
83
 
84
+ def forward(self,
85
+ x: torch.Tensor,
86
+ pos_cis: torch.Tensor,
87
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
88
+ use_cache=False):
89
+ bsz, seq_len, _ = x.shape
90
  xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
91
+ xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
92
+ xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
93
+ xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
 
94
 
95
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
96
+ # kv_cache实现
97
+ if past_key_value is not None:
98
+ xk = torch.cat([past_key_value[0], xk], dim=1)
99
+ xv = torch.cat([past_key_value[1], xv], dim=1)
100
+ past_kv = (xk, xv) if use_cache else None
101
+
102
+ xq, xk, xv = (
103
+ xq.transpose(1, 2),
104
+ repeat_kv(xk, self.n_rep).transpose(1, 2),
105
+ repeat_kv(xv, self.n_rep).transpose(1, 2)
106
+ )
107
+ if self.flash and seq_len != 1:
108
+ dropout_p = self.dropout if self.training else 0.0
109
+ output = F.scaled_dot_product_attention(
110
+ xq, xk, xv,
111
+ attn_mask=None,
112
+ dropout_p=dropout_p,
113
+ is_causal=True
114
+ )
115
  else:
116
+ scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
117
+ scores += self.mask[:, :, :seq_len, :seq_len]
118
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
119
  scores = self.attn_dropout(scores)
120
+ output = scores @ xv
 
 
121
 
122
+ output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
123
+ output = self.resid_dropout(self.wo(output))
124
+ return output, past_kv
125
 
126
 
127
  class FeedForward(nn.Module):
128
+ def __init__(self, config: LMConfig):
129
  super().__init__()
130
+ if config.hidden_dim is None:
131
+ hidden_dim = 4 * config.dim
132
  hidden_dim = int(2 * hidden_dim / 3)
133
+ config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
134
+ self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
135
+ self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
136
+ self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
137
+ self.dropout = nn.Dropout(config.dropout)
138
 
139
  def forward(self, x):
140
  return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
 
162
 
163
  def forward(self, hidden_states):
164
  bsz, seq_len, h = hidden_states.shape
 
165
  hidden_states = hidden_states.view(-1, h)
166
  logits = F.linear(hidden_states, self.weight, None)
167
  if self.scoring_func == 'softmax':
 
193
  fi = ce * self.n_routed_experts
194
  aux_loss = (Pi * fi).sum() * self.alpha
195
  else:
196
+ aux_loss = 0
197
  return topk_idx, topk_weight, aux_loss
198
 
199
 
 
202
  super().__init__()
203
  self.config = config
204
  self.experts = nn.ModuleList([
205
+ FeedForward(config)
 
 
 
 
 
206
  for _ in range(config.n_routed_experts)
207
  ])
 
208
  self.gate = MoEGate(config)
209
  if config.n_shared_experts is not None:
210
+ self.shared_experts = FeedForward(config)
 
 
 
 
 
211
 
212
  def forward(self, x):
213
  identity = x
214
  orig_shape = x.shape
215
  bsz, seq_len, _ = x.shape
 
216
  # 使用门控机制选择专家
217
  topk_idx, topk_weight, aux_loss = self.gate(x)
 
218
  x = x.view(-1, x.shape[-1])
219
  flat_topk_idx = topk_idx.view(-1)
 
220
  if self.training:
221
  # 训练模式下,重复输入数据
222
  x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
223
  y = torch.empty_like(x, dtype=torch.float16)
224
  for i, expert in enumerate(self.experts):
225
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
226
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
227
  y = y.view(*orig_shape)
228
  else:
229
  # 推理模式下,只选择最优专家
230
  y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
 
231
  if self.config.n_shared_experts is not None:
232
  y = y + self.shared_experts(identity)
233
+ self.aux_loss = aux_loss
234
  return y
235
 
236
  @torch.no_grad()
 
249
  expert = self.experts[i]
250
  exp_token_idx = token_idxs[start_idx:end_idx]
251
  expert_tokens = x[exp_token_idx]
252
+ expert_out = expert(expert_tokens).to(expert_cache.dtype)
253
  expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
254
  # 使用 scatter_add_ 进行 sum 操作
255
  expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
 
257
  return expert_cache
258
 
259
 
260
+ class MiniMindBlock(nn.Module):
261
+ def __init__(self, layer_id: int, config: LMConfig):
262
  super().__init__()
263
+ self.n_heads = config.n_heads
264
+ self.dim = config.dim
265
+ self.head_dim = config.dim // config.n_heads
266
+ self.attention = Attention(config)
267
 
268
  self.layer_id = layer_id
269
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
270
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
271
+ self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
272
+
273
+ def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
274
+ h_attn, past_kv = self.attention(
275
+ self.attention_norm(x),
276
+ pos_cis,
277
+ past_key_value=past_key_value,
278
+ use_cache=use_cache
279
+ )
280
+ h = x + h_attn
 
 
 
281
  out = h + self.feed_forward(self.ffn_norm(h))
282
+ return out, past_kv
283
 
284
 
285
+ class MiniMindLM(PreTrainedModel):
286
  config_class = LMConfig
 
287
 
288
  def __init__(self, params: LMConfig = None):
289
+ self.params = params or LMConfig()
290
+ super().__init__(self.params)
291
+ self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
 
 
 
 
292
  self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
293
  self.dropout = nn.Dropout(params.dropout)
294
+ self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
 
 
295
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
296
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
297
  self.tok_embeddings.weight = self.output.weight
298
+ self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
299
+ theta=params.rope_theta), persistent=False)
 
 
 
 
 
 
 
 
300
  self.OUT = CausalLMOutputWithPast()
301
 
302
+ def forward(self,
303
+ input_ids: Optional[torch.Tensor] = None,
304
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
305
+ use_cache: bool = False,
306
+ **args):
307
+ past_key_values = past_key_values or [None] * len(self.layers)
308
+ start_pos = args.get('start_pos', 0)
309
+ h = self.dropout(self.tok_embeddings(input_ids))
310
+ pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
311
+ past_kvs = []
312
+ for l, layer in enumerate(self.layers):
313
+ h, past_kv = layer(
314
+ h, pos_cis,
315
+ past_key_value=past_key_values[l],
316
+ use_cache=use_cache
317
+ )
318
+ past_kvs.append(past_kv)
319
+ logits = self.output(self.norm(h))
320
+ aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  self.OUT.__setitem__('logits', logits)
322
+ self.OUT.__setitem__('aux_loss', aux_loss)
323
+ self.OUT.__setitem__('past_key_values', past_kvs)
324
  return self.OUT
325
 
326
  @torch.inference_mode()
327
+ def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
328
+ stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
329
+ # 流式生成
330
+ if stream:
331
+ return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
332
+
333
+ # 直接生成
334
+ generated = []
335
+ for i in range(input_ids.size(0)):
336
+ non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
337
+ out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
338
+ tokens_list = [tokens[:, -1:] for tokens in out]
339
+ gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
340
+ full_sequence = torch.cat([non_pad, gen], dim=-1)
341
+ generated.append(full_sequence)
342
+ max_length = max(seq.size(1) for seq in generated)
343
+ generated = [
344
+ torch.cat(
345
+ [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
346
+ dim=-1)
347
+ for seq in generated
348
+ ]
349
+ return torch.cat(generated, dim=0)
350
+
351
+ def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
352
+ start, first_seq, past_kvs = input_ids.shape[1], True, None
353
+ while input_ids.shape[1] < max_new_tokens - 1:
354
+ if first_seq or not use_cache:
355
+ out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
356
  else:
357
+ out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
358
+ start_pos=input_ids.shape[1] - 1)
359
+ logits, past_kvs = out.logits[:, -1, :], out.past_key_values
360
+ logits[:, list(set(input_ids.tolist()[0]))] /= rp
361
+ logits /= (temperature + 1e-9)
362
+ if top_p is not None and top_p < 1.0:
363
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
364
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
365
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
366
+ sorted_indices_to_remove = cumulative_probs > top_p
367
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
368
+ sorted_indices_to_remove[:, 0] = False
369
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
370
+ logits[indices_to_remove] = -float('Inf')
371
+ input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
372
+ input_ids = torch.cat((input_ids, input_ids_next), dim=1)
373
+ yield input_ids[:, start:]
374
+ if input_ids_next.item() == eos_token_id:
 
 
375
  break