ok
Browse files- llama_roughwork.ipynb +207 -224
llama_roughwork.ipynb
CHANGED
@@ -2,98 +2,60 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
-
"
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
},
|
23 |
-
{
|
24 |
-
"cell_type": "code",
|
25 |
-
"execution_count": 3,
|
26 |
-
"metadata": {},
|
27 |
-
"outputs": [
|
28 |
{
|
29 |
"data": {
|
30 |
"text/plain": [
|
31 |
-
"
|
32 |
]
|
33 |
},
|
34 |
-
"execution_count":
|
35 |
"metadata": {},
|
36 |
"output_type": "execute_result"
|
37 |
}
|
38 |
],
|
39 |
"source": [
|
40 |
-
"
|
41 |
-
|
42 |
-
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": 7,
|
46 |
-
"metadata": {},
|
47 |
-
"outputs": [
|
48 |
-
{
|
49 |
-
"ename": "TypeError",
|
50 |
-
"evalue": "train_model() missing 1 required positional argument: 'args'",
|
51 |
-
"output_type": "error",
|
52 |
-
"traceback": [
|
53 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
54 |
-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
55 |
-
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmlx_lm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m lora\n\u001b[0;32m----> 3\u001b[0m \u001b[43mlora\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfine_tune_train.jsonl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfine_tune_test.jsonl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
56 |
-
"\u001b[0;31mTypeError\u001b[0m: train_model() missing 1 required positional argument: 'args'"
|
57 |
-
]
|
58 |
-
}
|
59 |
-
],
|
60 |
-
"source": [
|
61 |
-
"from mlx_lm import lora\n",
|
62 |
-
"\n",
|
63 |
-
"lora.train_model(\n",
|
64 |
-
" model=model, \n",
|
65 |
-
" tokenizer=tokenizer,\n",
|
66 |
-
" train_set=\"fine_tune_train.jsonl\",\n",
|
67 |
-
" valid_set=\"fine_tune_test.jsonl\")"
|
68 |
]
|
69 |
},
|
70 |
{
|
71 |
"cell_type": "code",
|
72 |
-
"execution_count":
|
73 |
"metadata": {},
|
74 |
"outputs": [
|
75 |
{
|
76 |
-
"
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
}
|
85 |
],
|
86 |
"source": [
|
87 |
-
"
|
88 |
]
|
89 |
},
|
90 |
-
{
|
91 |
-
"cell_type": "code",
|
92 |
-
"execution_count": null,
|
93 |
-
"metadata": {},
|
94 |
-
"outputs": [],
|
95 |
-
"source": []
|
96 |
-
},
|
97 |
{
|
98 |
"cell_type": "code",
|
99 |
"execution_count": 18,
|
@@ -128,62 +90,35 @@
|
|
128 |
]
|
129 |
},
|
130 |
{
|
131 |
-
"cell_type": "
|
132 |
-
"execution_count": 15,
|
133 |
"metadata": {},
|
134 |
-
"outputs": [
|
135 |
-
{
|
136 |
-
"name": "stdout",
|
137 |
-
"output_type": "stream",
|
138 |
-
"text": [
|
139 |
-
"Starting training..., iters: 100\n"
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"ename": "KeyboardInterrupt",
|
144 |
-
"evalue": "",
|
145 |
-
"output_type": "error",
|
146 |
-
"traceback": [
|
147 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
148 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
149 |
-
"Cell \u001b[0;32mIn[15], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmlx\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01moptimizers\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01moptim\u001b[39;00m\n\u001b[1;32m 2\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m optim\u001b[38;5;241m.\u001b[39mAdam(learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-3\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mlora\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_dataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfine_tune_train.jsonl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mval_dataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfine_tune_test.jsonl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[43m)\u001b[49m\n",
|
150 |
-
"File \u001b[0;32m/opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/mlx_lm/tuner/trainer.py:242\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, tokenizer, optimizer, train_dataset, val_dataset, args, loss, iterate_batches, training_callback)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m it \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m it \u001b[38;5;241m%\u001b[39m args\u001b[38;5;241m.\u001b[39msteps_per_eval \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m it \u001b[38;5;241m==\u001b[39m args\u001b[38;5;241m.\u001b[39miters:\n\u001b[1;32m 241\u001b[0m stop \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mperf_counter()\n\u001b[0;32m--> 242\u001b[0m val_loss \u001b[38;5;241m=\u001b[39m \u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 243\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 244\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mval_dataset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 245\u001b[0m \u001b[43m \u001b[49m\u001b[43mloss\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 246\u001b[0m \u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 247\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_batches\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_batches\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_seq_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_seq_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[43m \u001b[49m\u001b[43miterate_batches\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43miterate_batches\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 252\u001b[0m val_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mperf_counter() \u001b[38;5;241m-\u001b[39m stop\n\u001b[1;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m rank \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
|
151 |
-
"File \u001b[0;32m/opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/mlx_lm/tuner/trainer.py:166\u001b[0m, in \u001b[0;36mevaluate\u001b[0;34m(model, dataset, tokenizer, batch_size, num_batches, max_seq_length, loss, iterate_batches)\u001b[0m\n\u001b[1;32m 164\u001b[0m all_losses \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m losses \u001b[38;5;241m*\u001b[39m toks\n\u001b[1;32m 165\u001b[0m ntokens \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m toks\n\u001b[0;32m--> 166\u001b[0m \u001b[43mmx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meval\u001b[49m\u001b[43m(\u001b[49m\u001b[43mall_losses\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mntokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m all_losses \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39mdistributed\u001b[38;5;241m.\u001b[39mall_sum(all_losses)\n\u001b[1;32m 169\u001b[0m ntokens \u001b[38;5;241m=\u001b[39m mx\u001b[38;5;241m.\u001b[39mdistributed\u001b[38;5;241m.\u001b[39mall_sum(ntokens)\n",
|
152 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
153 |
-
]
|
154 |
-
}
|
155 |
-
],
|
156 |
"source": [
|
157 |
-
"
|
158 |
-
"optimizer = optim.Adam(learning_rate=1e-3)\n",
|
159 |
-
"lora.train(model=model, \n",
|
160 |
-
" tokenizer=tokenizer,\n",
|
161 |
-
" train_dataset=\"fine_tune_train.jsonl\",\n",
|
162 |
-
" val_dataset=\"fine_tune_test.jsonl\",\n",
|
163 |
-
" optimizer=optimizer)"
|
164 |
]
|
165 |
},
|
166 |
{
|
167 |
"cell_type": "code",
|
168 |
-
"execution_count":
|
169 |
-
"metadata": {},
|
170 |
-
"outputs": [],
|
171 |
-
"source": []
|
172 |
-
},
|
173 |
-
{
|
174 |
-
"cell_type": "code",
|
175 |
-
"execution_count": null,
|
176 |
"metadata": {},
|
177 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
{
|
179 |
"name": "stdout",
|
180 |
"output_type": "stream",
|
181 |
"text": [
|
182 |
"Trainable parameters: 0.085% (6.816M/8030.261M)\n",
|
183 |
"Starting training..., iters: 10\n",
|
184 |
-
"Iter 1: Val loss 14.203, Val took
|
185 |
-
"Iter 10: Val loss 7.
|
186 |
-
"Iter 10: Train loss 10.
|
187 |
"Saved final weights to adapters.safetensors.\n"
|
188 |
]
|
189 |
}
|
@@ -192,7 +127,7 @@
|
|
192 |
"from dataclasses import dataclass\n",
|
193 |
"import mlx.optimizers as optim\n",
|
194 |
"from mlx_lm import lora\n",
|
195 |
-
"from mlx_lm import load, generate
|
196 |
"\n",
|
197 |
"# Create a dataclass to convert dictionary to an object\n",
|
198 |
"@dataclass\n",
|
@@ -232,191 +167,239 @@
|
|
232 |
" valid_set=\"fine_tune_test.jsonl\")"
|
233 |
]
|
234 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
{
|
236 |
"cell_type": "code",
|
237 |
-
"execution_count":
|
238 |
"metadata": {},
|
239 |
"outputs": [
|
240 |
{
|
241 |
-
"
|
242 |
-
"
|
243 |
-
"text": [
|
244 |
-
"/opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
245 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
246 |
-
]
|
247 |
-
},
|
248 |
-
{
|
249 |
-
"ename": "TypeError",
|
250 |
-
"evalue": "'module' object is not callable",
|
251 |
"output_type": "error",
|
252 |
"traceback": [
|
253 |
-
"\u001b[0;31m
|
254 |
-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
255 |
-
"Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmlx_lm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m fuse\n\u001b[0;32m----> 3\u001b[0m \u001b[43mfuse\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
256 |
-
"\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable"
|
257 |
]
|
258 |
}
|
259 |
],
|
260 |
"source": [
|
261 |
-
"
|
262 |
"\n",
|
263 |
-
"fuse
|
|
|
|
|
|
|
|
|
|
|
264 |
]
|
265 |
},
|
266 |
{
|
267 |
-
"cell_type": "
|
268 |
-
"execution_count": null,
|
269 |
"metadata": {},
|
270 |
-
"outputs": [],
|
271 |
"source": [
|
272 |
-
"
|
273 |
-
" --model mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\n",
|
274 |
-
" --upload-repo mlx-community/my-lora-mistral-7b \\\n",
|
275 |
-
" --hf-path mistralai/Mistral-7B-v0.1"
|
276 |
]
|
277 |
},
|
278 |
{
|
279 |
"cell_type": "code",
|
280 |
-
"execution_count":
|
281 |
"metadata": {},
|
282 |
"outputs": [
|
283 |
{
|
284 |
"data": {
|
|
|
|
|
|
|
|
|
|
|
285 |
"text/plain": [
|
286 |
-
"[
|
287 |
-
" 'DoRALinear',\n",
|
288 |
-
" 'LoRAEmbedding',\n",
|
289 |
-
" 'LoRALinear',\n",
|
290 |
-
" 'LoRASwitchLinear',\n",
|
291 |
-
" 'Path',\n",
|
292 |
-
" '__builtins__',\n",
|
293 |
-
" '__cached__',\n",
|
294 |
-
" '__doc__',\n",
|
295 |
-
" '__file__',\n",
|
296 |
-
" '__loader__',\n",
|
297 |
-
" '__name__',\n",
|
298 |
-
" '__package__',\n",
|
299 |
-
" '__spec__',\n",
|
300 |
-
" 'argparse',\n",
|
301 |
-
" 'convert_to_gguf',\n",
|
302 |
-
" 'dequantize',\n",
|
303 |
-
" 'fetch_from_hub',\n",
|
304 |
-
" 'get_model_path',\n",
|
305 |
-
" 'glob',\n",
|
306 |
-
" 'load_adapters',\n",
|
307 |
-
" 'main',\n",
|
308 |
-
" 'parse_arguments',\n",
|
309 |
-
" 'save_config',\n",
|
310 |
-
" 'save_weights',\n",
|
311 |
-
" 'shutil',\n",
|
312 |
-
" 'tree_flatten',\n",
|
313 |
-
" 'tree_unflatten',\n",
|
314 |
-
" 'upload_to_hub']"
|
315 |
]
|
316 |
},
|
317 |
-
"execution_count": 26,
|
318 |
"metadata": {},
|
319 |
-
"output_type": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
}
|
321 |
],
|
322 |
"source": [
|
323 |
-
"mlx_lm
|
324 |
-
"
|
325 |
-
"
|
326 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
]
|
328 |
},
|
329 |
{
|
330 |
"cell_type": "code",
|
331 |
-
"execution_count":
|
332 |
"metadata": {},
|
333 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
{
|
335 |
"data": {
|
336 |
"text/plain": [
|
337 |
-
"'
|
338 |
]
|
339 |
},
|
340 |
-
"execution_count":
|
341 |
"metadata": {},
|
342 |
"output_type": "execute_result"
|
343 |
}
|
344 |
],
|
345 |
"source": [
|
346 |
-
"generate(model
|
347 |
]
|
348 |
},
|
349 |
{
|
350 |
"cell_type": "code",
|
351 |
"execution_count": null,
|
352 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
"outputs": [
|
354 |
{
|
355 |
"data": {
|
|
|
|
|
|
|
|
|
|
|
356 |
"text/plain": [
|
357 |
-
"
|
358 |
]
|
359 |
},
|
360 |
-
"execution_count": 4,
|
361 |
"metadata": {},
|
362 |
-
"output_type": "
|
363 |
-
}
|
364 |
-
],
|
365 |
-
"source": [
|
366 |
-
"generate(model=model, tokenizer=tokenizer, prompt=\"role:Tell me something surprising about f1, content:\")"
|
367 |
-
]
|
368 |
-
},
|
369 |
-
{
|
370 |
-
"cell_type": "code",
|
371 |
-
"execution_count": 24,
|
372 |
-
"metadata": {},
|
373 |
-
"outputs": [
|
374 |
{
|
375 |
-
"
|
376 |
-
"
|
377 |
-
"
|
378 |
-
|
379 |
-
"
|
380 |
-
"\
|
381 |
-
"
|
382 |
-
"\
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
]
|
384 |
-
}
|
385 |
-
],
|
386 |
-
"source": [
|
387 |
-
"from mlx_lm import fuse\n",
|
388 |
-
"\n",
|
389 |
-
"fuse(model)"
|
390 |
-
]
|
391 |
-
},
|
392 |
-
{
|
393 |
-
"cell_type": "code",
|
394 |
-
"execution_count": 2,
|
395 |
-
"metadata": {},
|
396 |
-
"outputs": [
|
397 |
{
|
398 |
"name": "stderr",
|
399 |
"output_type": "stream",
|
400 |
"text": [
|
401 |
-
"
|
402 |
-
"
|
403 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
]
|
405 |
}
|
406 |
],
|
407 |
"source": [
|
408 |
-
"
|
409 |
-
"
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
"
|
419 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
]
|
421 |
}
|
422 |
],
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 6,
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
+
"data": {
|
10 |
+
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "3155f84dc9cb452f993d5535dc11f344",
|
12 |
+
"version_major": 2,
|
13 |
+
"version_minor": 0
|
14 |
+
},
|
15 |
+
"text/plain": [
|
16 |
+
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"metadata": {},
|
20 |
+
"output_type": "display_data"
|
21 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
{
|
23 |
"data": {
|
24 |
"text/plain": [
|
25 |
+
"\"The best car in F1 is a matter of personal opinion, as different drivers and teams have different strengths and weaknesses. However, based on recent performance and dominance, the Mercedes AMG F1 W11 is often considered one of the best cars in F1. The W11 has been a dominant force in the sport, winning 17 out of 20 races in the 2020 season and securing the constructors' championship. Its impressive performance is due to its powerful engine, advanced aerodynamics,\""
|
26 |
]
|
27 |
},
|
28 |
+
"execution_count": 6,
|
29 |
"metadata": {},
|
30 |
"output_type": "execute_result"
|
31 |
}
|
32 |
],
|
33 |
"source": [
|
34 |
+
"from mlx_lm import load, generate\n",
|
35 |
+
"model_1, tokenizer_1 = load(\"mlx-community/Meta-Llama-3-8B-Instruct-8bit\")\n",
|
36 |
+
"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
]
|
38 |
},
|
39 |
{
|
40 |
"cell_type": "code",
|
41 |
+
"execution_count": 13,
|
42 |
"metadata": {},
|
43 |
"outputs": [
|
44 |
{
|
45 |
+
"data": {
|
46 |
+
"text/plain": [
|
47 |
+
"'How many r in strawberry, role:How many r in strawberry, content:How many r in strawberry, role:How many r in strawberry, content:How many r in strawberry, role:How many r in strawberry, content:How many r in strawberry, role:How many r in strawberry, content:How many r in strawberry, role:How many r in strawberry, content:How many r in strawberry, role:How many r in strawberry, content:How many r in'"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
"execution_count": 13,
|
51 |
+
"metadata": {},
|
52 |
+
"output_type": "execute_result"
|
53 |
}
|
54 |
],
|
55 |
"source": [
|
56 |
+
"generate(model=model_1, tokenizer=tokenizer_1, prompt=\"role:How many r in strawberry, content:\")"
|
57 |
]
|
58 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
{
|
60 |
"cell_type": "code",
|
61 |
"execution_count": 18,
|
|
|
90 |
]
|
91 |
},
|
92 |
{
|
93 |
+
"cell_type": "markdown",
|
|
|
94 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
"source": [
|
96 |
+
"## Fine tuning the model using LORA techinique"
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
]
|
98 |
},
|
99 |
{
|
100 |
"cell_type": "code",
|
101 |
+
"execution_count": 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
"metadata": {},
|
103 |
"outputs": [
|
104 |
+
{
|
105 |
+
"name": "stderr",
|
106 |
+
"output_type": "stream",
|
107 |
+
"text": [
|
108 |
+
"/opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
109 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
110 |
+
"Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 93503.59it/s]\n"
|
111 |
+
]
|
112 |
+
},
|
113 |
{
|
114 |
"name": "stdout",
|
115 |
"output_type": "stream",
|
116 |
"text": [
|
117 |
"Trainable parameters: 0.085% (6.816M/8030.261M)\n",
|
118 |
"Starting training..., iters: 10\n",
|
119 |
+
"Iter 1: Val loss 14.203, Val took 14.567s\n",
|
120 |
+
"Iter 10: Val loss 7.762, Val took 2.556s\n",
|
121 |
+
"Iter 10: Train loss 10.280, Learning Rate 1.000e-05, It/sec 8.496, Tokens/sec 67.967, Trained Tokens 80, Peak mem 9.347 GB\n",
|
122 |
"Saved final weights to adapters.safetensors.\n"
|
123 |
]
|
124 |
}
|
|
|
127 |
"from dataclasses import dataclass\n",
|
128 |
"import mlx.optimizers as optim\n",
|
129 |
"from mlx_lm import lora\n",
|
130 |
+
"from mlx_lm import load, generate\n",
|
131 |
"\n",
|
132 |
"# Create a dataclass to convert dictionary to an object\n",
|
133 |
"@dataclass\n",
|
|
|
167 |
" valid_set=\"fine_tune_test.jsonl\")"
|
168 |
]
|
169 |
},
|
170 |
+
{
|
171 |
+
"cell_type": "markdown",
|
172 |
+
"metadata": {},
|
173 |
+
"source": [
|
174 |
+
"## Integrating fine tuned LORA weights with actual model weights and upload to hugging face"
|
175 |
+
]
|
176 |
+
},
|
177 |
{
|
178 |
"cell_type": "code",
|
179 |
+
"execution_count": null,
|
180 |
"metadata": {},
|
181 |
"outputs": [
|
182 |
{
|
183 |
+
"ename": "SyntaxError",
|
184 |
+
"evalue": "invalid decimal literal (3768910078.py, line 2)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
"output_type": "error",
|
186 |
"traceback": [
|
187 |
+
"\u001b[0;36m Cell \u001b[0;32mIn[7], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m --model mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid decimal literal\n"
|
|
|
|
|
|
|
188 |
]
|
189 |
}
|
190 |
],
|
191 |
"source": [
|
192 |
+
"# In terminal\n",
|
193 |
"\n",
|
194 |
+
"mlx_lm.fuse \\\n",
|
195 |
+
" --model mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\n",
|
196 |
+
" --upload-repo Rafii/f1llama \\\n",
|
197 |
+
" --hf-path mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\n",
|
198 |
+
" --adapter-path /Users/rafa/f1llama/ \\\n",
|
199 |
+
" --save-path ./fine_tuned/"
|
200 |
]
|
201 |
},
|
202 |
{
|
203 |
+
"cell_type": "markdown",
|
|
|
204 |
"metadata": {},
|
|
|
205 |
"source": [
|
206 |
+
"## Using my model from hugging face"
|
|
|
|
|
|
|
207 |
]
|
208 |
},
|
209 |
{
|
210 |
"cell_type": "code",
|
211 |
+
"execution_count": 1,
|
212 |
"metadata": {},
|
213 |
"outputs": [
|
214 |
{
|
215 |
"data": {
|
216 |
+
"application/vnd.jupyter.widget-view+json": {
|
217 |
+
"model_id": "f214f5af0a304c2e8f09cc7fa20d424b",
|
218 |
+
"version_major": 2,
|
219 |
+
"version_minor": 0
|
220 |
+
},
|
221 |
"text/plain": [
|
222 |
+
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
]
|
224 |
},
|
|
|
225 |
"metadata": {},
|
226 |
+
"output_type": "display_data"
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"name": "stdout",
|
230 |
+
"output_type": "stream",
|
231 |
+
"text": [
|
232 |
+
"==========\n",
|
233 |
+
"Prompt: <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
|
234 |
+
"\n",
|
235 |
+
"hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
236 |
+
"\n",
|
237 |
+
"\n",
|
238 |
+
"Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?\n",
|
239 |
+
"==========\n",
|
240 |
+
"Prompt: 11 tokens, 0.961 tokens-per-sec\n",
|
241 |
+
"Generation: 26 tokens, 8.244 tokens-per-sec\n",
|
242 |
+
"Peak memory: 9.066 GB\n"
|
243 |
+
]
|
244 |
}
|
245 |
],
|
246 |
"source": [
|
247 |
+
"from mlx_lm import load, generate\n",
|
248 |
+
"\n",
|
249 |
+
"model, tokenizer = load(\"Rafii/f1llama\")\n",
|
250 |
+
"\n",
|
251 |
+
"prompt=\"hello\"\n",
|
252 |
+
"\n",
|
253 |
+
"if hasattr(tokenizer, \"apply_chat_template\") and tokenizer.chat_template is not None:\n",
|
254 |
+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
|
255 |
+
" prompt = tokenizer.apply_chat_template(\n",
|
256 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
257 |
+
" )\n",
|
258 |
+
"\n",
|
259 |
+
"response = generate(model, tokenizer, prompt=prompt, verbose=True)\n"
|
260 |
]
|
261 |
},
|
262 |
{
|
263 |
"cell_type": "code",
|
264 |
+
"execution_count": 2,
|
265 |
"metadata": {},
|
266 |
"outputs": [
|
267 |
+
{
|
268 |
+
"name": "stdout",
|
269 |
+
"output_type": "stream",
|
270 |
+
"text": [
|
271 |
+
"==========\n",
|
272 |
+
"Prompt: How many r in strawberry\n",
|
273 |
+
"?\n",
|
274 |
+
"Answer: There are 2 r's in the word \"strawberry\"....more\n",
|
275 |
+
"How many s in strawberry?\n",
|
276 |
+
"Answer: There is 1 s in the word \"strawberry\"....more\n",
|
277 |
+
"How many t in strawberry?\n",
|
278 |
+
"Answer: There is 1 t in the word \"strawberry\"....more\n",
|
279 |
+
"How many w in strawberry?\n",
|
280 |
+
"Answer: There is 1 w in the word \"strawberry\"....more\n",
|
281 |
+
"How many a in strawberry?\n",
|
282 |
+
"Answer:\n",
|
283 |
+
"==========\n",
|
284 |
+
"Prompt: 5 tokens, 34.517 tokens-per-sec\n",
|
285 |
+
"Generation: 100 tokens, 11.103 tokens-per-sec\n",
|
286 |
+
"Peak memory: 9.066 GB\n"
|
287 |
+
]
|
288 |
+
},
|
289 |
{
|
290 |
"data": {
|
291 |
"text/plain": [
|
292 |
+
"'?\\nAnswer: There are 2 r\\'s in the word \"strawberry\"....more\\nHow many s in strawberry?\\nAnswer: There is 1 s in the word \"strawberry\"....more\\nHow many t in strawberry?\\nAnswer: There is 1 t in the word \"strawberry\"....more\\nHow many w in strawberry?\\nAnswer: There is 1 w in the word \"strawberry\"....more\\nHow many a in strawberry?\\nAnswer:'"
|
293 |
]
|
294 |
},
|
295 |
+
"execution_count": 2,
|
296 |
"metadata": {},
|
297 |
"output_type": "execute_result"
|
298 |
}
|
299 |
],
|
300 |
"source": [
|
301 |
+
"generate(model, tokenizer, prompt = \"How many r in strawberry\", verbose=True)\n"
|
302 |
]
|
303 |
},
|
304 |
{
|
305 |
"cell_type": "code",
|
306 |
"execution_count": null,
|
307 |
"metadata": {},
|
308 |
+
"outputs": [],
|
309 |
+
"source": []
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "code",
|
313 |
+
"execution_count": 3,
|
314 |
+
"metadata": {},
|
315 |
"outputs": [
|
316 |
{
|
317 |
"data": {
|
318 |
+
"application/vnd.jupyter.widget-view+json": {
|
319 |
+
"model_id": "32fb768e7a124f24a8c81dbe9676e637",
|
320 |
+
"version_major": 2,
|
321 |
+
"version_minor": 0
|
322 |
+
},
|
323 |
"text/plain": [
|
324 |
+
"Fetching 7 files: 0%| | 0/7 [00:00<?, ?it/s]"
|
325 |
]
|
326 |
},
|
|
|
327 |
"metadata": {},
|
328 |
+
"output_type": "display_data"
|
329 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
{
|
331 |
+
"name": "stdout",
|
332 |
+
"output_type": "stream",
|
333 |
+
"text": [
|
334 |
+
"==========\n",
|
335 |
+
"Prompt: <|begin_of_text|><|start_header_id|>user<|end_header_id|>\n",
|
336 |
+
"\n",
|
337 |
+
"hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
|
338 |
+
"\n",
|
339 |
+
"\n",
|
340 |
+
"Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?\n",
|
341 |
+
"==========\n",
|
342 |
+
"Prompt: 11 tokens, 3.317 tokens-per-sec\n",
|
343 |
+
"Generation: 26 tokens, 10.258 tokens-per-sec\n",
|
344 |
+
"Peak memory: 18.050 GB\n"
|
345 |
]
|
346 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
{
|
348 |
"name": "stderr",
|
349 |
"output_type": "stream",
|
350 |
"text": [
|
351 |
+
"2024-12-08 20:54:55.898 WARNING streamlit.runtime.scriptrunner_utils.script_run_context: Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
352 |
+
"2024-12-08 20:54:55.973 \n",
|
353 |
+
" \u001b[33m\u001b[1mWarning:\u001b[0m to view this Streamlit app on a browser, run it with the following\n",
|
354 |
+
" command:\n",
|
355 |
+
"\n",
|
356 |
+
" streamlit run /opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/ipykernel_launcher.py [ARGUMENTS]\n",
|
357 |
+
"2024-12-08 20:54:55.974 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
358 |
+
"2024-12-08 20:54:55.975 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
359 |
+
"2024-12-08 20:54:55.975 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
360 |
+
"2024-12-08 20:54:55.976 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
361 |
+
"2024-12-08 20:54:55.977 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
362 |
+
"2024-12-08 20:54:55.977 Session state does not function when running a script without `streamlit run`\n",
|
363 |
+
"2024-12-08 20:54:55.978 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
364 |
+
"2024-12-08 20:54:55.978 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
365 |
+
"2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
366 |
+
"2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
367 |
+
"2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
368 |
+
"2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n",
|
369 |
+
"2024-12-08 20:54:55.980 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n"
|
370 |
]
|
371 |
}
|
372 |
],
|
373 |
"source": [
|
374 |
+
"import streamlit as st\n",
|
375 |
+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
|
376 |
+
"\n",
|
377 |
+
"# Load your model and tokenizer\n",
|
378 |
+
"\n",
|
379 |
+
"model, tokenizer = load(\"Rafii/f1llama\")\n",
|
380 |
+
"\n",
|
381 |
+
"prompt=\"hello\"\n",
|
382 |
+
"\n",
|
383 |
+
"if hasattr(tokenizer, \"apply_chat_template\") and tokenizer.chat_template is not None:\n",
|
384 |
+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
|
385 |
+
" prompt = tokenizer.apply_chat_template(\n",
|
386 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
387 |
+
" )\n",
|
388 |
+
"\n",
|
389 |
+
"response = generate(model, tokenizer, prompt=prompt, verbose=True)\n",
|
390 |
+
"\n",
|
391 |
+
"st.title(\"Your Model Interface\")\n",
|
392 |
+
"\n",
|
393 |
+
"# User input\n",
|
394 |
+
"user_input = st.text_input(\"Enter text:\")\n",
|
395 |
+
"\n",
|
396 |
+
"if st.button(\"Submit\"):\n",
|
397 |
+
" # Tokenize input and make predictions\n",
|
398 |
+
" # inputs = tokenizer(user_input, return_tensors=\"pt\")\n",
|
399 |
+
" # outputs = model(**inputs)\n",
|
400 |
+
" response = generate(model, tokenizer, prompt=user_input, verbose=True)\n",
|
401 |
+
"\n",
|
402 |
+
" st.write(response)"
|
403 |
]
|
404 |
}
|
405 |
],
|