Sara Han davidberenstein1957 HF staff commited on
Commit
3b7b628
·
unverified ·
1 Parent(s): b2669f7

feat: different model completion (#31)

Browse files

* feat: use different models for instruction and completion

* refactor to support specific completion model

* improve comments

* add tokenizer_id

* make improvements and fix structured generation for other providers

* fix temperature issue

* Update src/synthetic_dataset_generator/constants.py

Co-authored-by: David Berenstein <[email protected]>

* apply feedback

* merging fix

---------

Co-authored-by: David Berenstein <[email protected]>

README.md CHANGED
@@ -86,12 +86,14 @@ You can set the following environment variables to customize the generation proc
86
  Optionally, you can use different API providers and models.
87
 
88
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
89
- - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
90
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
91
  - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
92
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
93
  - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
94
 
 
 
95
  SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
96
 
97
  - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
 
86
  Optionally, you can use different API providers and models.
87
 
88
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
89
+ - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the `HF_TOKEN` environment variable.
90
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
91
  - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
92
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
93
  - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
94
 
95
+ To use a specific model exclusively for generating completions, set the corresponding environment variables by appending `_COMPLETION` to the ones mentioned earlier. For example, you can use `MODEL_COMPLETION` and `OPENAI_BASE_URL_COMPLETION`.
96
+
97
  SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
98
 
99
  - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
examples/hf-serverless-deployment.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
- os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
13
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
 
15
  launch()
 
9
  from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
13
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
 
15
  launch()
examples/hf-serverless-different-model-for-completion.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for instruction generation
13
+ os.environ["MODEL_COMPLETION"] = "meta-llama/Llama-3.1-70B-Instruct" # use model for completion generation
14
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
15
+
16
+ launch()
examples/ollama-different-model-for-completion.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # ollama serve
8
+ # ollama run llama3.2
9
+ # ollama run llama3.2:1b
10
+ import os
11
+
12
+ from synthetic_dataset_generator import launch
13
+
14
+ os.environ["OLLAMA_BASE_URL"] = (
15
+ "http://127.0.0.1:11434/" # in this case, the same base url for both models
16
+ )
17
+
18
+ os.environ["MODEL"] = "llama3.2" # model for instruction generation
19
+ os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
20
+
21
+ os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for instruction generation
22
+ os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for completion generation
23
+
24
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
25
+
26
+ launch()
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -28,6 +28,7 @@ from synthetic_dataset_generator.constants import (
28
  BASE_URL,
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
 
31
  SFT_AVAILABLE,
32
  )
33
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
@@ -148,6 +149,7 @@ def generate_dataset_from_prompt(
148
  num_turns: int = 1,
149
  num_rows: int = 10,
150
  temperature: float = 0.9,
 
151
  is_sample: bool = False,
152
  progress=gr.Progress(),
153
  ) -> pd.DataFrame:
@@ -155,7 +157,10 @@ def generate_dataset_from_prompt(
155
  progress(0.0, desc="(1/2) Generating instructions")
156
  magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
157
  response_generator = get_response_generator(
158
- system_prompt, num_turns, temperature, is_sample
 
 
 
159
  )
160
  total_steps: int = num_rows * 2
161
  batch_size = DEFAULT_BATCH_SIZE
@@ -266,6 +271,7 @@ def generate_dataset_from_seed(
266
  num_turns: int = 1,
267
  num_rows: int = 10,
268
  temperature: float = 0.9,
 
269
  is_sample: bool = False,
270
  progress=gr.Progress(),
271
  ) -> pd.DataFrame:
@@ -278,13 +284,18 @@ def generate_dataset_from_seed(
278
  temperature=temperature, is_sample=is_sample
279
  )
280
  response_generator = get_response_generator(
281
- system_prompt=None, num_turns=1, temperature=temperature, is_sample=is_sample
 
 
 
282
  )
283
  follow_up_generator_instruction = get_follow_up_generator(
284
  type="instruction", temperature=temperature, is_sample=is_sample
285
  )
286
  follow_up_generator_response = get_follow_up_generator(
287
- type="response", temperature=temperature, is_sample=is_sample
 
 
288
  )
289
  steps = 2 * num_turns
290
  total_steps: int = num_rows * steps
@@ -402,6 +413,7 @@ def generate_dataset(
402
  num_turns: int = 1,
403
  num_rows: int = 10,
404
  temperature: float = 0.9,
 
405
  is_sample: bool = False,
406
  progress=gr.Progress(),
407
  ) -> pd.DataFrame:
@@ -411,6 +423,7 @@ def generate_dataset(
411
  num_turns=num_turns,
412
  num_rows=num_rows,
413
  temperature=temperature,
 
414
  is_sample=is_sample,
415
  )
416
  else:
@@ -420,6 +433,7 @@ def generate_dataset(
420
  num_turns=num_turns,
421
  num_rows=num_rows,
422
  temperature=temperature,
 
423
  is_sample=is_sample,
424
  )
425
  return dataframe
@@ -468,6 +482,7 @@ def push_dataset(
468
  num_turns: int = 1,
469
  num_rows: int = 10,
470
  temperature: float = 0.9,
 
471
  pipeline_code: str = "",
472
  oauth_token: Union[gr.OAuthToken, None] = None,
473
  progress=gr.Progress(),
@@ -491,6 +506,7 @@ def push_dataset(
491
  num_turns=num_turns,
492
  num_rows=num_rows,
493
  temperature=temperature,
 
494
  )
495
  push_dataset_to_hub(
496
  dataframe=dataframe,
@@ -651,6 +667,11 @@ def hide_pipeline_code_visibility():
651
  return {pipeline_code_ui: gr.Accordion(visible=False)}
652
 
653
 
 
 
 
 
 
654
  ######################
655
  # Gradio UI
656
  ######################
@@ -808,11 +829,20 @@ with gr.Blocks() as app:
808
  temperature = gr.Slider(
809
  label="Temperature",
810
  minimum=0.1,
811
- maximum=1,
812
  value=0.9,
813
  step=0.1,
814
  interactive=True,
815
  )
 
 
 
 
 
 
 
 
 
816
  private = gr.Checkbox(
817
  label="Private dataset",
818
  value=False,
@@ -944,6 +974,7 @@ with gr.Blocks() as app:
944
  num_turns,
945
  num_rows,
946
  temperature,
 
947
  pipeline_code,
948
  ],
949
  outputs=[success_message],
@@ -976,7 +1007,7 @@ with gr.Blocks() as app:
976
  inputs=[dataframe],
977
  outputs=[system_prompt, document_column, num_turns, dataframe],
978
  )
979
-
980
  app.load(fn=swap_visibility, outputs=main_ui)
981
  app.load(fn=get_org_dropdown, outputs=[org_name])
982
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
 
28
  BASE_URL,
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
31
+ MODEL_COMPLETION,
32
  SFT_AVAILABLE,
33
  )
34
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
 
149
  num_turns: int = 1,
150
  num_rows: int = 10,
151
  temperature: float = 0.9,
152
+ temperature_completion: Union[float, None] = None,
153
  is_sample: bool = False,
154
  progress=gr.Progress(),
155
  ) -> pd.DataFrame:
 
157
  progress(0.0, desc="(1/2) Generating instructions")
158
  magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
159
  response_generator = get_response_generator(
160
+ system_prompt=system_prompt,
161
+ num_turns=num_turns,
162
+ temperature=temperature or temperature_completion,
163
+ is_sample=is_sample,
164
  )
165
  total_steps: int = num_rows * 2
166
  batch_size = DEFAULT_BATCH_SIZE
 
271
  num_turns: int = 1,
272
  num_rows: int = 10,
273
  temperature: float = 0.9,
274
+ temperature_completion: Union[float, None] = None,
275
  is_sample: bool = False,
276
  progress=gr.Progress(),
277
  ) -> pd.DataFrame:
 
284
  temperature=temperature, is_sample=is_sample
285
  )
286
  response_generator = get_response_generator(
287
+ system_prompt=None,
288
+ num_turns=1,
289
+ temperature=temperature or temperature_completion,
290
+ is_sample=is_sample,
291
  )
292
  follow_up_generator_instruction = get_follow_up_generator(
293
  type="instruction", temperature=temperature, is_sample=is_sample
294
  )
295
  follow_up_generator_response = get_follow_up_generator(
296
+ type="response",
297
+ temperature=temperature or temperature_completion,
298
+ is_sample=is_sample,
299
  )
300
  steps = 2 * num_turns
301
  total_steps: int = num_rows * steps
 
413
  num_turns: int = 1,
414
  num_rows: int = 10,
415
  temperature: float = 0.9,
416
+ temperature_completion: Union[float, None] = None,
417
  is_sample: bool = False,
418
  progress=gr.Progress(),
419
  ) -> pd.DataFrame:
 
423
  num_turns=num_turns,
424
  num_rows=num_rows,
425
  temperature=temperature,
426
+ temperature_completion=temperature_completion,
427
  is_sample=is_sample,
428
  )
429
  else:
 
433
  num_turns=num_turns,
434
  num_rows=num_rows,
435
  temperature=temperature,
436
+ temperature_completion=temperature_completion,
437
  is_sample=is_sample,
438
  )
439
  return dataframe
 
482
  num_turns: int = 1,
483
  num_rows: int = 10,
484
  temperature: float = 0.9,
485
+ temperature_completion: Union[float, None] = None,
486
  pipeline_code: str = "",
487
  oauth_token: Union[gr.OAuthToken, None] = None,
488
  progress=gr.Progress(),
 
506
  num_turns=num_turns,
507
  num_rows=num_rows,
508
  temperature=temperature,
509
+ temperature_completion=temperature_completion
510
  )
511
  push_dataset_to_hub(
512
  dataframe=dataframe,
 
667
  return {pipeline_code_ui: gr.Accordion(visible=False)}
668
 
669
 
670
+ def show_temperature_completion():
671
+ if MODEL != MODEL_COMPLETION:
672
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
673
+
674
+
675
  ######################
676
  # Gradio UI
677
  ######################
 
829
  temperature = gr.Slider(
830
  label="Temperature",
831
  minimum=0.1,
832
+ maximum=1.5,
833
  value=0.9,
834
  step=0.1,
835
  interactive=True,
836
  )
837
+ temperature_completion = gr.Slider(
838
+ label="Temperature for completion",
839
+ minimum=0.1,
840
+ maximum=1.5,
841
+ value=None,
842
+ step=0.1,
843
+ interactive=True,
844
+ visible=False,
845
+ )
846
  private = gr.Checkbox(
847
  label="Private dataset",
848
  value=False,
 
974
  num_turns,
975
  num_rows,
976
  temperature,
977
+ temperature_completion,
978
  pipeline_code,
979
  ],
980
  outputs=[success_message],
 
1007
  inputs=[dataframe],
1008
  outputs=[system_prompt, document_column, num_turns, dataframe],
1009
  )
 
1010
  app.load(fn=swap_visibility, outputs=main_ui)
1011
  app.load(fn=get_org_dropdown, outputs=[org_name])
1012
  app.load(fn=get_random_repo_name, outputs=[repo_name])
1013
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -24,7 +24,7 @@ from synthetic_dataset_generator.apps.base import (
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
@@ -132,6 +132,7 @@ def generate_dataset(
132
  reranking: bool = False,
133
  num_rows: int = 10,
134
  temperature: float = 0.7,
 
135
  is_sample: bool = False,
136
  progress=gr.Progress(),
137
  ):
@@ -155,7 +156,7 @@ def generate_dataset(
155
  is_sample=is_sample,
156
  )
157
  response_generator = get_response_generator(
158
- temperature=temperature, is_sample=is_sample
159
  )
160
  if reranking:
161
  reranking_generator = get_sentence_pair_generator(
@@ -320,6 +321,7 @@ def push_dataset(
320
  retrieval_reranking: list[str],
321
  num_rows: int,
322
  temperature: float,
 
323
  pipeline_code: str,
324
  oauth_token: Union[gr.OAuthToken, None] = None,
325
  progress=gr.Progress(),
@@ -347,6 +349,8 @@ def push_dataset(
347
  reranking=reranking,
348
  num_rows=num_rows,
349
  temperature=temperature,
 
 
350
  )
351
  push_dataset_to_hub(
352
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
@@ -512,6 +516,11 @@ def hide_pipeline_code_visibility():
512
  return {pipeline_code_ui: gr.Accordion(visible=False)}
513
 
514
 
 
 
 
 
 
515
  ######################
516
  # Gradio UI
517
  ######################
@@ -645,11 +654,20 @@ with gr.Blocks() as app:
645
  temperature = gr.Slider(
646
  label="Temperature",
647
  minimum=0.1,
648
- maximum=1,
649
  value=0.7,
650
  step=0.1,
651
  interactive=True,
652
  )
 
 
 
 
 
 
 
 
 
653
  private = gr.Checkbox(
654
  label="Private dataset",
655
  value=False,
@@ -779,6 +797,7 @@ with gr.Blocks() as app:
779
  retrieval_reranking,
780
  num_rows,
781
  temperature,
 
782
  pipeline_code,
783
  ],
784
  outputs=[success_message],
@@ -815,3 +834,4 @@ with gr.Blocks() as app:
815
  app.load(fn=swap_visibility, outputs=main_ui)
816
  app.load(fn=get_org_dropdown, outputs=[org_name])
817
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
 
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
 
132
  reranking: bool = False,
133
  num_rows: int = 10,
134
  temperature: float = 0.7,
135
+ temperature_completion: Union[float, None] = None,
136
  is_sample: bool = False,
137
  progress=gr.Progress(),
138
  ):
 
156
  is_sample=is_sample,
157
  )
158
  response_generator = get_response_generator(
159
+ temperature = temperature_completion or temperature , is_sample=is_sample
160
  )
161
  if reranking:
162
  reranking_generator = get_sentence_pair_generator(
 
321
  retrieval_reranking: list[str],
322
  num_rows: int,
323
  temperature: float,
324
+ temperature_completion: float,
325
  pipeline_code: str,
326
  oauth_token: Union[gr.OAuthToken, None] = None,
327
  progress=gr.Progress(),
 
349
  reranking=reranking,
350
  num_rows=num_rows,
351
  temperature=temperature,
352
+ temperature_completion=temperature_completion,
353
+ is_sample=True,
354
  )
355
  push_dataset_to_hub(
356
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
 
516
  return {pipeline_code_ui: gr.Accordion(visible=False)}
517
 
518
 
519
+ def show_temperature_completion():
520
+ if MODEL != MODEL_COMPLETION:
521
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
522
+
523
+
524
  ######################
525
  # Gradio UI
526
  ######################
 
654
  temperature = gr.Slider(
655
  label="Temperature",
656
  minimum=0.1,
657
+ maximum=1.5,
658
  value=0.7,
659
  step=0.1,
660
  interactive=True,
661
  )
662
+ temperature_completion = gr.Slider(
663
+ label="Temperature for completion",
664
+ minimum=0.1,
665
+ maximum=1.5,
666
+ value=None,
667
+ step=0.1,
668
+ interactive=True,
669
+ visible=False,
670
+ )
671
  private = gr.Checkbox(
672
  label="Private dataset",
673
  value=False,
 
797
  retrieval_reranking,
798
  num_rows,
799
  temperature,
800
+ temperature_completion,
801
  pipeline_code,
802
  ],
803
  outputs=[success_message],
 
834
  app.load(fn=swap_visibility, outputs=main_ui)
835
  app.load(fn=get_org_dropdown, outputs=[org_name])
836
  app.load(fn=get_random_repo_name, outputs=[repo_name])
837
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -532,7 +532,7 @@ with gr.Blocks() as app:
532
  temperature = gr.Slider(
533
  label="Temperature",
534
  minimum=0.1,
535
- maximum=1,
536
  value=0.8,
537
  step=0.1,
538
  interactive=True,
 
532
  temperature = gr.Slider(
533
  label="Temperature",
534
  minimum=0.1,
535
+ maximum=1.5,
536
  value=0.8,
537
  step=0.1,
538
  interactive=True,
src/synthetic_dataset_generator/constants.py CHANGED
@@ -3,10 +3,6 @@ import warnings
3
 
4
  import argilla as rg
5
 
6
- # Tasks
7
- TEXTCAT_TASK = "text_classification"
8
- SFT_TASK = "supervised_fine_tuning"
9
-
10
  # Inference
11
  MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
12
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
@@ -20,28 +16,56 @@ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
20
  HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
21
  VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
22
 
23
- # check if model is set correctly
24
- if HUGGINGFACE_BASE_URL and MODEL:
25
- raise ValueError(
26
- "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
27
- )
28
- if not MODEL:
29
- if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
30
- raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
31
-
32
- # Check if multiple base URLs are provided
33
- base_urls = [
34
- url
35
- for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
36
- if url
 
 
37
  ]
38
- if len(base_urls) > 1:
39
- raise ValueError(
40
- f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
41
- )
42
- BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # API Keys
46
  HF_TOKEN = os.getenv("HF_TOKEN")
47
  if not HF_TOKEN:
 
3
 
4
  import argilla as rg
5
 
 
 
 
 
6
  # Inference
7
  MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
 
16
  HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
17
  VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
18
 
19
+ # Just used in case of selecting a different model for completions
20
+ MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
21
+ TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
22
+ OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
23
+ OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
24
+ HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
25
+ "HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
26
+ )
27
+ VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)
28
+
29
+ base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
30
+ base_urls_completion = [
31
+ OPENAI_BASE_URL_COMPLETION,
32
+ OLLAMA_BASE_URL_COMPLETION,
33
+ HUGGINGFACE_BASE_URL_COMPLETION,
34
+ VLLM_BASE_URL_COMPLETION,
35
  ]
 
 
 
 
 
36
 
37
 
38
+ # Validate the configuration of the model and base URLs.
39
+ def validate_configuration(base_urls, model, env_context=""):
40
+ huggingface_url = base_urls[2]
41
+ if huggingface_url and model:
42
+ raise ValueError(
43
+ f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
44
+ "Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
45
+ )
46
+
47
+ if not model and any(base_urls):
48
+ raise ValueError(
49
+ f"`MODEL{env_context}` is not set. Please provide a model id for inference."
50
+ )
51
+
52
+ active_urls = [url for url in base_urls if url]
53
+ if len(active_urls) > 1:
54
+ raise ValueError(
55
+ f"Multiple base URLs are provided: {', '.join(active_urls)}. "
56
+ "Only one base URL can be set at a time."
57
+ )
58
+ validate_configuration(base_urls, MODEL)
59
+ validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")
60
+
61
+ BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
62
+ BASE_URL_COMPLETION = (
63
+ OPENAI_BASE_URL_COMPLETION
64
+ or OLLAMA_BASE_URL_COMPLETION
65
+ or HUGGINGFACE_BASE_URL_COMPLETION
66
+ or VLLM_BASE_URL_COMPLETION
67
+ )
68
+
69
  # API Keys
70
  HF_TOKEN = os.getenv("HF_TOKEN")
71
  if not HF_TOKEN:
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -8,11 +8,17 @@ from synthetic_dataset_generator.constants import (
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
 
11
  MODEL,
 
12
  OLLAMA_BASE_URL,
 
13
  OPENAI_BASE_URL,
 
14
  TOKENIZER_ID,
 
15
  VLLM_BASE_URL,
 
16
  )
17
 
18
  TOKEN_INDEX = 0
@@ -73,12 +79,20 @@ def _get_llm_class() -> str:
73
  return "InferenceEndpointsLLM"
74
 
75
 
76
- def _get_llm(use_magpie_template=False, **kwargs):
 
 
 
 
 
 
 
77
  if OPENAI_BASE_URL:
78
  llm = OpenAILLM(
79
- model=MODEL,
80
- base_url=OPENAI_BASE_URL,
81
  api_key=_get_next_api_key(),
 
82
  **kwargs,
83
  )
84
  if "generation_kwargs" in kwargs:
@@ -108,19 +122,25 @@ def _get_llm(use_magpie_template=False, **kwargs):
108
  kwargs["generation_kwargs"] = {}
109
  kwargs["generation_kwargs"]["options"] = options
110
  llm = OllamaLLM(
111
- model=MODEL,
112
- host=OLLAMA_BASE_URL,
113
- tokenizer_id=TOKENIZER_ID or MODEL,
114
  use_magpie_template=use_magpie_template,
 
115
  **kwargs,
116
  )
117
  elif HUGGINGFACE_BASE_URL:
118
  kwargs["generation_kwargs"]["do_sample"] = True
119
  llm = InferenceEndpointsLLM(
120
  api_key=_get_next_api_key(),
121
- base_url=HUGGINGFACE_BASE_URL,
122
- tokenizer_id=TOKENIZER_ID or MODEL,
 
 
 
 
123
  use_magpie_template=use_magpie_template,
 
124
  **kwargs,
125
  )
126
  elif VLLM_BASE_URL:
@@ -128,19 +148,21 @@ def _get_llm(use_magpie_template=False, **kwargs):
128
  if "do_sample" in kwargs["generation_kwargs"]:
129
  del kwargs["generation_kwargs"]["do_sample"]
130
  llm = ClientvLLM(
131
- base_url=VLLM_BASE_URL,
132
- model=MODEL,
133
- tokenizer=TOKENIZER_ID or MODEL,
134
  api_key=_get_next_api_key(),
135
  use_magpie_template=use_magpie_template,
 
136
  **kwargs,
137
  )
138
  else:
139
  llm = InferenceEndpointsLLM(
140
  api_key=_get_next_api_key(),
141
- tokenizer_id=TOKENIZER_ID or MODEL,
142
- model_id=MODEL,
143
  use_magpie_template=use_magpie_template,
 
144
  **kwargs,
145
  )
146
 
 
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
11
+ HUGGINGFACE_BASE_URL_COMPLETION,
12
  MODEL,
13
+ MODEL_COMPLETION,
14
  OLLAMA_BASE_URL,
15
+ OLLAMA_BASE_URL_COMPLETION,
16
  OPENAI_BASE_URL,
17
+ OPENAI_BASE_URL_COMPLETION,
18
  TOKENIZER_ID,
19
+ TOKENIZER_ID_COMPLETION,
20
  VLLM_BASE_URL,
21
+ VLLM_BASE_URL_COMPLETION,
22
  )
23
 
24
  TOKEN_INDEX = 0
 
79
  return "InferenceEndpointsLLM"
80
 
81
 
82
+ def _get_llm(
83
+ structured_output: dict = None,
84
+ use_magpie_template: str = False,
85
+ is_completion: bool = False,
86
+ **kwargs,
87
+ ):
88
+ model = MODEL_COMPLETION if is_completion else MODEL
89
+ tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
  if OPENAI_BASE_URL:
91
  llm = OpenAILLM(
92
+ model=model,
93
+ base_url=OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
94
  api_key=_get_next_api_key(),
95
+ structured_output=structured_output,
96
  **kwargs,
97
  )
98
  if "generation_kwargs" in kwargs:
 
122
  kwargs["generation_kwargs"] = {}
123
  kwargs["generation_kwargs"]["options"] = options
124
  llm = OllamaLLM(
125
+ model=model,
126
+ host=OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
127
+ tokenizer_id=tokenizer_id,
128
  use_magpie_template=use_magpie_template,
129
+ structured_output=structured_output,
130
  **kwargs,
131
  )
132
  elif HUGGINGFACE_BASE_URL:
133
  kwargs["generation_kwargs"]["do_sample"] = True
134
  llm = InferenceEndpointsLLM(
135
  api_key=_get_next_api_key(),
136
+ base_url=(
137
+ HUGGINGFACE_BASE_URL_COMPLETION
138
+ if is_completion
139
+ else HUGGINGFACE_BASE_URL
140
+ ),
141
+ tokenizer_id=tokenizer_id,
142
  use_magpie_template=use_magpie_template,
143
+ structured_output=structured_output,
144
  **kwargs,
145
  )
146
  elif VLLM_BASE_URL:
 
148
  if "do_sample" in kwargs["generation_kwargs"]:
149
  del kwargs["generation_kwargs"]["do_sample"]
150
  llm = ClientvLLM(
151
+ base_url=VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
152
+ model=model,
153
+ tokenizer=tokenizer_id,
154
  api_key=_get_next_api_key(),
155
  use_magpie_template=use_magpie_template,
156
+ structured_output=structured_output,
157
  **kwargs,
158
  )
159
  else:
160
  llm = InferenceEndpointsLLM(
161
  api_key=_get_next_api_key(),
162
+ tokenizer_id=tokenizer_id,
163
+ model_id=model,
164
  use_magpie_template=use_magpie_template,
165
+ structured_output=structured_output,
166
  **kwargs,
167
  )
168
 
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -245,7 +245,7 @@ def get_response_generator(
245
  "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
246
  }
247
  response_generator = TextGeneration(
248
- llm=_get_llm(generation_kwargs=generation_kwargs),
249
  system_prompt=system_prompt,
250
  output_mappings={"generation": "completion"},
251
  input_mappings={"instruction": "prompt"},
@@ -256,7 +256,7 @@ def get_response_generator(
256
  "max_new_tokens": MAX_NUM_TOKENS,
257
  }
258
  response_generator = ChatGeneration(
259
- llm=_get_llm(generation_kwargs=generation_kwargs),
260
  output_mappings={"generation": "completion"},
261
  input_mappings={"conversation": "messages"},
262
  )
@@ -281,7 +281,7 @@ def get_follow_up_generator(type: str, temperature: float, is_sample: bool):
281
  "max_new_tokens": MAX_NUM_TOKENS,
282
  }
283
  follow_up_generator = ChatGeneration(
284
- llm=_get_llm(generation_kwargs=generation_kwargs),
285
  )
286
  follow_up_generator.load()
287
  return follow_up_generator
@@ -336,7 +336,7 @@ def generate_pipeline_code_seed(
336
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
337
  from distilabel.models import {_get_llm_class()}
338
  from distilabel.pipeline import Pipeline
339
- from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}
340
  from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
341
  """
342
 
@@ -455,10 +455,10 @@ with Pipeline(name="sft") as pipeline:
455
  keep_columns = KeepColumns(columns=["messages"])
456
  """
457
  code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
458
-
459
  for i in range(1, num_turns + 1):
460
  code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
461
-
462
  code += " >> keep_columns"
463
 
464
  code += """
 
245
  "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
246
  }
247
  response_generator = TextGeneration(
248
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
249
  system_prompt=system_prompt,
250
  output_mappings={"generation": "completion"},
251
  input_mappings={"instruction": "prompt"},
 
256
  "max_new_tokens": MAX_NUM_TOKENS,
257
  }
258
  response_generator = ChatGeneration(
259
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
260
  output_mappings={"generation": "completion"},
261
  input_mappings={"conversation": "messages"},
262
  )
 
281
  "max_new_tokens": MAX_NUM_TOKENS,
282
  }
283
  follow_up_generator = ChatGeneration(
284
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
285
  )
286
  follow_up_generator.load()
287
  return follow_up_generator
 
336
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
337
  from distilabel.models import {_get_llm_class()}
338
  from distilabel.pipeline import Pipeline
339
+ from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}{", StepInput, step" if num_turns > 1 else ""}
340
  from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
341
  """
342
 
 
455
  keep_columns = KeepColumns(columns=["messages"])
456
  """
457
  code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
458
+
459
  for i in range(1, num_turns + 1):
460
  code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
461
+
462
  code += " >> keep_columns"
463
 
464
  code += """
src/synthetic_dataset_generator/pipelines/rag.py CHANGED
@@ -121,7 +121,7 @@ def get_response_generator(temperature: float, is_sample: bool):
121
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
122
  }
123
  text_generator = TextGeneration(
124
- llm=_get_llm(generation_kwargs=generation_kwargs),
125
  system_prompt=SYSTEM_PROMPT_RAG,
126
  template=RAG_TEMPLATE,
127
  columns=["context", "question"],
 
121
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
122
  }
123
  text_generator = TextGeneration(
124
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
125
  system_prompt=SYSTEM_PROMPT_RAG,
126
  template=RAG_TEMPLATE,
127
  columns=["context", "question"],
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -109,7 +109,7 @@ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: b
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
- llm = _get_llm(generation_kwargs=generation_kwargs)
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,
 
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
+ llm = _get_llm(is_completion=True, generation_kwargs=generation_kwargs)
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,