import random from gliner import GLiNER import gradio as gr from datasets import load_dataset # Load the BL dataset as a streaming iterator dataset_iter = load_dataset( "TheBritishLibrary/blbooks", split="train", streaming=True, # Enable streaming trust_remote_code=True ).shuffle(seed=42) # Shuffle added # Load the model model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True) def ner(text: str, labels: str, threshold: float): # Convert user-provided labels (comma-separated string) into a list labels_list = [label.strip() for label in labels.split(",")] # Predict entities using the fine-tuned GLiNER model entities = model.predict_entities(text, labels_list, flat_ner=True, threshold=threshold) # Prepare data for HighlightedText highlighted_text = text for ent in sorted(entities, key=lambda x: x['start'], reverse=True): highlighted_text = ( highlighted_text[:ent['start']] + f"{highlighted_text[ent['start']:ent['end']]}" + highlighted_text[ent['end']:] ) return highlighted_text, entities with gr.Blocks(title="General NER Demo") as demo: gr.Markdown( """ # General Entity Recognition Demo This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find. """ ) # Display a random example input_text = gr.Textbox( value="The machine is fed by means of an endless apron, the wool entering at the smaller end...", label="Text input", placeholder="Enter your text here", lines=5 ) with gr.Row() as row: labels = gr.Textbox( value="Machine, Wool", # Default example labels label="Labels", placeholder="Enter your labels here (comma separated)", scale=2, ) threshold = gr.Slider( 0, 1, value=0.5, # Adjusted to match the threshold used in the function step=0.01, label="Threshold", info="Lower the threshold to increase how many entities get predicted.", scale=1, ) # Define output components output_highlighted = gr.HTML(label="Predicted Entities") output_entities = gr.JSON(label="Entities") submit_btn = gr.Button("Find Entities!") refresh_btn = gr.Button("Get New Snippet") def get_new_snippet(): attempts = 0 max_attempts = 1000 # Prevent infinite loops for sample in dataset_iter: return sample['text'] return "No more snippets available." # Return this if no valid snippets are found # Connect refresh button refresh_btn.click(fn=get_new_snippet, outputs=input_text) # Connect submit button submit_btn.click( fn=lambda text, labels, threshold: ner(text, labels, threshold), inputs=[input_text, labels, threshold], outputs=[output_highlighted, output_entities] ) demo.queue() demo.launch(debug=True)