Confusion in the model architecture
#4
by
Ink
- opened
If I do a
for i, layer in enumerate(model.modules()):
print(i,":\t",layer)
I get the first two layers as :
0 : RavenForCausalLM(
(transformer): ModuleDict(
(wte): Embedding(65536, 5280)
(prelude): ModuleList(
(0-1): 2 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(adapter): Linear(in_features=10560, out_features=5280, bias=False)
(core_block): ModuleList(
(0-3): 4 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(coda): ModuleList(
(0-1): 2 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(ln_f): RMSNorm()
)
(lm_head): Linear(in_features=5280, out_features=65536, bias=False)
)
-------------------
1 : ModuleDict(
(wte): Embedding(65536, 5280)
(prelude): ModuleList(
(0-1): 2 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(adapter): Linear(in_features=10560, out_features=5280, bias=False)
(core_block): ModuleList(
(0-3): 4 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(coda): ModuleList(
(0-1): 2 x SandwichBlock(
(norm_1): RMSNorm()
(attn): CausalSelfAttention(
(Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
(proj): Linear(in_features=5280, out_features=5280, bias=False)
)
(norm_2): RMSNorm()
(mlp): GatedMLP(
(fc): Linear(in_features=5280, out_features=35840, bias=False)
(proj): Linear(in_features=17920, out_features=5280, bias=False)
(nonlin): SiLU()
)
(norm_3): RMSNorm()
(norm_4): RMSNorm()
)
)
(ln_f): RMSNorm()
)
..
..
where 0: and 1: are enumeration indices of the model.modules()
I am a little confused. Why is the entire prelude, core, coda modules repeating themselves for each item in model.modules(). Shoudn't it just be module 0,1: prelude, module 2,3,4,5: core and module 6,7: coda ?
Just a little confused on how this architecture should be structured.
Hi! What is your usecase for model.modules()
? You're getting a confusing references because you are iterating over a ModuleDict.
The layout of the modules is following Hugging Face layouts, separating the transformer from the model head:
self.transformer = torch.nn.ModuleDict(
dict(
wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
prelude=prelude,
adapter=adapter,
core_block=core_block,
coda=coda,
ln_f=RMSNorm(config.n_embd, eps=config.norm_eps),
)
)
self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)