refactor: add local save to README and improve layout
Browse files
README.md
CHANGED
@@ -104,6 +104,10 @@ Optionally, you can also push your datasets to Argilla for further curation by s
|
|
104 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
105 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
106 |
|
|
|
|
|
|
|
|
|
107 |
### Argilla integration
|
108 |
|
109 |
Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
|
|
104 |
- `ARGILLA_API_KEY`: Your Argilla API key to push your datasets to Argilla.
|
105 |
- `ARGILLA_API_URL`: Your Argilla API URL to push your datasets to Argilla.
|
106 |
|
107 |
+
To save the generated datasets to a local directory instead of pushing them to the Hugging Face Hub, set the following environment variable:
|
108 |
+
|
109 |
+
- `SAVE_LOCAL_DIR`: The local directory to save the generated datasets to.
|
110 |
+
|
111 |
### Argilla integration
|
112 |
|
113 |
Argilla is an open source tool for data curation. It allows you to annotate and review datasets, and push curated datasets to the Hugging Face Hub. You can easily get started with Argilla by following the [quickstart guide](https://docs.argilla.io/latest/getting_started/quickstart/).
|
src/synthetic_dataset_generator/apps/chat.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import ast
|
2 |
import json
|
|
|
3 |
import random
|
4 |
import uuid
|
5 |
-
import os
|
6 |
from typing import Dict, List, Union
|
7 |
|
8 |
import argilla as rg
|
@@ -30,8 +30,8 @@ from synthetic_dataset_generator.constants import (
|
|
30 |
DEFAULT_BATCH_SIZE,
|
31 |
MODEL,
|
32 |
MODEL_COMPLETION,
|
33 |
-
SFT_AVAILABLE,
|
34 |
SAVE_LOCAL_DIR,
|
|
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
37 |
from synthetic_dataset_generator.pipelines.chat import (
|
@@ -266,6 +266,7 @@ def generate_dataset_from_prompt(
|
|
266 |
progress(1.0, desc="Dataset generation completed")
|
267 |
return dataframe
|
268 |
|
|
|
269 |
def generate_dataset_from_seed(
|
270 |
dataframe: pd.DataFrame,
|
271 |
document_column: str,
|
@@ -369,7 +370,9 @@ def generate_dataset_from_seed(
|
|
369 |
follow_up_instructions = list(
|
370 |
follow_up_generator_instruction.process(inputs=conversations_batch)
|
371 |
)
|
372 |
-
for conv, follow_up in zip(
|
|
|
|
|
373 |
conv["messages"].append(
|
374 |
{"role": "user", "content": follow_up["generation"]}
|
375 |
)
|
@@ -667,7 +670,7 @@ def save_local(
|
|
667 |
num_turns=num_turns,
|
668 |
num_rows=num_rows,
|
669 |
temperature=temperature,
|
670 |
-
temperature_completion=temperature_completion
|
671 |
)
|
672 |
local_dataset = Dataset.from_pandas(dataframe)
|
673 |
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
@@ -710,12 +713,30 @@ def hide_pipeline_code_visibility():
|
|
710 |
def show_temperature_completion():
|
711 |
if MODEL != MODEL_COMPLETION:
|
712 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
def show_save_local():
|
|
|
715 |
return {
|
716 |
-
btn_save_local: gr.Button(visible=True),
|
717 |
csv_file: gr.File(visible=True),
|
718 |
-
json_file: gr.File(visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
719 |
}
|
720 |
|
721 |
|
@@ -902,12 +923,20 @@ with gr.Blocks() as app:
|
|
902 |
btn_save_local = gr.Button(
|
903 |
"Save locally", variant="primary", scale=2, visible=False
|
904 |
)
|
905 |
-
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
906 |
-
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
907 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
908 |
success_message = gr.Markdown(
|
909 |
-
visible=
|
910 |
-
min_height=
|
911 |
)
|
912 |
with gr.Accordion(
|
913 |
"Customize your pipeline with distilabel",
|
@@ -1005,6 +1034,9 @@ with gr.Blocks() as app:
|
|
1005 |
fn=validate_push_to_hub,
|
1006 |
inputs=[org_name, repo_name],
|
1007 |
outputs=[success_message],
|
|
|
|
|
|
|
1008 |
).success(
|
1009 |
fn=hide_success_message,
|
1010 |
outputs=[success_message],
|
@@ -1050,8 +1082,19 @@ with gr.Blocks() as app:
|
|
1050 |
inputs=[],
|
1051 |
outputs=[pipeline_code_ui],
|
1052 |
)
|
1053 |
-
|
1054 |
btn_save_local.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
save_local,
|
1056 |
inputs=[
|
1057 |
search_in,
|
@@ -1065,7 +1108,22 @@ with gr.Blocks() as app:
|
|
1065 |
repo_name,
|
1066 |
temperature_completion,
|
1067 |
],
|
1068 |
-
outputs=[csv_file, json_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1069 |
)
|
1070 |
|
1071 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
@@ -1081,4 +1139,4 @@ with gr.Blocks() as app:
|
|
1081 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1082 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
1083 |
if SAVE_LOCAL_DIR is not None:
|
1084 |
-
app.load(fn=
|
|
|
1 |
import ast
|
2 |
import json
|
3 |
+
import os
|
4 |
import random
|
5 |
import uuid
|
|
|
6 |
from typing import Dict, List, Union
|
7 |
|
8 |
import argilla as rg
|
|
|
30 |
DEFAULT_BATCH_SIZE,
|
31 |
MODEL,
|
32 |
MODEL_COMPLETION,
|
|
|
33 |
SAVE_LOCAL_DIR,
|
34 |
+
SFT_AVAILABLE,
|
35 |
)
|
36 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
37 |
from synthetic_dataset_generator.pipelines.chat import (
|
|
|
266 |
progress(1.0, desc="Dataset generation completed")
|
267 |
return dataframe
|
268 |
|
269 |
+
|
270 |
def generate_dataset_from_seed(
|
271 |
dataframe: pd.DataFrame,
|
272 |
document_column: str,
|
|
|
370 |
follow_up_instructions = list(
|
371 |
follow_up_generator_instruction.process(inputs=conversations_batch)
|
372 |
)
|
373 |
+
for conv, follow_up in zip(
|
374 |
+
conversations_batch, follow_up_instructions[0]
|
375 |
+
):
|
376 |
conv["messages"].append(
|
377 |
{"role": "user", "content": follow_up["generation"]}
|
378 |
)
|
|
|
670 |
num_turns=num_turns,
|
671 |
num_rows=num_rows,
|
672 |
temperature=temperature,
|
673 |
+
temperature_completion=temperature_completion,
|
674 |
)
|
675 |
local_dataset = Dataset.from_pandas(dataframe)
|
676 |
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
|
|
|
713 |
def show_temperature_completion():
|
714 |
if MODEL != MODEL_COMPLETION:
|
715 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
716 |
+
|
717 |
+
|
718 |
+
def show_save_local_button():
|
719 |
+
return {btn_save_local: gr.Button(visible=True)}
|
720 |
+
|
721 |
+
|
722 |
+
def hide_save_local_button():
|
723 |
+
return {btn_save_local: gr.Button(visible=False)}
|
724 |
+
|
725 |
+
|
726 |
def show_save_local():
|
727 |
+
gr.update(success_message, min_height=0)
|
728 |
return {
|
|
|
729 |
csv_file: gr.File(visible=True),
|
730 |
+
json_file: gr.File(visible=True),
|
731 |
+
success_message: success_message
|
732 |
+
}
|
733 |
+
|
734 |
+
def hide_save_local():
|
735 |
+
gr.update(success_message, min_height=100)
|
736 |
+
return {
|
737 |
+
csv_file: gr.File(visible=False),
|
738 |
+
json_file: gr.File(visible=False),
|
739 |
+
success_message: success_message,
|
740 |
}
|
741 |
|
742 |
|
|
|
923 |
btn_save_local = gr.Button(
|
924 |
"Save locally", variant="primary", scale=2, visible=False
|
925 |
)
|
|
|
|
|
926 |
with gr.Column(scale=3):
|
927 |
+
csv_file = gr.File(
|
928 |
+
label="CSV",
|
929 |
+
elem_classes="datasets",
|
930 |
+
visible=False,
|
931 |
+
)
|
932 |
+
json_file = gr.File(
|
933 |
+
label="JSON",
|
934 |
+
elem_classes="datasets",
|
935 |
+
visible=False,
|
936 |
+
)
|
937 |
success_message = gr.Markdown(
|
938 |
+
visible=False,
|
939 |
+
min_height=0 # don't remove this otherwise progress is not visible
|
940 |
)
|
941 |
with gr.Accordion(
|
942 |
"Customize your pipeline with distilabel",
|
|
|
1034 |
fn=validate_push_to_hub,
|
1035 |
inputs=[org_name, repo_name],
|
1036 |
outputs=[success_message],
|
1037 |
+
).success(
|
1038 |
+
fn=hide_save_local,
|
1039 |
+
outputs=[csv_file, json_file, success_message],
|
1040 |
).success(
|
1041 |
fn=hide_success_message,
|
1042 |
outputs=[success_message],
|
|
|
1082 |
inputs=[],
|
1083 |
outputs=[pipeline_code_ui],
|
1084 |
)
|
1085 |
+
|
1086 |
btn_save_local.click(
|
1087 |
+
fn=hide_success_message,
|
1088 |
+
outputs=[success_message],
|
1089 |
+
).success(
|
1090 |
+
fn=hide_pipeline_code_visibility,
|
1091 |
+
inputs=[],
|
1092 |
+
outputs=[pipeline_code_ui],
|
1093 |
+
).success(
|
1094 |
+
fn=show_save_local,
|
1095 |
+
inputs=[],
|
1096 |
+
outputs=[csv_file, json_file, success_message],
|
1097 |
+
).success(
|
1098 |
save_local,
|
1099 |
inputs=[
|
1100 |
search_in,
|
|
|
1108 |
repo_name,
|
1109 |
temperature_completion,
|
1110 |
],
|
1111 |
+
outputs=[csv_file, json_file],
|
1112 |
+
).success(
|
1113 |
+
fn=generate_pipeline_code,
|
1114 |
+
inputs=[
|
1115 |
+
search_in,
|
1116 |
+
input_type,
|
1117 |
+
system_prompt,
|
1118 |
+
document_column,
|
1119 |
+
num_turns,
|
1120 |
+
num_rows,
|
1121 |
+
],
|
1122 |
+
outputs=[pipeline_code],
|
1123 |
+
).success(
|
1124 |
+
fn=show_pipeline_code_visibility,
|
1125 |
+
inputs=[],
|
1126 |
+
outputs=[pipeline_code_ui],
|
1127 |
)
|
1128 |
|
1129 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
|
|
1139 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
1140 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
1141 |
if SAVE_LOCAL_DIR is not None:
|
1142 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/apps/rag.py
CHANGED
@@ -24,7 +24,12 @@ from synthetic_dataset_generator.apps.base import (
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
-
from synthetic_dataset_generator.constants import
|
|
|
|
|
|
|
|
|
|
|
28 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
29 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
30 |
get_embeddings,
|
@@ -156,7 +161,7 @@ def generate_dataset(
|
|
156 |
is_sample=is_sample,
|
157 |
)
|
158 |
response_generator = get_response_generator(
|
159 |
-
temperature
|
160 |
)
|
161 |
if reranking:
|
162 |
reranking_generator = get_sentence_pair_generator(
|
@@ -564,11 +569,29 @@ def show_temperature_completion():
|
|
564 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
565 |
|
566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
567 |
def show_save_local():
|
|
|
568 |
return {
|
569 |
-
btn_save_local: gr.Button(visible=True),
|
570 |
csv_file: gr.File(visible=True),
|
571 |
-
json_file: gr.File(visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
}
|
573 |
|
574 |
|
@@ -725,18 +748,24 @@ with gr.Blocks() as app:
|
|
725 |
interactive=True,
|
726 |
scale=1,
|
727 |
)
|
728 |
-
btn_push_to_hub = gr.Button(
|
729 |
-
"Push to Hub", variant="primary", scale=2
|
730 |
-
)
|
731 |
btn_save_local = gr.Button(
|
732 |
"Save locally", variant="primary", scale=2, visible=False
|
733 |
)
|
734 |
-
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
735 |
-
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
736 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
737 |
success_message = gr.Markdown(
|
738 |
-
visible=
|
739 |
-
min_height=
|
740 |
)
|
741 |
with gr.Accordion(
|
742 |
"Customize your pipeline with distilabel",
|
@@ -834,6 +863,9 @@ with gr.Blocks() as app:
|
|
834 |
fn=validate_push_to_hub,
|
835 |
inputs=[org_name, repo_name],
|
836 |
outputs=[success_message],
|
|
|
|
|
|
|
837 |
).success(
|
838 |
fn=hide_success_message,
|
839 |
outputs=[success_message],
|
@@ -881,6 +913,17 @@ with gr.Blocks() as app:
|
|
881 |
)
|
882 |
|
883 |
btn_save_local.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
884 |
save_local,
|
885 |
inputs=[
|
886 |
search_in,
|
@@ -894,7 +937,22 @@ with gr.Blocks() as app:
|
|
894 |
repo_name,
|
895 |
temperature_completion,
|
896 |
],
|
897 |
-
outputs=[csv_file, json_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
898 |
)
|
899 |
|
900 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
@@ -911,4 +969,4 @@ with gr.Blocks() as app:
|
|
911 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
912 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
913 |
if SAVE_LOCAL_DIR is not None:
|
914 |
-
app.load(fn=
|
|
|
24 |
validate_argilla_user_workspace_dataset,
|
25 |
validate_push_to_hub,
|
26 |
)
|
27 |
+
from synthetic_dataset_generator.constants import (
|
28 |
+
DEFAULT_BATCH_SIZE,
|
29 |
+
MODEL,
|
30 |
+
MODEL_COMPLETION,
|
31 |
+
SAVE_LOCAL_DIR,
|
32 |
+
)
|
33 |
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
|
34 |
from synthetic_dataset_generator.pipelines.embeddings import (
|
35 |
get_embeddings,
|
|
|
161 |
is_sample=is_sample,
|
162 |
)
|
163 |
response_generator = get_response_generator(
|
164 |
+
temperature=temperature_completion or temperature, is_sample=is_sample
|
165 |
)
|
166 |
if reranking:
|
167 |
reranking_generator = get_sentence_pair_generator(
|
|
|
569 |
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
|
570 |
|
571 |
|
572 |
+
def show_save_local_button():
|
573 |
+
return {btn_save_local: gr.Button(visible=True)}
|
574 |
+
|
575 |
+
|
576 |
+
def hide_save_local_button():
|
577 |
+
return {btn_save_local: gr.Button(visible=False)}
|
578 |
+
|
579 |
+
|
580 |
def show_save_local():
|
581 |
+
gr.update(success_message, min_height=0)
|
582 |
return {
|
|
|
583 |
csv_file: gr.File(visible=True),
|
584 |
+
json_file: gr.File(visible=True),
|
585 |
+
success_message: success_message,
|
586 |
+
}
|
587 |
+
|
588 |
+
|
589 |
+
def hide_save_local():
|
590 |
+
gr.update(success_message, min_height=100)
|
591 |
+
return {
|
592 |
+
csv_file: gr.File(visible=False),
|
593 |
+
json_file: gr.File(visible=False),
|
594 |
+
success_message: success_message,
|
595 |
}
|
596 |
|
597 |
|
|
|
748 |
interactive=True,
|
749 |
scale=1,
|
750 |
)
|
751 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
|
|
|
|
752 |
btn_save_local = gr.Button(
|
753 |
"Save locally", variant="primary", scale=2, visible=False
|
754 |
)
|
|
|
|
|
755 |
with gr.Column(scale=3):
|
756 |
+
csv_file = gr.File(
|
757 |
+
label="CSV",
|
758 |
+
elem_classes="datasets",
|
759 |
+
visible=False,
|
760 |
+
)
|
761 |
+
json_file = gr.File(
|
762 |
+
label="JSON",
|
763 |
+
elem_classes="datasets",
|
764 |
+
visible=False,
|
765 |
+
)
|
766 |
success_message = gr.Markdown(
|
767 |
+
visible=False,
|
768 |
+
min_height=0, # don't remove this otherwise progress is not visible
|
769 |
)
|
770 |
with gr.Accordion(
|
771 |
"Customize your pipeline with distilabel",
|
|
|
863 |
fn=validate_push_to_hub,
|
864 |
inputs=[org_name, repo_name],
|
865 |
outputs=[success_message],
|
866 |
+
).success(
|
867 |
+
fn=hide_save_local,
|
868 |
+
outputs=[csv_file, json_file, success_message],
|
869 |
).success(
|
870 |
fn=hide_success_message,
|
871 |
outputs=[success_message],
|
|
|
913 |
)
|
914 |
|
915 |
btn_save_local.click(
|
916 |
+
fn=hide_success_message,
|
917 |
+
outputs=[success_message],
|
918 |
+
).success(
|
919 |
+
fn=hide_pipeline_code_visibility,
|
920 |
+
inputs=[],
|
921 |
+
outputs=[pipeline_code_ui],
|
922 |
+
).success(
|
923 |
+
fn=show_save_local,
|
924 |
+
inputs=[],
|
925 |
+
outputs=[csv_file, json_file, success_message],
|
926 |
+
).success(
|
927 |
save_local,
|
928 |
inputs=[
|
929 |
search_in,
|
|
|
937 |
repo_name,
|
938 |
temperature_completion,
|
939 |
],
|
940 |
+
outputs=[csv_file, json_file],
|
941 |
+
).success(
|
942 |
+
fn=generate_pipeline_code,
|
943 |
+
inputs=[
|
944 |
+
search_in,
|
945 |
+
input_type,
|
946 |
+
system_prompt,
|
947 |
+
document_column,
|
948 |
+
retrieval_reranking,
|
949 |
+
num_rows,
|
950 |
+
],
|
951 |
+
outputs=[pipeline_code],
|
952 |
+
).success(
|
953 |
+
fn=show_pipeline_code_visibility,
|
954 |
+
inputs=[],
|
955 |
+
outputs=[pipeline_code_ui],
|
956 |
)
|
957 |
|
958 |
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
|
|
|
969 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
970 |
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
|
971 |
if SAVE_LOCAL_DIR is not None:
|
972 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
import os
|
2 |
import json
|
|
|
3 |
import random
|
4 |
import uuid
|
5 |
from typing import List, Union
|
@@ -453,11 +453,29 @@ def hide_pipeline_code_visibility():
|
|
453 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
454 |
|
455 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
def show_save_local():
|
|
|
457 |
return {
|
458 |
-
btn_save_local: gr.Button(visible=True),
|
459 |
csv_file: gr.File(visible=True),
|
460 |
-
json_file: gr.File(visible=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
}
|
462 |
|
463 |
|
@@ -579,18 +597,24 @@ with gr.Blocks() as app:
|
|
579 |
interactive=True,
|
580 |
scale=1,
|
581 |
)
|
582 |
-
btn_push_to_hub = gr.Button(
|
583 |
-
"Push to Hub", variant="primary", scale=2
|
584 |
-
)
|
585 |
btn_save_local = gr.Button(
|
586 |
"Save locally", variant="primary", scale=2, visible=False
|
587 |
)
|
588 |
-
csv_file = gr.File(label="CSV", elem_classes="datasets", visible=False)
|
589 |
-
json_file = gr.File(label="JSON", elem_classes="datasets", visible=False)
|
590 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
success_message = gr.Markdown(
|
592 |
-
visible=
|
593 |
-
min_height=
|
594 |
)
|
595 |
with gr.Accordion(
|
596 |
"Customize your pipeline with distilabel",
|
@@ -643,6 +667,9 @@ with gr.Blocks() as app:
|
|
643 |
fn=validate_input_labels,
|
644 |
inputs=[labels],
|
645 |
outputs=[labels],
|
|
|
|
|
|
|
646 |
).success(
|
647 |
fn=hide_success_message,
|
648 |
outputs=[success_message],
|
@@ -686,8 +713,19 @@ with gr.Blocks() as app:
|
|
686 |
inputs=[],
|
687 |
outputs=[pipeline_code_ui],
|
688 |
)
|
689 |
-
|
690 |
btn_save_local.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
save_local,
|
692 |
inputs=[
|
693 |
system_prompt,
|
@@ -699,7 +737,22 @@ with gr.Blocks() as app:
|
|
699 |
temperature,
|
700 |
repo_name,
|
701 |
],
|
702 |
-
outputs=[csv_file, json_file]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
703 |
)
|
704 |
|
705 |
gr.on(
|
@@ -719,4 +772,4 @@ with gr.Blocks() as app:
|
|
719 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
720 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
721 |
if SAVE_LOCAL_DIR is not None:
|
722 |
-
app.load(fn=
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
import random
|
4 |
import uuid
|
5 |
from typing import List, Union
|
|
|
453 |
return {pipeline_code_ui: gr.Accordion(visible=False)}
|
454 |
|
455 |
|
456 |
+
def show_save_local_button():
|
457 |
+
return {btn_save_local: gr.Button(visible=True)}
|
458 |
+
|
459 |
+
|
460 |
+
def hide_save_local_button():
|
461 |
+
return {btn_save_local: gr.Button(visible=False)}
|
462 |
+
|
463 |
+
|
464 |
def show_save_local():
|
465 |
+
gr.update(success_message, min_height=0)
|
466 |
return {
|
|
|
467 |
csv_file: gr.File(visible=True),
|
468 |
+
json_file: gr.File(visible=True),
|
469 |
+
success_message: success_message,
|
470 |
+
}
|
471 |
+
|
472 |
+
|
473 |
+
def hide_save_local():
|
474 |
+
gr.update(success_message, min_height=100)
|
475 |
+
return {
|
476 |
+
csv_file: gr.File(visible=False),
|
477 |
+
json_file: gr.File(visible=False),
|
478 |
+
success_message: success_message,
|
479 |
}
|
480 |
|
481 |
|
|
|
597 |
interactive=True,
|
598 |
scale=1,
|
599 |
)
|
600 |
+
btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
|
|
|
|
|
601 |
btn_save_local = gr.Button(
|
602 |
"Save locally", variant="primary", scale=2, visible=False
|
603 |
)
|
|
|
|
|
604 |
with gr.Column(scale=3):
|
605 |
+
csv_file = gr.File(
|
606 |
+
label="CSV",
|
607 |
+
elem_classes="datasets",
|
608 |
+
visible=False,
|
609 |
+
)
|
610 |
+
json_file = gr.File(
|
611 |
+
label="JSON",
|
612 |
+
elem_classes="datasets",
|
613 |
+
visible=False,
|
614 |
+
)
|
615 |
success_message = gr.Markdown(
|
616 |
+
visible=False,
|
617 |
+
min_height=0, # don't remove this otherwise progress is not visible
|
618 |
)
|
619 |
with gr.Accordion(
|
620 |
"Customize your pipeline with distilabel",
|
|
|
667 |
fn=validate_input_labels,
|
668 |
inputs=[labels],
|
669 |
outputs=[labels],
|
670 |
+
).success(
|
671 |
+
fn=hide_save_local,
|
672 |
+
outputs=[csv_file, json_file, success_message],
|
673 |
).success(
|
674 |
fn=hide_success_message,
|
675 |
outputs=[success_message],
|
|
|
713 |
inputs=[],
|
714 |
outputs=[pipeline_code_ui],
|
715 |
)
|
716 |
+
|
717 |
btn_save_local.click(
|
718 |
+
fn=hide_success_message,
|
719 |
+
outputs=[success_message],
|
720 |
+
).success(
|
721 |
+
fn=hide_pipeline_code_visibility,
|
722 |
+
inputs=[],
|
723 |
+
outputs=[pipeline_code_ui],
|
724 |
+
).success(
|
725 |
+
fn=show_save_local,
|
726 |
+
inputs=[],
|
727 |
+
outputs=[csv_file, json_file, success_message],
|
728 |
+
).success(
|
729 |
save_local,
|
730 |
inputs=[
|
731 |
system_prompt,
|
|
|
737 |
temperature,
|
738 |
repo_name,
|
739 |
],
|
740 |
+
outputs=[csv_file, json_file],
|
741 |
+
).success(
|
742 |
+
fn=generate_pipeline_code,
|
743 |
+
inputs=[
|
744 |
+
system_prompt,
|
745 |
+
difficulty,
|
746 |
+
clarity,
|
747 |
+
labels,
|
748 |
+
multi_label,
|
749 |
+
num_rows,
|
750 |
+
],
|
751 |
+
outputs=[pipeline_code],
|
752 |
+
).success(
|
753 |
+
fn=show_pipeline_code_visibility,
|
754 |
+
inputs=[],
|
755 |
+
outputs=[pipeline_code_ui],
|
756 |
)
|
757 |
|
758 |
gr.on(
|
|
|
772 |
app.load(fn=get_org_dropdown, outputs=[org_name])
|
773 |
app.load(fn=get_random_repo_name, outputs=[repo_name])
|
774 |
if SAVE_LOCAL_DIR is not None:
|
775 |
+
app.load(fn=show_save_local_button, outputs=btn_save_local)
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -8,7 +8,7 @@ MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
11 |
-
# Directory
|
12 |
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
|
13 |
|
14 |
# Models
|
|
|
8 |
MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
|
9 |
DEFAULT_BATCH_SIZE = int(os.getenv("DEFAULT_BATCH_SIZE", 5))
|
10 |
|
11 |
+
# Directory to locally save the generated data
|
12 |
SAVE_LOCAL_DIR = os.getenv(key="SAVE_LOCAL_DIR", default=None)
|
13 |
|
14 |
# Models
|