Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,9 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Auto
|
|
14 |
import os
|
15 |
import colorsys
|
16 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
17 |
|
18 |
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
|
19 |
hex_color = hex_color.lstrip('#')
|
@@ -86,26 +89,32 @@ def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str,
|
|
86 |
score = prediction1 / (prediction2 + prediction1)
|
87 |
|
88 |
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
|
89 |
-
|
90 |
-
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[
|
91 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
92 |
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
|
110 |
return fig1, fig2
|
111 |
|
|
|
14 |
import os
|
15 |
import colorsys
|
16 |
import matplotlib.pyplot as plt
|
17 |
+
import plotly.graph_objects as go
|
18 |
+
from typing import Tuple
|
19 |
+
import plotly.io as pio
|
20 |
|
21 |
def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
|
22 |
hex_color = hex_color.lstrip('#')
|
|
|
89 |
score = prediction1 / (prediction2 + prediction1)
|
90 |
|
91 |
return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
|
92 |
+
|
93 |
+
def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure]:
|
94 |
entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
|
95 |
entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
|
96 |
|
97 |
+
# Counting entities for binary classification
|
98 |
+
entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
|
99 |
+
bin_labels = list(entity_counts_bin.keys())
|
100 |
+
bin_sizes = list(entity_counts_bin.values())
|
101 |
|
102 |
+
# Counting entities for extended classification
|
103 |
+
entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
|
104 |
+
ext_labels = list(entity_counts_ext.keys())
|
105 |
+
ext_sizes = list(entity_counts_ext.values())
|
106 |
+
|
107 |
+
# Create pie chart for extended classification
|
108 |
+
fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3)])
|
109 |
+
fig1.update_layout(title_text='Extended Sequence Classification Subclasses')
|
110 |
|
111 |
+
# Create bar chart for binary classification
|
112 |
+
fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes)])
|
113 |
+
fig2.update_layout(
|
114 |
+
title='Binary Sequence Classification Classes',
|
115 |
+
xaxis_title='Entity Type',
|
116 |
+
yaxis_title='Count'
|
117 |
+
)
|
118 |
|
119 |
return fig1, fig2
|
120 |
|