max-long commited on
Commit
f9bc688
·
verified ·
1 Parent(s): 28b0e98

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from gliner import GLiNER
3
+ import gradio as gr
4
+ from datasets import load_dataset
5
+
6
+ # Load the BL dataset as a streaming iterator
7
+ dataset_iter = load_dataset(
8
+ "TheBritishLibrary/blbooks",
9
+ split="train",
10
+ streaming=True, # Enable streaming
11
+ trust_remote_code=True
12
+ ).shuffle(seed=42) # Shuffle added
13
+
14
+ # Load the model
15
+ model = GLiNER.from_pretrained("max-long/textile_machines_3_oct", trust_remote_code=True)
16
+
17
+ def ner(text: str, labels: str, threshold: float):
18
+ # Convert user-provided labels (comma-separated string) into a list
19
+ labels_list = [label.strip() for label in labels.split(",")]
20
+
21
+ # Predict entities using the fine-tuned GLiNER model
22
+ entities = model.predict_entities(text, labels_list, flat_ner=True, threshold=threshold)
23
+
24
+ # Prepare data for HighlightedText
25
+ highlighted_text = text
26
+ for ent in sorted(entities, key=lambda x: x['start'], reverse=True):
27
+ highlighted_text = (
28
+ highlighted_text[:ent['start']] +
29
+ f"<span style='background-color: yellow; font-weight: bold;'>{highlighted_text[ent['start']:ent['end']]}</span>" +
30
+ highlighted_text[ent['end']:]
31
+ )
32
+
33
+ return highlighted_text, entities
34
+
35
+ with gr.Blocks(title="General NER Demo") as demo:
36
+ gr.Markdown(
37
+ """
38
+ # General Entity Recognition Demo
39
+ This demo selects a random text snippet from a subset of the British Library's books dataset and identifies entities using a fine-tuned GLiNER model. You can specify the entities you want to find.
40
+ """
41
+ )
42
+
43
+ # Display a random example
44
+ input_text = gr.Textbox(
45
+ value="The machine is fed by means of an endless apron, the wool entering at the smaller end...",
46
+ label="Text input",
47
+ placeholder="Enter your text here",
48
+ lines=5
49
+ )
50
+
51
+ with gr.Row() as row:
52
+ labels = gr.Textbox(
53
+ value="Machine, Wool", # Default example labels
54
+ label="Labels",
55
+ placeholder="Enter your labels here (comma separated)",
56
+ scale=2,
57
+ )
58
+ threshold = gr.Slider(
59
+ 0,
60
+ 1,
61
+ value=0.5, # Adjusted to match the threshold used in the function
62
+ step=0.01,
63
+ label="Threshold",
64
+ info="Lower the threshold to increase how many entities get predicted.",
65
+ scale=1,
66
+ )
67
+
68
+ # Define output components
69
+ output_highlighted = gr.HTML(label="Predicted Entities")
70
+ output_entities = gr.JSON(label="Entities")
71
+
72
+ submit_btn = gr.Button("Find Entities!")
73
+ refresh_btn = gr.Button("Get New Snippet")
74
+
75
+ def get_new_snippet():
76
+ attempts = 0
77
+ max_attempts = 1000 # Prevent infinite loops
78
+ for sample in dataset_iter:
79
+ return sample['text']
80
+ return "No more snippets available." # Return this if no valid snippets are found
81
+
82
+ # Connect refresh button
83
+ refresh_btn.click(fn=get_new_snippet, outputs=input_text)
84
+
85
+ # Connect submit button
86
+ submit_btn.click(
87
+ fn=lambda text, labels, threshold: ner(text, labels, threshold),
88
+ inputs=[input_text, labels, threshold],
89
+ outputs=[output_highlighted, output_entities]
90
+ )
91
+
92
+ demo.queue()
93
+ demo.launch(debug=True)