Text Generation
ELM
English
dev-slx commited on
Commit
896f67a
·
verified ·
1 Parent(s): 2ed7102

Update elm/model.py

Browse files
Files changed (1) hide show
  1. elm/model.py +5 -7
elm/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024, SliceX AI, Inc. All Rights Reserved.
2
 
3
  import copy
4
  import inspect
@@ -100,15 +100,12 @@ class ELM(torch.nn.Module):
100
  else:
101
  x = self.slice_transformer.drop(tok_emb)
102
 
103
- tlayer_id = 0
104
  ignore_index_id = -100
105
  loss = torch.zeros(1).to(device)
106
  loss_denom = 0
107
 
108
  for tlayer in self.slice_transformer.h:
109
  x = tlayer(x, attention_mask=attention_mask)
110
-
111
- tlayer_id += 1
112
 
113
  x = self.slice_transformer.ln_f(x)
114
 
@@ -133,9 +130,8 @@ class ELM(torch.nn.Module):
133
  def get_num_params(self, non_embedding=True):
134
  """
135
  Return the number of parameters in the model.
136
- For non-embedding count (default), the position embeddings get subtracted.
137
- This assumes parameter tying between input and final layer embeddings. Oherwise
138
- If there is no parameter sharing , set the flag to False to include parameters for both layers.
139
  """
140
  n_params = sum(p.numel() for p in self.parameters())
141
  if non_embedding and not self.model_args.use_rotary_embeddings:
@@ -342,6 +338,8 @@ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None
342
  model_args = ModelArgs(**model_config_dict)
343
 
344
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
 
 
345
 
346
  model = ELM(model_args=model_args).to(dtype=dtype)
347
 
 
1
+ # Copyright (c) 2024, SliceX AI, Inc.
2
 
3
  import copy
4
  import inspect
 
100
  else:
101
  x = self.slice_transformer.drop(tok_emb)
102
 
 
103
  ignore_index_id = -100
104
  loss = torch.zeros(1).to(device)
105
  loss_denom = 0
106
 
107
  for tlayer in self.slice_transformer.h:
108
  x = tlayer(x, attention_mask=attention_mask)
 
 
109
 
110
  x = self.slice_transformer.ln_f(x)
111
 
 
130
  def get_num_params(self, non_embedding=True):
131
  """
132
  Return the number of parameters in the model.
133
+ For non-embedding count (default), subtract position embeddings if parameter tying applies.
134
+ If there is no parameter sharing, set the flag to False to include parameters for both input/output layers.
 
135
  """
136
  n_params = sum(p.numel() for p in self.parameters())
137
  if non_embedding and not self.model_args.use_rotary_embeddings:
 
338
  model_args = ModelArgs(**model_config_dict)
339
 
340
  dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
341
+ if not torch.cuda.is_available():
342
+ dtype = torch.bfloat16
343
 
344
  model = ELM(model_args=model_args).to(dtype=dtype)
345