File size: 2,406 Bytes
2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf b01e61c 7198503 b01e61c 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf 7198503 2e605bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.hugging_face import (
HuggingFaceSeq2SeqGenerator,
HuggingFaceGenerationAlgorithm
)
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def run_inference(
model_name_or_path: str,
prefix: str,
prompt: str,
num_beams: int,
):
config = HuggingFaceSeq2SeqGenerator(
algorithm_version=model_name_or_path,
prefix=prefix,
prompt=prompt,
num_beams=num_beams
)
model = HuggingFaceGenerationAlgorithm(config)
tokenizer = AutoTokenizer.from_pretrained("t5-small")
text = list(model.sample(1))[0]
text = text.replace(prefix+prompt,"")
text = text.split(tokenizer.eos_token)[0]
text = text.replace(tokenizer.pad_token, "")
text = text.strip()
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
models = ["text-chem-t5-small-standard", "text-chem-t5-small-augm",
"text-chem-t5-base-standard", "text-chem-t5-base-augm"]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(metadata_root.joinpath("examples.csv"), header=None).fillna(
""
)
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=run_inference,
title="Text-chem-T5 model",
inputs=[
gr.Dropdown(
models,
label="Language model",
value="text-chem-t5-base-augm",
),
gr.Textbox(
label="Prefix", placeholder="A task-specific prefix", lines=1
),
gr.Textbox(
label="Text prompt",
placeholder="I'm a stochastic parrot.",
lines=1,
),
gr.Slider(minimum=1, maximum=50, value=10, label="num_beams", step=1),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|