UnarineLeo
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,17 @@ unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
|
5 |
|
6 |
st.set_page_config(layout="wide")
|
7 |
|
|
|
|
|
|
|
8 |
def fill_mask(sentences):
|
9 |
results = {}
|
10 |
warnings = []
|
11 |
-
for language, sentence in sentences.items():
|
12 |
if "<mask>" in sentence:
|
13 |
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
14 |
unmasked = unmasker(masked_sentence)
|
15 |
-
results[
|
16 |
else:
|
17 |
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
|
18 |
return results, warnings
|
@@ -25,52 +28,52 @@ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combi
|
|
25 |
|
26 |
col1, col2 = st.columns(2)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
if 'warnings' not in st.session_state:
|
32 |
st.session_state['warnings'] = []
|
33 |
|
|
|
|
|
|
|
34 |
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
|
35 |
|
|
|
|
|
36 |
with col1:
|
37 |
with st.container():
|
38 |
st.markdown("Input :clipboard:")
|
39 |
|
40 |
input1, input2 = st.columns(2)
|
41 |
|
42 |
-
|
43 |
for i in range(5):
|
44 |
with input1:
|
45 |
-
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
|
46 |
with input2:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
input_sentences[language.lower()] = sentence
|
52 |
|
53 |
button1, button2, _ = st.columns([2, 2, 4])
|
54 |
-
|
55 |
-
with button1:
|
56 |
-
if st.button("Test Example"):
|
57 |
-
sample_sentence = {
|
58 |
-
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
59 |
-
'tshivenda': "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis.",
|
60 |
-
'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima",
|
61 |
-
'tswana': "Monna o <mask> tsamaya.",
|
62 |
-
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
63 |
-
}
|
64 |
-
input_sentences = sample_sentence
|
65 |
-
result, warnings = fill_mask(input_sentences)
|
66 |
-
|
67 |
-
with button2:
|
68 |
-
# Set session state when "Submit" is clicked
|
69 |
-
if st.button("Submit"):
|
70 |
-
st.session_state['submit_clicked'] = True
|
71 |
-
result, warnings = fill_mask(input_sentences)
|
72 |
-
st.session_state['warnings'] = warnings
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
if st.session_state['warnings']:
|
75 |
for warning in st.session_state['warnings']:
|
76 |
st.warning(warning)
|
@@ -87,13 +90,10 @@ with col1:
|
|
87 |
with col2:
|
88 |
with st.container():
|
89 |
st.markdown("Output :bar_chart:")
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
95 |
-
predictions = unmasker(masked_sentence)
|
96 |
-
|
97 |
if predictions:
|
98 |
top_prediction = predictions[0]
|
99 |
predicted_word = top_prediction['token_str']
|
@@ -109,7 +109,9 @@ with col2:
|
|
109 |
</div>
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
-
|
|
|
|
|
113 |
css = """
|
114 |
<style>
|
115 |
footer {display:none !important;}
|
|
|
5 |
|
6 |
st.set_page_config(layout="wide")
|
7 |
|
8 |
+
# Disable auto-rerun when selecting options or typing
|
9 |
+
st.stop_rerun = True
|
10 |
+
|
11 |
def fill_mask(sentences):
|
12 |
results = {}
|
13 |
warnings = []
|
14 |
+
for key, (language, sentence) in sentences.items():
|
15 |
if "<mask>" in sentence:
|
16 |
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
|
17 |
unmasked = unmasker(masked_sentence)
|
18 |
+
results[key] = (language, unmasked)
|
19 |
else:
|
20 |
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
|
21 |
return results, warnings
|
|
|
28 |
|
29 |
col1, col2 = st.columns(2)
|
30 |
|
31 |
+
if 'text_input' not in st.session_state:
|
32 |
+
st.session_state['text_input'] = ""
|
33 |
+
|
34 |
if 'warnings' not in st.session_state:
|
35 |
st.session_state['warnings'] = []
|
36 |
|
37 |
+
if 'result' not in st.session_state:
|
38 |
+
st.session_state['result'] = {}
|
39 |
+
|
40 |
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
|
41 |
|
42 |
+
input_sentences = {}
|
43 |
+
|
44 |
with col1:
|
45 |
with st.container():
|
46 |
st.markdown("Input :clipboard:")
|
47 |
|
48 |
input1, input2 = st.columns(2)
|
49 |
|
50 |
+
# Loop to gather input sentences
|
51 |
for i in range(5):
|
52 |
with input1:
|
53 |
+
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
|
54 |
with input2:
|
55 |
+
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
|
56 |
+
if sentence:
|
57 |
+
# Use a unique key for each sentence (even if languages are the same)
|
58 |
+
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
|
|
|
59 |
|
60 |
button1, button2, _ = st.columns([2, 2, 4])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# Call fill_mask on button click, not on form input
|
63 |
+
if st.button("Test Example"):
|
64 |
+
sample_sentences = {
|
65 |
+
'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
|
66 |
+
'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
|
67 |
+
'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
|
68 |
+
'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
|
69 |
+
'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
|
70 |
+
}
|
71 |
+
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
|
72 |
+
|
73 |
+
if st.button("Submit"):
|
74 |
+
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
|
75 |
+
|
76 |
+
# Display warnings
|
77 |
if st.session_state['warnings']:
|
78 |
for warning in st.session_state['warnings']:
|
79 |
st.warning(warning)
|
|
|
90 |
with col2:
|
91 |
with st.container():
|
92 |
st.markdown("Output :bar_chart:")
|
93 |
+
# Check for the result in session_state and display predictions
|
94 |
+
if st.session_state['result']:
|
95 |
+
for key, (language, predictions) in st.session_state['result'].items():
|
96 |
+
original_sentence = input_sentences[key][1] if key in input_sentences else ""
|
|
|
|
|
|
|
97 |
if predictions:
|
98 |
top_prediction = predictions[0]
|
99 |
predicted_word = top_prediction['token_str']
|
|
|
109 |
</div>
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
+
predicted_sentence = replace_mask(original_sentence, predicted_word)
|
113 |
+
st.write(f"{language}: {predicted_sentence}\n")
|
114 |
+
|
115 |
css = """
|
116 |
<style>
|
117 |
footer {display:none !important;}
|