AlGe commited on
Commit
886193f
·
verified ·
1 Parent(s): 6b0ab1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -14,6 +14,9 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Auto
14
  import os
15
  import colorsys
16
  import matplotlib.pyplot as plt
 
 
 
17
 
18
  def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
19
  hex_color = hex_color.lstrip('#')
@@ -86,26 +89,32 @@ def process_classification(text: str, model1, model2, tokenizer1) -> Tuple[str,
86
  score = prediction1 / (prediction2 + prediction1)
87
 
88
  return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
89
-
90
- def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[plt.Figure, plt.Figure]:
91
  entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
92
  entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
93
 
94
- all_entities = entities_bin + entities_ext
95
- entity_counts = {entity: all_entities.count(entity) for entity in set(all_entities)}
 
 
96
 
97
- pie_labels = list(entity_counts.keys())
98
- pie_sizes = list(entity_counts.values())
99
-
100
- fig1, ax1 = plt.subplots()
101
- ax1.pie(pie_sizes, labels=pie_labels, autopct='%1.1f%%', startangle=90)
102
- ax1.axis('equal')
 
 
103
 
104
- fig2, ax2 = plt.subplots()
105
- ax2.bar(entity_counts.keys(), entity_counts.values())
106
- ax2.set_ylabel('Count')
107
- ax2.set_xlabel('Entity Type')
108
- ax2.set_title('Entity Counts')
 
 
109
 
110
  return fig1, fig2
111
 
 
14
  import os
15
  import colorsys
16
  import matplotlib.pyplot as plt
17
+ import plotly.graph_objects as go
18
+ from typing import Tuple
19
+ import plotly.io as pio
20
 
21
  def hex_to_rgb(hex_color: str) -> tuple[int, int, int]:
22
  hex_color = hex_color.lstrip('#')
 
89
  score = prediction1 / (prediction2 + prediction1)
90
 
91
  return f"{round(prediction1, 1)}", f"{round(prediction2, 1)}", f"{round(score, 2)}"
92
+
93
+ def generate_charts(ner_output_bin: dict, ner_output_ext: dict) -> Tuple[go.Figure, go.Figure]:
94
  entities_bin = [entity['entity'] for entity in ner_output_bin['entities']]
95
  entities_ext = [entity['entity'] for entity in ner_output_ext['entities']]
96
 
97
+ # Counting entities for binary classification
98
+ entity_counts_bin = {entity: entities_bin.count(entity) for entity in set(entities_bin)}
99
+ bin_labels = list(entity_counts_bin.keys())
100
+ bin_sizes = list(entity_counts_bin.values())
101
 
102
+ # Counting entities for extended classification
103
+ entity_counts_ext = {entity: entities_ext.count(entity) for entity in set(entities_ext)}
104
+ ext_labels = list(entity_counts_ext.keys())
105
+ ext_sizes = list(entity_counts_ext.values())
106
+
107
+ # Create pie chart for extended classification
108
+ fig1 = go.Figure(data=[go.Pie(labels=ext_labels, values=ext_sizes, textinfo='label+percent', hole=.3)])
109
+ fig1.update_layout(title_text='Extended Sequence Classification Subclasses')
110
 
111
+ # Create bar chart for binary classification
112
+ fig2 = go.Figure(data=[go.Bar(x=bin_labels, y=bin_sizes)])
113
+ fig2.update_layout(
114
+ title='Binary Sequence Classification Classes',
115
+ xaxis_title='Entity Type',
116
+ yaxis_title='Count'
117
+ )
118
 
119
  return fig1, fig2
120