Update app.py
Browse files
app.py
CHANGED
@@ -3,72 +3,111 @@ from transformers import pipeline
|
|
3 |
|
4 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
5 |
|
6 |
-
|
7 |
-
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
8 |
-
'tshivenda': "Mufana uyo <mask> vhukuma.",
|
9 |
-
'sepedi': "Mosadi o <mask> pheka.",
|
10 |
-
'tswana': "Monna o <mask> tsamaya.",
|
11 |
-
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
12 |
-
}
|
13 |
|
14 |
-
def
|
15 |
results = {}
|
|
|
16 |
for language, sentence in sentences.items():
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
|
22 |
def replace_mask(sentence, predicted_word):
|
23 |
return sentence.replace("<mask>", f"**{predicted_word}**")
|
24 |
|
25 |
-
st.title("Fill Mask| Zabantu-XLM-Roberta")
|
26 |
st.write(f"")
|
27 |
|
|
|
|
|
28 |
col1, col2 = st.columns(2)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
|
38 |
with col2:
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
<div class="
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
</div>
|
61 |
-
""", unsafe_allow_html=True)
|
62 |
|
63 |
if 'predictions' in locals():
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
73 |
css = """
|
74 |
<style>
|
@@ -135,6 +174,7 @@ footer {display:none !important;}
|
|
135 |
--tw-text-opacity: 1 !important;
|
136 |
color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
|
137 |
}
|
|
|
138 |
.container {
|
139 |
display: flex;
|
140 |
justify-content: space-between;
|
@@ -143,7 +183,7 @@ footer {display:none !important;}
|
|
143 |
width: 100%;
|
144 |
}
|
145 |
.bar {
|
146 |
-
width: 70%;
|
147 |
background-color: #e6e6e6;
|
148 |
border-radius: 12px;
|
149 |
overflow: hidden;
|
@@ -155,6 +195,8 @@ footer {display:none !important;}
|
|
155 |
height: 100%;
|
156 |
border-radius: 12px;
|
157 |
}
|
|
|
158 |
</style>
|
159 |
"""
|
160 |
-
|
|
|
|
3 |
|
4 |
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
5 |
|
6 |
+
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
def fill_mask(sentences):
|
9 |
results = {}
|
10 |
+
warnings = []
|
11 |
for language, sentence in sentences.items():
|
12 |
+
if "<mask>" in sentence:
|
13 |
+
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
14 |
+
unmasked = unmasker(masked_sentence)
|
15 |
+
results[language] = unmasked
|
16 |
+
else:
|
17 |
+
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
|
18 |
+
return results, warnings
|
19 |
|
20 |
def replace_mask(sentence, predicted_word):
|
21 |
return sentence.replace("<mask>", f"**{predicted_word}**")
|
22 |
|
23 |
+
st.title("Fill Mask | Zabantu-XLM-Roberta")
|
24 |
st.write(f"")
|
25 |
|
26 |
+
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. These include: Zabantu-VEN, Zabantu-NSO, Zabantu-NSO+VEN, Zabantu-SOT+VEN, Zabantu-BANTU(from 9 South African Bantu languages)")
|
27 |
+
|
28 |
col1, col2 = st.columns(2)
|
29 |
|
30 |
+
if 'text_input' not in st.session_state:
|
31 |
+
st.session_state['text_input'] = ""
|
32 |
+
|
33 |
+
if 'warnings' not in st.session_state:
|
34 |
+
st.session_state['warnings'] = []
|
35 |
+
|
36 |
+
with col1:
|
37 |
+
with st.container(border=True):
|
38 |
+
st.markdown("Input :clipboard:")
|
39 |
+
sample_sentence = {
|
40 |
+
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
41 |
+
'tshivenda': "Mufana uyo <mask> vhukuma.",
|
42 |
+
'sepedi': "Mosadi o <mask> pheka.",
|
43 |
+
'tswana': "Monna o <mask> tsamaya.",
|
44 |
+
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
45 |
+
}
|
46 |
+
|
47 |
+
text_input = st.text_area(
|
48 |
+
"Enter sentences with <mask> token:",
|
49 |
+
value=st.session_state['text_input']
|
50 |
+
)
|
51 |
+
|
52 |
+
input_sentences = text_input.split("\n")
|
53 |
+
|
54 |
+
button1, button2, _ = st.columns([2, 2, 4])
|
55 |
+
with button1:
|
56 |
+
if st.button("Test Example"):
|
57 |
+
user_sentence = f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items()
|
58 |
+
user_masked_sentence = user_sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
59 |
+
# st.rerun()
|
60 |
+
# result, warnings = fill_mask(sample_sentence.split("\n"))
|
61 |
+
# st.session_state['text_input'] = sample_sentence
|
62 |
+
|
63 |
+
with button2:
|
64 |
+
if st.button("Submit"):
|
65 |
+
user_masked_sentence = input_sentences.replace('<mask>', unmasker.tokenizer.mask_token)
|
66 |
+
# result, warnings = fill_mask(input_sentences)
|
67 |
+
# st.session_state['warnings'] = warnings
|
68 |
+
|
69 |
+
if st.session_state['warnings']:
|
70 |
+
for warning in st.session_state['warnings']:
|
71 |
+
st.warning(warning)
|
72 |
|
73 |
+
st.markdown("Example")
|
74 |
+
st.code(sample_sentence, wrap_lines=True)
|
75 |
|
76 |
with col2:
|
77 |
+
with st.container(border=True):
|
78 |
+
st.markdown("Output :bar_chart:")
|
79 |
+
if 'user_masked_sentence' in locals():
|
80 |
+
if user_masked_sentence:
|
81 |
+
user_predictions = unmasker(user_masked_sentence)
|
82 |
+
|
83 |
+
# st.write(user_predictions)
|
84 |
|
85 |
+
if len(user_predictions) > 0:
|
86 |
+
# st.write(f"Top prediction for the masked token: {user_predictions[0]['sequence']}")
|
87 |
+
|
88 |
+
predictions = fill_mask_for_languages(sample_sentences)
|
89 |
+
for language, language_predictions in predictions.items():
|
90 |
+
predicted_word = language_predictions[0]['token_str']
|
91 |
+
score = language_predictions[0]['score'] * 100
|
92 |
+
|
93 |
+
st.markdown(f"""
|
94 |
+
<div class="bar">
|
95 |
+
<div class="bar-fill" style="width: {score}%;"></div>
|
96 |
+
</div>
|
97 |
+
<div class="container">
|
98 |
+
<div style="align-items: left;">{predicted_word}({language})</div>
|
99 |
+
<div style="align-items: right;">{score:.2f}%</div>
|
100 |
+
</div>
|
101 |
+
""", unsafe_allow_html=True)
|
|
|
|
|
102 |
|
103 |
if 'predictions' in locals():
|
104 |
+
if predictions:
|
105 |
+
for language, language_predictions in predictions.items():
|
106 |
+
original_sentence = sample_sentences[language]
|
107 |
+
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
|
108 |
+
# st.write(language_predictions)
|
109 |
+
# st.write(f"Original sentence ({language}): {original_sentence}")
|
110 |
+
st.write(f"{language}: {predicted_sentence}\n")
|
|
|
111 |
|
112 |
css = """
|
113 |
<style>
|
|
|
174 |
--tw-text-opacity: 1 !important;
|
175 |
color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
|
176 |
}
|
177 |
+
|
178 |
.container {
|
179 |
display: flex;
|
180 |
justify-content: space-between;
|
|
|
183 |
width: 100%;
|
184 |
}
|
185 |
.bar {
|
186 |
+
# width: 70%;
|
187 |
background-color: #e6e6e6;
|
188 |
border-radius: 12px;
|
189 |
overflow: hidden;
|
|
|
195 |
height: 100%;
|
196 |
border-radius: 12px;
|
197 |
}
|
198 |
+
|
199 |
</style>
|
200 |
"""
|
201 |
+
|
202 |
+
st.markdown(css, unsafe_allow_html=True)
|