Parallelization support

#5
Files changed (1) hide show
  1. raven_modeling_minimal.py +1 -1
raven_modeling_minimal.py CHANGED
@@ -492,7 +492,7 @@ class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin):
492
  attn_maps: dict = {},
493
  return_attn: bool = False,
494
  ):
495
- x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
  x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map
 
492
  attn_maps: dict = {},
493
  return_attn: bool = False,
494
  ):
495
+ x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1))
496
  for idx, block in enumerate(self.transformer.core_block, start=1):
497
  x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=return_attn)
498
  attn_maps[block_idx + idx] = attn_map