Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
|
|
5 |
|
6 |
st.set_page_config(layout="wide")
|
7 |
|
8 |
-
# Disable auto-rerun when selecting options or typing
|
9 |
st.stop_rerun = True
|
10 |
|
11 |
def fill_mask(sentences):
|
@@ -47,33 +46,29 @@ with col1:
|
|
47 |
|
48 |
input1, input2 = st.columns(2)
|
49 |
|
50 |
-
# Loop to gather input sentences
|
51 |
for i in range(5):
|
52 |
with input1:
|
53 |
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
|
54 |
with input2:
|
55 |
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
|
56 |
if sentence:
|
57 |
-
# Use a unique key for each sentence (even if languages are the same)
|
58 |
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
|
59 |
|
60 |
button1, button2, _ = st.columns([2, 2, 4])
|
61 |
|
62 |
-
# Call fill_mask on button click, not on form input
|
63 |
if st.button("Test Example"):
|
64 |
sample_sentences = {
|
65 |
-
'
|
66 |
-
'
|
67 |
-
'
|
68 |
-
'
|
69 |
-
'
|
70 |
}
|
71 |
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
|
72 |
|
73 |
if st.button("Submit"):
|
74 |
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
|
75 |
|
76 |
-
# Display warnings
|
77 |
if st.session_state['warnings']:
|
78 |
for warning in st.session_state['warnings']:
|
79 |
st.warning(warning)
|
@@ -90,7 +85,6 @@ with col1:
|
|
90 |
with col2:
|
91 |
with st.container():
|
92 |
st.markdown("Output :bar_chart:")
|
93 |
-
# Check for the result in session_state and display predictions
|
94 |
if st.session_state['result']:
|
95 |
for key, (language, predictions) in st.session_state['result'].items():
|
96 |
original_sentence = input_sentences[key][1] if key in input_sentences else ""
|
@@ -109,8 +103,12 @@ with col2:
|
|
109 |
</div>
|
110 |
""", unsafe_allow_html=True)
|
111 |
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
114 |
|
115 |
css = """
|
116 |
<style>
|
@@ -137,4 +135,4 @@ footer {display:none !important;}
|
|
137 |
}
|
138 |
</style>
|
139 |
"""
|
140 |
-
st.markdown(css, unsafe_allow_html=True)
|
|
|
5 |
|
6 |
st.set_page_config(layout="wide")
|
7 |
|
|
|
8 |
st.stop_rerun = True
|
9 |
|
10 |
def fill_mask(sentences):
|
|
|
46 |
|
47 |
input1, input2 = st.columns(2)
|
48 |
|
|
|
49 |
for i in range(5):
|
50 |
with input1:
|
51 |
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
|
52 |
with input2:
|
53 |
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
|
54 |
if sentence:
|
|
|
55 |
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
|
56 |
|
57 |
button1, button2, _ = st.columns([2, 2, 4])
|
58 |
|
|
|
59 |
if st.button("Test Example"):
|
60 |
sample_sentences = {
|
61 |
+
'zulu': "Le ndoda ithi izo <mask> ukudla.",
|
62 |
+
'tshivenda': "Mufana uyo <mask> vhukuma.",
|
63 |
+
'sepedi': "Mosadi o <mask> pheka.",
|
64 |
+
'tswana': "Monna o <mask> tsamaya.",
|
65 |
+
'tsonga': "N'wana wa xisati u <mask> ku tsaka."
|
66 |
}
|
67 |
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
|
68 |
|
69 |
if st.button("Submit"):
|
70 |
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
|
71 |
|
|
|
72 |
if st.session_state['warnings']:
|
73 |
for warning in st.session_state['warnings']:
|
74 |
st.warning(warning)
|
|
|
85 |
with col2:
|
86 |
with st.container():
|
87 |
st.markdown("Output :bar_chart:")
|
|
|
88 |
if st.session_state['result']:
|
89 |
for key, (language, predictions) in st.session_state['result'].items():
|
90 |
original_sentence = input_sentences[key][1] if key in input_sentences else ""
|
|
|
103 |
</div>
|
104 |
""", unsafe_allow_html=True)
|
105 |
|
106 |
+
if 'predictions' in locals():
|
107 |
+
if result:
|
108 |
+
for language, language_predictions in result.items():
|
109 |
+
original_sentence = sample_sentence[language]
|
110 |
+
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str'])
|
111 |
+
st.write(f"{language}: {predicted_sentence}\n")
|
112 |
|
113 |
css = """
|
114 |
<style>
|
|
|
135 |
}
|
136 |
</style>
|
137 |
"""
|
138 |
+
st.markdown(css, unsafe_allow_html=True)
|