import random from gliner import GLiNER import gradio as gr from datasets import load_dataset # Load the BL dataset as a streaming iterator dataset_iter = load_dataset( "TheBritishLibrary/blbooks", split="train", streaming=True, # Enable streaming trust_remote_code=True ).shuffle(seed=42) # Shuffle added # Load the model model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1", trust_remote_code=True) def ner(text: str, labels: str, threshold: float, nested_ner: bool): # Convert user-provided labels (comma-separated string) into a list labels_list = [label.strip() for label in labels.split(",")] # Truncate the text to avoid length exceeding model limits (e.g., 384 tokens) max_length = 384 truncated_text = text[:max_length] # Predict entities using the GLiNER model entities = model.predict_entities(truncated_text, labels_list, flat_ner=not nested_ner, threshold=threshold) # Prepare entities for color-coded display using gr.HighlightedText highlights = [{"start": ent["start"], "end": ent["end"], "entity": ent["label"]} for ent in entities] # Return both the highlighted text and the raw entities in JSON format return { "text": truncated_text, "entities": highlights }, entities # Return both outputs: the first for HighlightedText, the second for JSON with gr.Blocks(title="General NER with Color-Coded Output") as demo: gr.Markdown( """ # GLiNER British Library Books Demo This demo selects a random text snippet from the British Library's books dataset and identifies entities using GLiNER (urchade/gliner_multi-v2.1). """ ) # Display a random example input_text = gr.Textbox( value="Click on 'Get New Snippet' to load a piece of text from the British Library dataset", label="Text input", placeholder="Enter your text here", lines=5 ) refresh_btn = gr.Button("Get New Snippet") with gr.Row() as row: labels = gr.Textbox( value="Person, Location", # Default example labels label="Labels", placeholder="Enter your labels here (comma separated)", scale=2, ) threshold = gr.Slider( 0, 1, value=0.5, # Adjusted to match the threshold used in the function step=0.01, label="Threshold", info="Lower the threshold to increase how many entities get predicted.", scale=1, ) nested_ner = gr.Checkbox( value=False, label="Nested NER", info="Enable Nested NER?", ) submit_btn = gr.Button("Find Entities!") # Define output components using HighlightedText for color-coded display output_highlighted = gr.HighlightedText(label="Predicted Entities") output_entities = gr.JSON(label="Entities") def get_new_snippet(): # Preload several samples into a list max_length = 384 # Maximum length for snippets samples = [ sample['text'][:max_length] for sample, _ in zip(dataset_iter, range(100)) # Truncate to max_length ] # Return a random snippet from the preloaded samples if samples: return random.choice(samples) return "No more snippets available." # Return this if no valid snippets are found # Connect refresh button refresh_btn.click(fn=get_new_snippet, outputs=input_text) # Connect submit button submit_btn.click( fn=ner, inputs=[input_text, labels, threshold, nested_ner], outputs=[output_highlighted, output_entities] ) demo.queue() demo.launch(debug=True)