Spaces:
Paused
Paused
app.py
CHANGED
@@ -15,6 +15,7 @@ print("Done")
|
|
15 |
|
16 |
|
17 |
def create_medusa_heads(model_id: str):
|
|
|
18 |
training_args = [
|
19 |
"--model_name_or_path", model_id,
|
20 |
"--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
|
@@ -37,7 +38,10 @@ def create_medusa_heads(model_id: str):
|
|
37 |
"--medusa_num_heads", "3",
|
38 |
"--medusa_num_layers", "1",
|
39 |
]
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
# Upload the medusa heads to the Hub
|
43 |
repo_id = f"medusa-{model_id}"
|
|
|
15 |
|
16 |
|
17 |
def create_medusa_heads(model_id: str):
|
18 |
+
parser = distributed_run.get_args_parser()
|
19 |
training_args = [
|
20 |
"--model_name_or_path", model_id,
|
21 |
"--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
|
|
|
38 |
"--medusa_num_heads", "3",
|
39 |
"--medusa_num_layers", "1",
|
40 |
]
|
41 |
+
args = parser.parse_args(
|
42 |
+
["training_script", "medusa/medusa/train/train.py", "training_script_args"] + training_args
|
43 |
+
)
|
44 |
+
distributed_run.run(args)
|
45 |
|
46 |
# Upload the medusa heads to the Hub
|
47 |
repo_id = f"medusa-{model_id}"
|