feat: different model completion (#31)
Browse files* feat: use different models for instruction and completion
* refactor to support specific completion model
* improve comments
* add tokenizer_id
* make improvements and fix structured generation for other providers
* fix temperature issue
* Update src/synthetic_dataset_generator/constants.py
Co-authored-by: David Berenstein <[email protected]>
* apply feedback
* merging fix
---------
Co-authored-by: David Berenstein <[email protected]>
- README.md +3 -1
- examples/hf-serverless-deployment.py +1 -1
- examples/hf-serverless-different-model-for-completion.py +16 -0
- examples/ollama-different-model-for-completion.py +26 -0
- src/synthetic_dataset_generator/apps/chat.py +36 -5
- src/synthetic_dataset_generator/apps/rag.py +23 -3
- src/synthetic_dataset_generator/apps/textcat.py +1 -1
- src/synthetic_dataset_generator/constants.py +47 -23
- src/synthetic_dataset_generator/pipelines/base.py +35 -13
- src/synthetic_dataset_generator/pipelines/chat.py +6 -6
- src/synthetic_dataset_generator/pipelines/rag.py +1 -1
- src/synthetic_dataset_generator/pipelines/textcat.py +1 -1
README.md
CHANGED
@@ -86,12 +86,14 @@ You can set the following environment variables to customize the generation proc
|
|
86 |
Optionally, you can use different API providers and models.
|
87 |
|
88 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
89 |
-
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the
|
90 |
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
91 |
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
|
92 |
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
93 |
- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
|
94 |
|
|
|
|
|
95 |
SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
|
96 |
|
97 |
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
|
|
|
86 |
Optionally, you can use different API providers and models.
|
87 |
|
88 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
|
89 |
+
- `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the `HF_TOKEN` environment variable.
|
90 |
- `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
|
91 |
- `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
|
92 |
- `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
|
93 |
- `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
|
94 |
|
95 |
+
To use a specific model exclusively for generating completions, set the corresponding environment variables by appending `_COMPLETION` to the ones mentioned earlier. For example, you can use `MODEL_COMPLETION` and `OPENAI_BASE_URL_COMPLETION`.
|
96 |
+
|
97 |
SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
|
98 |
|
99 |
- `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
|
examples/hf-serverless-deployment.py
CHANGED
@@ -9,7 +9,7 @@ import os
|
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
-
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use
|
13 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
14 |
|
15 |
launch()
|
|
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
|
13 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
14 |
|
15 |
launch()
|
examples/hf-serverless-different-model-for-completion.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11,<3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "synthetic-dataset-generator",
|
5 |
+
# ]
|
6 |
+
# ///
|
7 |
+
import os
|
8 |
+
|
9 |
+
from synthetic_dataset_generator import launch
|
10 |
+
|
11 |
+
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
+
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for instruction generation
|
13 |
+
os.environ["MODEL_COMPLETION"] = "meta-llama/Llama-3.1-70B-Instruct" # use model for completion generation
|
14 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
15 |
+
|
16 |
+
launch()
|
examples/ollama-different-model-for-completion.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.11,<3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "synthetic-dataset-generator",
|
5 |
+
# ]
|
6 |
+
# ///
|
7 |
+
# ollama serve
|
8 |
+
# ollama run llama3.2
|
9 |
+
# ollama run llama3.2:1b
|
10 |
+
import os
|
11 |
+
|
12 |
+
from synthetic_dataset_generator import launch
|
13 |
+
|
14 |
+
os.environ["OLLAMA_BASE_URL"] = (
|
15 |
+
"http://127.0.0.1:11434/" # in this case, the same base url for both models
|
16 |
+
)
|
17 |
+
|
18 |
+
os.environ["MODEL"] = "llama3.2" # model for instruction generation
|
19 |
+
os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
|
20 |
+
|
21 |
+
os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for instruction generation
|
22 |
+
os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for completion generation
|
23 |
+
|
24 |
+
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
|
25 |
+
|
26 |
+
launch()
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -28,6 +28,7 @@ from synthetic_dataset_generator.constants import (
|
|
28 |
BASE_URL,
|
29 |
DEFAULT_BATCH_SIZE,
|
30 |
MODEL,
|
|
|
31 |
SFT_AVAILABLE,
|
32 |
)
|
33 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
@@ -148,6 +149,7 @@ def generate_dataset_from_prompt(
|
|
148 |
num_turns: int = 1,
|
149 |
num_rows: int = 10,
|
150 |
temperature: float = 0.9,
|
|
|
151 |
is_sample: bool = False,
|
152 |
progress=gr.Progress(),
|
153 |
) -> pd.DataFrame:
|
@@ -155,7 +157,10 @@ def generate_dataset_from_prompt(
|
|
155 |
progress(0.0, desc="(1/2) Generating instructions")
|
156 |
magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
|
157 |
response_generator = get_response_generator(
|
158 |
-
system_prompt,
|
|
|
|
|
|
|
159 |
)
|
160 |
total_steps: int = num_rows * 2
|
161 |
batch_size = DEFAULT_BATCH_SIZE
|
@@ -266,6 +271,7 @@ def generate_dataset_from_seed(
|
|
266 |
num_turns: int = 1,
|
267 |
num_rows: int = 10,
|
268 |
temperature: float = 0.9,
|
|
|
269 |
is_sample: bool = False,
|
270 |
progress=gr.Progress(),
|
271 |
) -> pd.DataFrame:
|
@@ -278,13 +284,18 @@ def generate_dataset_from_seed(
|
|
278 |
temperature=temperature, is_sample=is_sample
|
279 |
)
|
280 |
response_generator = get_response_generator(
|
281 |
-
system_prompt=None,
|
|
|
|
|
|
|
282 |
)
|
283 |
follow_up_generator_instruction = get_follow_up_generator(
|
284 |
type="instruction", temperature=temperature, is_sample=is_sample
|
285 |
)
|
286 |
follow_up_generator_response = get_follow_up_generator(
|
287 |
-
type="response",
|
|
|
|
|
288 |
)
|
289 |
steps = 2 * num_turns
|
290 |
total_steps: int = num_rows * steps
|
@@ -402,6 +413,7 @@ def generate_dataset(
|
|
402 |
num_turns: int = 1,
|
403 |
num_rows: int = 10,
|
404 |
temperature: float = 0.9,
|
|
|
405 |
is_sample: bool = False,
|
406 |
progress=gr.Progress(),
|
407 |
) -> pd.DataFrame:
|
@@ -411,6 +423,7 @@ def generate_dataset(
|
|
411 |
num_turns=num_turns,
|
412 |
num_rows=num_rows,
|
413 |
temperature=temperature,
|
|
|
414 |
is_sample=is_sample,
|
415 |
)
|
416 |
else:
|
@@ -420,6 +433,7 @@ def generate_dataset(
|
|
420 |
num_turns=num_turns,
|
421 |
num_rows=num_rows,
|
422 |
temperature=temperature,
|
|
|
423 |
is_sample=is_sample,
|
424 |
)
|
425 |
return dataframe
|
@@ -468,6 +482,7 @@ def push_dataset(
|
|
468 |
num_turns: int = 1,
|
469 |
num_rows: int = 10,
|
470 |
temperature: float = 0.9,
|
|
|
471 |
pipeline_code: str = "",
|
472 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
473 |
progress=gr.Progress(),
|
@@ -491,6 +506,7 @@ def push_dataset(
|
|
491 |
num_turns=num_turns,
|
492 |
num_rows=num_rows,
|
493 |
temperature=temperature,
|
|
|
494 |
)
|
495 |
push_dataset_to_hub(
|
496 |
dataframe=dataframe,
|
@@ -651,6 +667,11 @@ def hide_pipeline_code_visibility():
|
|
651 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
652 |
|
653 |
|
|
|
|
|
|
|
|
|
|
|
654 |
######################
|
655 |
# Gradio UI
|
656 |
######################
|
@@ -808,11 +829,20 @@ with gr.Blocks() as app:
|
|
808 |
temperature = gr.Slider(
|
809 |
label="Temperature",
|
810 |
minimum=0.1,
|
811 |
-
maximum=1,
|
812 |
value=0.9,
|
813 |
step=0.1,
|
814 |
interactive=True,
|
815 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
private = gr.Checkbox(
|
817 |
label="Private dataset",
|
818 |
value=False,
|
@@ -944,6 +974,7 @@ with gr.Blocks() as app:
|
|
944 |
num_turns,
|
945 |
num_rows,
|
946 |
temperature,
|
|
|
947 |
pipeline_code,
|
948 |
],
|
949 |
outputs=[success_message],
|
@@ -976,7 +1007,7 @@ with gr.Blocks() as app:
|
|
976 |
inputs=[dataframe],
|
977 |
outputs=[system_prompt, document_column, num_turns, dataframe],
|
978 |
)
|
979 |
-
|
980 |
app.load(fn=swap_visibility, outputs=main_ui)
|
981 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
982 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
|
|
|
28 |
BASE_URL,
|
29 |
DEFAULT_BATCH_SIZE,
|
30 |
MODEL,
|
31 |
+
MODEL_COMPLETION,
|
32 |
SFT_AVAILABLE,
|
33 |
)
|
34 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
|
|
149 |
num_turns: int = 1,
|
150 |
num_rows: int = 10,
|
151 |
temperature: float = 0.9,
|
152 |
+
temperature_completion: Union[float, None] = None,
|
153 |
is_sample: bool = False,
|
154 |
progress=gr.Progress(),
|
155 |
) -> pd.DataFrame:
|
|
|
157 |
progress(0.0, desc="(1/2) Generating instructions")
|
158 |
magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
|
159 |
response_generator = get_response_generator(
|
160 |
+
system_prompt=system_prompt,
|
161 |
+
num_turns=num_turns,
|
162 |
+
temperature=temperature or temperature_completion,
|
163 |
+
is_sample=is_sample,
|
164 |
)
|
165 |
total_steps: int = num_rows * 2
|
166 |
batch_size = DEFAULT_BATCH_SIZE
|
|
|
271 |
num_turns: int = 1,
|
272 |
num_rows: int = 10,
|
273 |
temperature: float = 0.9,
|
274 |
+
temperature_completion: Union[float, None] = None,
|
275 |
is_sample: bool = False,
|
276 |
progress=gr.Progress(),
|
277 |
) -> pd.DataFrame:
|
|
|
284 |
temperature=temperature, is_sample=is_sample
|
285 |
)
|
286 |
response_generator = get_response_generator(
|
287 |
+
system_prompt=None,
|
288 |
+
num_turns=1,
|
289 |
+
temperature=temperature or temperature_completion,
|
290 |
+
is_sample=is_sample,
|
291 |
)
|
292 |
follow_up_generator_instruction = get_follow_up_generator(
|
293 |
type="instruction", temperature=temperature, is_sample=is_sample
|
294 |
)
|
295 |
follow_up_generator_response = get_follow_up_generator(
|
296 |
+
type="response",
|
297 |
+
temperature=temperature or temperature_completion,
|
298 |
+
is_sample=is_sample,
|
299 |
)
|
300 |
steps = 2 * num_turns
|
301 |
total_steps: int = num_rows * steps
|
|
|
413 |
num_turns: int = 1,
|
414 |
num_rows: int = 10,
|
415 |
temperature: float = 0.9,
|
416 |
+
temperature_completion: Union[float, None] = None,
|
417 |
is_sample: bool = False,
|
418 |
progress=gr.Progress(),
|
419 |
) -> pd.DataFrame:
|
|
|
423 |
num_turns=num_turns,
|
424 |
num_rows=num_rows,
|
425 |
temperature=temperature,
|
426 |
+
temperature_completion=temperature_completion,
|
427 |
is_sample=is_sample,
|
428 |
)
|
429 |
else:
|
|
|
433 |
num_turns=num_turns,
|
434 |
num_rows=num_rows,
|
435 |
temperature=temperature,
|
436 |
+
temperature_completion=temperature_completion,
|
437 |
is_sample=is_sample,
|
438 |
)
|
439 |
return dataframe
|
|
|
482 |
num_turns: int = 1,
|
483 |
num_rows: int = 10,
|
484 |
temperature: float = 0.9,
|
485 |
+
temperature_completion: Union[float, None] = None,
|
486 |
pipeline_code: str = "",
|
487 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
488 |
progress=gr.Progress(),
|
|
|
506 |
num_turns=num_turns,
|
507 |
num_rows=num_rows,
|
508 |
temperature=temperature,
|
509 |
+
temperature_completion=temperature_completion
|
510 |
)
|
511 |
push_dataset_to_hub(
|
512 |
dataframe=dataframe,
|
|
|
667 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
668 |
|
669 |
|
670 |
+
def show_temperature_completion():
|
671 |
+
if MODEL != MODEL_COMPLETION:
|
672 |
+
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
673 |
+
|
674 |
+
|
675 |
######################
|
676 |
# Gradio UI
|
677 |
######################
|
|
|
829 |
temperature = gr.Slider(
|
830 |
label="Temperature",
|
831 |
minimum=0.1,
|
832 |
+
maximum=1.5,
|
833 |
value=0.9,
|
834 |
step=0.1,
|
835 |
interactive=True,
|
836 |
)
|
837 |
+
temperature_completion = gr.Slider(
|
838 |
+
label="Temperature for completion",
|
839 |
+
minimum=0.1,
|
840 |
+
maximum=1.5,
|
841 |
+
value=None,
|
842 |
+
step=0.1,
|
843 |
+
interactive=True,
|
844 |
+
visible=False,
|
845 |
+
)
|
846 |
private = gr.Checkbox(
|
847 |
label="Private dataset",
|
848 |
value=False,
|
|
|
974 |
num_turns,
|
975 |
num_rows,
|
976 |
temperature,
|
977 |
+
temperature_completion,
|
978 |
pipeline_code,
|
979 |
],
|
980 |
outputs=[success_message],
|
|
|
1007 |
inputs=[dataframe],
|
1008 |
outputs=[system_prompt, document_column, num_turns, dataframe],
|
1009 |
)
|
|
|
1010 |
app.load(fn=swap_visibility, outputs=main_ui)
|
1011 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
1012 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1013 |
+
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -24,7 +24,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
-
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
@@ -132,6 +132,7 @@ def generate_dataset(
|
|
132 |
reranking: bool = False,
|
133 |
num_rows: int = 10,
|
134 |
temperature: float = 0.7,
|
|
|
135 |
is_sample: bool = False,
|
136 |
progress=gr.Progress(),
|
137 |
):
|
@@ -155,7 +156,7 @@ def generate_dataset(
|
|
155 |
is_sample=is_sample,
|
156 |
)
|
157 |
response_generator = get_response_generator(
|
158 |
-
temperature=temperature, is_sample=is_sample
|
159 |
)
|
160 |
if reranking:
|
161 |
reranking_generator = get_sentence_pair_generator(
|
@@ -320,6 +321,7 @@ def push_dataset(
|
|
320 |
retrieval_reranking: list[str],
|
321 |
num_rows: int,
|
322 |
temperature: float,
|
|
|
323 |
pipeline_code: str,
|
324 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
325 |
progress=gr.Progress(),
|
@@ -347,6 +349,8 @@ def push_dataset(
|
|
347 |
reranking=reranking,
|
348 |
num_rows=num_rows,
|
349 |
temperature=temperature,
|
|
|
|
|
350 |
)
|
351 |
push_dataset_to_hub(
|
352 |
dataframe, org_name, repo_name, oauth_token, private, pipeline_code
|
@@ -512,6 +516,11 @@ def hide_pipeline_code_visibility():
|
|
512 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
513 |
|
514 |
|
|
|
|
|
|
|
|
|
|
|
515 |
######################
|
516 |
# Gradio UI
|
517 |
######################
|
@@ -645,11 +654,20 @@ with gr.Blocks() as app:
|
|
645 |
temperature = gr.Slider(
|
646 |
label="Temperature",
|
647 |
minimum=0.1,
|
648 |
-
maximum=1,
|
649 |
value=0.7,
|
650 |
step=0.1,
|
651 |
interactive=True,
|
652 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
private = gr.Checkbox(
|
654 |
label="Private dataset",
|
655 |
value=False,
|
@@ -779,6 +797,7 @@ with gr.Blocks() as app:
|
|
779 |
retrieval_reranking,
|
780 |
num_rows,
|
781 |
temperature,
|
|
|
782 |
pipeline_code,
|
783 |
],
|
784 |
outputs=[success_message],
|
@@ -815,3 +834,4 @@ with gr.Blocks() as app:
|
|
815 |
app.load(fn=swap_visibility, outputs=main_ui)
|
816 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
817 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
|
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
+
from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
|
|
132 |
reranking: bool = False,
|
133 |
num_rows: int = 10,
|
134 |
temperature: float = 0.7,
|
135 |
+
temperature_completion: Union[float, None] = None,
|
136 |
is_sample: bool = False,
|
137 |
progress=gr.Progress(),
|
138 |
):
|
|
|
156 |
is_sample=is_sample,
|
157 |
)
|
158 |
response_generator = get_response_generator(
|
159 |
+
temperature = temperature_completion or temperature , is_sample=is_sample
|
160 |
)
|
161 |
if reranking:
|
162 |
reranking_generator = get_sentence_pair_generator(
|
|
|
321 |
retrieval_reranking: list[str],
|
322 |
num_rows: int,
|
323 |
temperature: float,
|
324 |
+
temperature_completion: float,
|
325 |
pipeline_code: str,
|
326 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
327 |
progress=gr.Progress(),
|
|
|
349 |
reranking=reranking,
|
350 |
num_rows=num_rows,
|
351 |
temperature=temperature,
|
352 |
+
temperature_completion=temperature_completion,
|
353 |
+
is_sample=True,
|
354 |
)
|
355 |
push_dataset_to_hub(
|
356 |
dataframe, org_name, repo_name, oauth_token, private, pipeline_code
|
|
|
516 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
517 |
|
518 |
|
519 |
+
def show_temperature_completion():
|
520 |
+
if MODEL != MODEL_COMPLETION:
|
521 |
+
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
522 |
+
|
523 |
+
|
524 |
######################
|
525 |
# Gradio UI
|
526 |
######################
|
|
|
654 |
temperature = gr.Slider(
|
655 |
label="Temperature",
|
656 |
minimum=0.1,
|
657 |
+
maximum=1.5,
|
658 |
value=0.7,
|
659 |
step=0.1,
|
660 |
interactive=True,
|
661 |
)
|
662 |
+
temperature_completion = gr.Slider(
|
663 |
+
label="Temperature for completion",
|
664 |
+
minimum=0.1,
|
665 |
+
maximum=1.5,
|
666 |
+
value=None,
|
667 |
+
step=0.1,
|
668 |
+
interactive=True,
|
669 |
+
visible=False,
|
670 |
+
)
|
671 |
private = gr.Checkbox(
|
672 |
label="Private dataset",
|
673 |
value=False,
|
|
|
797 |
retrieval_reranking,
|
798 |
num_rows,
|
799 |
temperature,
|
800 |
+
temperature_completion,
|
801 |
pipeline_code,
|
802 |
],
|
803 |
outputs=[success_message],
|
|
|
834 |
app.load(fn=swap_visibility, outputs=main_ui)
|
835 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
836 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
837 |
+
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -532,7 +532,7 @@ with gr.Blocks() as app:
|
|
532 |
temperature = gr.Slider(
|
533 |
label="Temperature",
|
534 |
minimum=0.1,
|
535 |
-
maximum=1,
|
536 |
value=0.8,
|
537 |
step=0.1,
|
538 |
interactive=True,
|
|
|
532 |
temperature = gr.Slider(
|
533 |
label="Temperature",
|
534 |
minimum=0.1,
|
535 |
+
maximum=1.5,
|
536 |
value=0.8,
|
537 |
step=0.1,
|
538 |
interactive=True,
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -3,10 +3,6 @@ import warnings
|
|
3 |
|
4 |
import argilla as rg
|
5 |
|
6 |
-
# Tasks
|
7 |
-
TEXTCAT_TASK = "text_classification"
|
8 |
-
SFT_TASK = "supervised_fine_tuning"
|
9 |
-
|
10 |
# Inference
|
11 |
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
12 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
@@ -20,28 +16,56 @@ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
|
|
20 |
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
|
21 |
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
|
22 |
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
base_urls = [
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
37 |
]
|
38 |
-
if len(base_urls) > 1:
|
39 |
-
raise ValueError(
|
40 |
-
f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
|
41 |
-
)
|
42 |
-
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
# API Keys
|
46 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
47 |
if not HF_TOKEN:
|
|
|
3 |
|
4 |
import argilla as rg
|
5 |
|
|
|
|
|
|
|
|
|
6 |
# Inference
|
7 |
MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
|
|
16 |
HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
|
17 |
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
|
18 |
|
19 |
+
# Just used in case of selecting a different model for completions
|
20 |
+
MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
|
21 |
+
TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
|
22 |
+
OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
|
23 |
+
OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
|
24 |
+
HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
|
25 |
+
"HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
|
26 |
+
)
|
27 |
+
VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)
|
28 |
+
|
29 |
+
base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
|
30 |
+
base_urls_completion = [
|
31 |
+
OPENAI_BASE_URL_COMPLETION,
|
32 |
+
OLLAMA_BASE_URL_COMPLETION,
|
33 |
+
HUGGINGFACE_BASE_URL_COMPLETION,
|
34 |
+
VLLM_BASE_URL_COMPLETION,
|
35 |
]
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
+
# Validate the configuration of the model and base URLs.
|
39 |
+
def validate_configuration(base_urls, model, env_context=""):
|
40 |
+
huggingface_url = base_urls[2]
|
41 |
+
if huggingface_url and model:
|
42 |
+
raise ValueError(
|
43 |
+
f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
|
44 |
+
"Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
|
45 |
+
)
|
46 |
+
|
47 |
+
if not model and any(base_urls):
|
48 |
+
raise ValueError(
|
49 |
+
f"`MODEL{env_context}` is not set. Please provide a model id for inference."
|
50 |
+
)
|
51 |
+
|
52 |
+
active_urls = [url for url in base_urls if url]
|
53 |
+
if len(active_urls) > 1:
|
54 |
+
raise ValueError(
|
55 |
+
f"Multiple base URLs are provided: {', '.join(active_urls)}. "
|
56 |
+
"Only one base URL can be set at a time."
|
57 |
+
)
|
58 |
+
validate_configuration(base_urls, MODEL)
|
59 |
+
validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")
|
60 |
+
|
61 |
+
BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
|
62 |
+
BASE_URL_COMPLETION = (
|
63 |
+
OPENAI_BASE_URL_COMPLETION
|
64 |
+
or OLLAMA_BASE_URL_COMPLETION
|
65 |
+
or HUGGINGFACE_BASE_URL_COMPLETION
|
66 |
+
or VLLM_BASE_URL_COMPLETION
|
67 |
+
)
|
68 |
+
|
69 |
# API Keys
|
70 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
71 |
if not HF_TOKEN:
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -8,11 +8,17 @@ from synthetic_dataset_generator.constants import (
|
|
8 |
API_KEYS,
|
9 |
DEFAULT_BATCH_SIZE,
|
10 |
HUGGINGFACE_BASE_URL,
|
|
|
11 |
MODEL,
|
|
|
12 |
OLLAMA_BASE_URL,
|
|
|
13 |
OPENAI_BASE_URL,
|
|
|
14 |
TOKENIZER_ID,
|
|
|
15 |
VLLM_BASE_URL,
|
|
|
16 |
)
|
17 |
|
18 |
TOKEN_INDEX = 0
|
@@ -73,12 +79,20 @@ def _get_llm_class() -> str:
|
|
73 |
return "InferenceEndpointsLLM"
|
74 |
|
75 |
|
76 |
-
def _get_llm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
if OPENAI_BASE_URL:
|
78 |
llm = OpenAILLM(
|
79 |
-
model=
|
80 |
-
base_url=OPENAI_BASE_URL,
|
81 |
api_key=_get_next_api_key(),
|
|
|
82 |
**kwargs,
|
83 |
)
|
84 |
if "generation_kwargs" in kwargs:
|
@@ -108,19 +122,25 @@ def _get_llm(use_magpie_template=False, **kwargs):
|
|
108 |
kwargs["generation_kwargs"] = {}
|
109 |
kwargs["generation_kwargs"]["options"] = options
|
110 |
llm = OllamaLLM(
|
111 |
-
model=
|
112 |
-
host=OLLAMA_BASE_URL,
|
113 |
-
tokenizer_id=
|
114 |
use_magpie_template=use_magpie_template,
|
|
|
115 |
**kwargs,
|
116 |
)
|
117 |
elif HUGGINGFACE_BASE_URL:
|
118 |
kwargs["generation_kwargs"]["do_sample"] = True
|
119 |
llm = InferenceEndpointsLLM(
|
120 |
api_key=_get_next_api_key(),
|
121 |
-
base_url=
|
122 |
-
|
|
|
|
|
|
|
|
|
123 |
use_magpie_template=use_magpie_template,
|
|
|
124 |
**kwargs,
|
125 |
)
|
126 |
elif VLLM_BASE_URL:
|
@@ -128,19 +148,21 @@ def _get_llm(use_magpie_template=False, **kwargs):
|
|
128 |
if "do_sample" in kwargs["generation_kwargs"]:
|
129 |
del kwargs["generation_kwargs"]["do_sample"]
|
130 |
llm = ClientvLLM(
|
131 |
-
base_url=VLLM_BASE_URL,
|
132 |
-
model=
|
133 |
-
tokenizer=
|
134 |
api_key=_get_next_api_key(),
|
135 |
use_magpie_template=use_magpie_template,
|
|
|
136 |
**kwargs,
|
137 |
)
|
138 |
else:
|
139 |
llm = InferenceEndpointsLLM(
|
140 |
api_key=_get_next_api_key(),
|
141 |
-
tokenizer_id=
|
142 |
-
model_id=
|
143 |
use_magpie_template=use_magpie_template,
|
|
|
144 |
**kwargs,
|
145 |
)
|
146 |
|
|
|
8 |
API_KEYS,
|
9 |
DEFAULT_BATCH_SIZE,
|
10 |
HUGGINGFACE_BASE_URL,
|
11 |
+
HUGGINGFACE_BASE_URL_COMPLETION,
|
12 |
MODEL,
|
13 |
+
MODEL_COMPLETION,
|
14 |
OLLAMA_BASE_URL,
|
15 |
+
OLLAMA_BASE_URL_COMPLETION,
|
16 |
OPENAI_BASE_URL,
|
17 |
+
OPENAI_BASE_URL_COMPLETION,
|
18 |
TOKENIZER_ID,
|
19 |
+
TOKENIZER_ID_COMPLETION,
|
20 |
VLLM_BASE_URL,
|
21 |
+
VLLM_BASE_URL_COMPLETION,
|
22 |
)
|
23 |
|
24 |
TOKEN_INDEX = 0
|
|
|
79 |
return "InferenceEndpointsLLM"
|
80 |
|
81 |
|
82 |
+
def _get_llm(
|
83 |
+
structured_output: dict = None,
|
84 |
+
use_magpie_template: str = False,
|
85 |
+
is_completion: bool = False,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
model = MODEL_COMPLETION if is_completion else MODEL
|
89 |
+
tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
|
90 |
if OPENAI_BASE_URL:
|
91 |
llm = OpenAILLM(
|
92 |
+
model=model,
|
93 |
+
base_url=OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
|
94 |
api_key=_get_next_api_key(),
|
95 |
+
structured_output=structured_output,
|
96 |
**kwargs,
|
97 |
)
|
98 |
if "generation_kwargs" in kwargs:
|
|
|
122 |
kwargs["generation_kwargs"] = {}
|
123 |
kwargs["generation_kwargs"]["options"] = options
|
124 |
llm = OllamaLLM(
|
125 |
+
model=model,
|
126 |
+
host=OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
|
127 |
+
tokenizer_id=tokenizer_id,
|
128 |
use_magpie_template=use_magpie_template,
|
129 |
+
structured_output=structured_output,
|
130 |
**kwargs,
|
131 |
)
|
132 |
elif HUGGINGFACE_BASE_URL:
|
133 |
kwargs["generation_kwargs"]["do_sample"] = True
|
134 |
llm = InferenceEndpointsLLM(
|
135 |
api_key=_get_next_api_key(),
|
136 |
+
base_url=(
|
137 |
+
HUGGINGFACE_BASE_URL_COMPLETION
|
138 |
+
if is_completion
|
139 |
+
else HUGGINGFACE_BASE_URL
|
140 |
+
),
|
141 |
+
tokenizer_id=tokenizer_id,
|
142 |
use_magpie_template=use_magpie_template,
|
143 |
+
structured_output=structured_output,
|
144 |
**kwargs,
|
145 |
)
|
146 |
elif VLLM_BASE_URL:
|
|
|
148 |
if "do_sample" in kwargs["generation_kwargs"]:
|
149 |
del kwargs["generation_kwargs"]["do_sample"]
|
150 |
llm = ClientvLLM(
|
151 |
+
base_url=VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
|
152 |
+
model=model,
|
153 |
+
tokenizer=tokenizer_id,
|
154 |
api_key=_get_next_api_key(),
|
155 |
use_magpie_template=use_magpie_template,
|
156 |
+
structured_output=structured_output,
|
157 |
**kwargs,
|
158 |
)
|
159 |
else:
|
160 |
llm = InferenceEndpointsLLM(
|
161 |
api_key=_get_next_api_key(),
|
162 |
+
tokenizer_id=tokenizer_id,
|
163 |
+
model_id=model,
|
164 |
use_magpie_template=use_magpie_template,
|
165 |
+
structured_output=structured_output,
|
166 |
**kwargs,
|
167 |
)
|
168 |
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -245,7 +245,7 @@ def get_response_generator(
|
|
245 |
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
246 |
}
|
247 |
response_generator = TextGeneration(
|
248 |
-
llm=_get_llm(generation_kwargs=generation_kwargs),
|
249 |
system_prompt=system_prompt,
|
250 |
output_mappings={"generation": "completion"},
|
251 |
input_mappings={"instruction": "prompt"},
|
@@ -256,7 +256,7 @@ def get_response_generator(
|
|
256 |
"max_new_tokens": MAX_NUM_TOKENS,
|
257 |
}
|
258 |
response_generator = ChatGeneration(
|
259 |
-
llm=_get_llm(generation_kwargs=generation_kwargs),
|
260 |
output_mappings={"generation": "completion"},
|
261 |
input_mappings={"conversation": "messages"},
|
262 |
)
|
@@ -281,7 +281,7 @@ def get_follow_up_generator(type: str, temperature: float, is_sample: bool):
|
|
281 |
"max_new_tokens": MAX_NUM_TOKENS,
|
282 |
}
|
283 |
follow_up_generator = ChatGeneration(
|
284 |
-
llm=_get_llm(generation_kwargs=generation_kwargs),
|
285 |
)
|
286 |
follow_up_generator.load()
|
287 |
return follow_up_generator
|
@@ -336,7 +336,7 @@ def generate_pipeline_code_seed(
|
|
336 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
337 |
from distilabel.models import {_get_llm_class()}
|
338 |
from distilabel.pipeline import Pipeline
|
339 |
-
from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}
|
340 |
from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
|
341 |
"""
|
342 |
|
@@ -455,10 +455,10 @@ with Pipeline(name="sft") as pipeline:
|
|
455 |
keep_columns = KeepColumns(columns=["messages"])
|
456 |
"""
|
457 |
code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
|
458 |
-
|
459 |
for i in range(1, num_turns + 1):
|
460 |
code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
|
461 |
-
|
462 |
code += " >> keep_columns"
|
463 |
|
464 |
code += """
|
|
|
245 |
"max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
|
246 |
}
|
247 |
response_generator = TextGeneration(
|
248 |
+
llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
|
249 |
system_prompt=system_prompt,
|
250 |
output_mappings={"generation": "completion"},
|
251 |
input_mappings={"instruction": "prompt"},
|
|
|
256 |
"max_new_tokens": MAX_NUM_TOKENS,
|
257 |
}
|
258 |
response_generator = ChatGeneration(
|
259 |
+
llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
|
260 |
output_mappings={"generation": "completion"},
|
261 |
input_mappings={"conversation": "messages"},
|
262 |
)
|
|
|
281 |
"max_new_tokens": MAX_NUM_TOKENS,
|
282 |
}
|
283 |
follow_up_generator = ChatGeneration(
|
284 |
+
llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
|
285 |
)
|
286 |
follow_up_generator.load()
|
287 |
return follow_up_generator
|
|
|
336 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
337 |
from distilabel.models import {_get_llm_class()}
|
338 |
from distilabel.pipeline import Pipeline
|
339 |
+
from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}{", StepInput, step" if num_turns > 1 else ""}
|
340 |
from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
|
341 |
"""
|
342 |
|
|
|
455 |
keep_columns = KeepColumns(columns=["messages"])
|
456 |
"""
|
457 |
code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
|
458 |
+
|
459 |
for i in range(1, num_turns + 1):
|
460 |
code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
|
461 |
+
|
462 |
code += " >> keep_columns"
|
463 |
|
464 |
code += """
|
src/synthetic_dataset_generator/pipelines/rag.py
CHANGED
@@ -121,7 +121,7 @@ def get_response_generator(temperature: float, is_sample: bool):
|
|
121 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
122 |
}
|
123 |
text_generator = TextGeneration(
|
124 |
-
llm=_get_llm(generation_kwargs=generation_kwargs),
|
125 |
system_prompt=SYSTEM_PROMPT_RAG,
|
126 |
template=RAG_TEMPLATE,
|
127 |
columns=["context", "question"],
|
|
|
121 |
"max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
|
122 |
}
|
123 |
text_generator = TextGeneration(
|
124 |
+
llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
|
125 |
system_prompt=SYSTEM_PROMPT_RAG,
|
126 |
template=RAG_TEMPLATE,
|
127 |
columns=["context", "question"],
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -109,7 +109,7 @@ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: b
|
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
-
llm = _get_llm(generation_kwargs=generation_kwargs)
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|
|
|
109 |
"temperature": 0.01,
|
110 |
"max_new_tokens": MAX_NUM_TOKENS,
|
111 |
}
|
112 |
+
llm = _get_llm(is_completion=True, generation_kwargs=generation_kwargs)
|
113 |
labeller_generator = TextClassification(
|
114 |
llm=llm,
|
115 |
context=system_prompt,
|