Update elm/model.py
Browse files- elm/model.py +5 -7
elm/model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c) 2024, SliceX AI, Inc.
|
2 |
|
3 |
import copy
|
4 |
import inspect
|
@@ -100,15 +100,12 @@ class ELM(torch.nn.Module):
|
|
100 |
else:
|
101 |
x = self.slice_transformer.drop(tok_emb)
|
102 |
|
103 |
-
tlayer_id = 0
|
104 |
ignore_index_id = -100
|
105 |
loss = torch.zeros(1).to(device)
|
106 |
loss_denom = 0
|
107 |
|
108 |
for tlayer in self.slice_transformer.h:
|
109 |
x = tlayer(x, attention_mask=attention_mask)
|
110 |
-
|
111 |
-
tlayer_id += 1
|
112 |
|
113 |
x = self.slice_transformer.ln_f(x)
|
114 |
|
@@ -133,9 +130,8 @@ class ELM(torch.nn.Module):
|
|
133 |
def get_num_params(self, non_embedding=True):
|
134 |
"""
|
135 |
Return the number of parameters in the model.
|
136 |
-
For non-embedding count (default),
|
137 |
-
|
138 |
-
If there is no parameter sharing , set the flag to False to include parameters for both layers.
|
139 |
"""
|
140 |
n_params = sum(p.numel() for p in self.parameters())
|
141 |
if non_embedding and not self.model_args.use_rotary_embeddings:
|
@@ -342,6 +338,8 @@ def init_elm_model(model_args=ModelArgs(), device="cuda", model_config_dict=None
|
|
342 |
model_args = ModelArgs(**model_config_dict)
|
343 |
|
344 |
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
345 |
|
346 |
model = ELM(model_args=model_args).to(dtype=dtype)
|
347 |
|
|
|
1 |
+
# Copyright (c) 2024, SliceX AI, Inc.
|
2 |
|
3 |
import copy
|
4 |
import inspect
|
|
|
100 |
else:
|
101 |
x = self.slice_transformer.drop(tok_emb)
|
102 |
|
|
|
103 |
ignore_index_id = -100
|
104 |
loss = torch.zeros(1).to(device)
|
105 |
loss_denom = 0
|
106 |
|
107 |
for tlayer in self.slice_transformer.h:
|
108 |
x = tlayer(x, attention_mask=attention_mask)
|
|
|
|
|
109 |
|
110 |
x = self.slice_transformer.ln_f(x)
|
111 |
|
|
|
130 |
def get_num_params(self, non_embedding=True):
|
131 |
"""
|
132 |
Return the number of parameters in the model.
|
133 |
+
For non-embedding count (default), subtract position embeddings if parameter tying applies.
|
134 |
+
If there is no parameter sharing, set the flag to False to include parameters for both input/output layers.
|
|
|
135 |
"""
|
136 |
n_params = sum(p.numel() for p in self.parameters())
|
137 |
if non_embedding and not self.model_args.use_rotary_embeddings:
|
|
|
338 |
model_args = ModelArgs(**model_config_dict)
|
339 |
|
340 |
dtype = torch.bfloat16 if device=="cuda" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
341 |
+
if not torch.cuda.is_available():
|
342 |
+
dtype = torch.bfloat16
|
343 |
|
344 |
model = ELM(model_args=model_args).to(dtype=dtype)
|
345 |
|