Confusion in the model architecture

#4
by Ink - opened

If I do a

for i, layer in enumerate(model.modules()):
    print(i,":\t",layer)

I get the first two layers as :

0 :      RavenForCausalLM(
  (transformer): ModuleDict(
    (wte): Embedding(65536, 5280)
    (prelude): ModuleList(
      (0-1): 2 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False) 
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)  
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (adapter): Linear(in_features=10560, out_features=5280, bias=False)   
    (core_block): ModuleList(
      (0-3): 4 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (coda): ModuleList(
      (0-1): 2 x SandwichBlock(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
          (proj): Linear(in_features=5280, out_features=5280, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): GatedMLP(
          (fc): Linear(in_features=5280, out_features=35840, bias=False)
          (proj): Linear(in_features=17920, out_features=5280, bias=False)
          (nonlin): SiLU()
        )
        (norm_3): RMSNorm()
        (norm_4): RMSNorm()
      )
    )
    (ln_f): RMSNorm()
  )
  (lm_head): Linear(in_features=5280, out_features=65536, bias=False)
)
-------------------
1 :      ModuleDict(
  (wte): Embedding(65536, 5280)
  (prelude): ModuleList(
    (0-1): 2 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (adapter): Linear(in_features=10560, out_features=5280, bias=False)
  (core_block): ModuleList(
    (0-3): 4 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (coda): ModuleList(
    (0-1): 2 x SandwichBlock(
      (norm_1): RMSNorm()
      (attn): CausalSelfAttention(
        (Wqkv): Linear(in_features=5280, out_features=15840, bias=False)
        (proj): Linear(in_features=5280, out_features=5280, bias=False)
      )
      (norm_2): RMSNorm()
      (mlp): GatedMLP(
        (fc): Linear(in_features=5280, out_features=35840, bias=False)
        (proj): Linear(in_features=17920, out_features=5280, bias=False)
        (nonlin): SiLU()
      )
      (norm_3): RMSNorm()
      (norm_4): RMSNorm()
    )
  )
  (ln_f): RMSNorm()
)
..
..

where 0: and 1: are enumeration indices of the model.modules()
I am a little confused. Why is the entire prelude, core, coda modules repeating themselves for each item in model.modules(). Shoudn't it just be module 0,1: prelude, module 2,3,4,5: core and module 6,7: coda ?

Just a little confused on how this architecture should be structured.

Tom Goldstein's Lab at University of Maryland, College Park org

Hi! What is your usecase for model.modules()? You're getting a confusing references because you are iterating over a ModuleDict.

The layout of the modules is following Hugging Face layouts, separating the transformer from the model head:

        self.transformer = torch.nn.ModuleDict(
            dict(
                wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
                prelude=prelude,
                adapter=adapter,
                core_block=core_block,
                coda=coda,
                ln_f=RMSNorm(config.n_embd, eps=config.norm_eps),
            )
        )
        self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)

(see e.g. https://github.com/seal-rg/recurrent-pretraining/blob/bfd495b8ccc77b5c63674717c168a4b62b3c3e2a/recpre/raven_modeling_minimal.py#L331)

Sign up or log in to comment