Parallelization support
#5
by
yigitbekir
- opened
raven_modeling_minimal.py
CHANGED
@@ -492,7 +492,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
|
|
492 |
attn_maps: dict = {},
|
493 |
return_attn: bool = False,
|
494 |
):
|
495 |
-
x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
|
496 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
497 |
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
|
498 |
attn_maps[block_idx + idx] = attn_map
|
|
|
492 |
attn_maps: dict = {},
|
493 |
return_attn: bool = False,
|
494 |
):
|
495 |
+
x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1))
|
496 |
for idx, block in enumerate(self.transformer.core_block, start=1):
|
497 |
x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
|
498 |
attn_maps[block_idx + idx] = attn_map
|