File size: 9,842 Bytes
f5f8f9a 15eff6a f5f8f9a 0802504 cff14a8 4d5270d f5f8f9a 4d5270d f5f8f9a 4d5270d b90e99a 8d1b776 4199132 b90e99a 39352aa 2cf67bf 39352aa b90e99a 4d5270d 0ce811e 2cf67bf 4d5270d 2cf67bf 4d5270d f5f8f9a 7a621b0 b92b795 7a621b0 4d5270d cff14a8 15eff6a cff14a8 4d5270d b51c864 d104ff1 cff14a8 5b5eeef cff14a8 4d5270d 39352aa cff14a8 fa59459 cff14a8 0716ef3 cff14a8 1192cb6 65213e3 1192cb6 cff14a8 1192cb6 b055af8 cff14a8 539b8f0 1192cb6 a575f88 1192cb6 cff14a8 1192cb6 79dbb07 cff14a8 08d27cb cff14a8 818719c cff14a8 f8f19a2 cff14a8 f8f19a2 cff14a8 9ba3728 b51c864 cff14a8 2cf67bf cff14a8 50ef43a cff14a8 15eff6a cff14a8 2cf67bf cff14a8 f5f8f9a b51c864 cff14a8 f5f8f9a cff14a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import streamlit as st
from transformers import pipeline
from io import StringIO
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')
st.set_page_config(layout="wide")
def fill_mask(sentences):
results = {}
warnings = []
# warnings.append(f"= {sentences.items()}")
for key, (language, sentence) in sentences.items():
if language == 'choose language':
warnings.append(f"Warning: Choose language for {sentence}")
continue
if language != 'choose language' and sentence == "":
warnings.append(f"Warning: Enter sentence for {language}")
continue
if "<mask>" in sentence:
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
unmasked = unmasker(masked_sentence)
results[key] = (unmasked,language,sentence)
else:
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
return results, warnings
def replace_mask(sentence, predicted_word):
return sentence.replace("<mask>", f"**{predicted_word}**")
st.title("Fill Mask | Zabantu-XLM-Roberta")
st.write(f"")
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages. It supports the following languages Tshivenda, Nguni languages (Zulu, Xhosa, Swati), Sotho languages (Northern Sotho, Southern Sotho, Setswana), and Xitsonga.")
col1, col2 = st.columns(2)
if 'text_input' not in st.session_state:
st.session_state['text_input'] = ""
if 'warnings' not in st.session_state:
st.session_state['warnings'] = []
input_sentences = {}
with col1:
with st.container(border=True):
st.markdown("Input :clipboard:")
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)']
sample_sentence = {'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima.",
"tsonga": "N'wana wa xisati u <mask> ku tsaka."
}
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
option_selected = st.selectbox(f"Select an input option:", select_options, index=0)
if option_selected == 'Enter text input':
st.session_state['warnings'].clear()
@st.fragment
def choose_language(i):
language = st.selectbox(f"Select language for input {i+1}:",
language_options, key=f'language_{i}', index=0)
return language
input1, input2 = st.columns(2)
for i in range(5):
with input1:
language = choose_language(i)
# st.write(f"lang : {language}")
with input2:
sentence = st.text_input(f"Enter sentence for input {i+1} (with <mask>):", key=f'text_input_{i}')
if sentence:
if language:
input_sentences[f'{i+1}'] = (language.lower(), sentence)
else:
warnings = []
warnings.append(f"Warning: Choose the language for input {i+1}")
st.session_state['warnings'] = warnings
if st.button("Submit",use_container_width=True):
result, warnings = fill_mask(input_sentences)
st.session_state['warnings'] = warnings
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.session_state['warnings'].clear()
if option_selected == 'Upload a file(csv/txt)':
uploaded_file = st.file_uploader("Choose a file-(one sentence per line)")
if uploaded_file is not None:
warnings = []
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
string_data = stringio.read()
sentences = string_data.split("\n")
i = 0
for sentence in sentences:
i += 1
if ":" in sentence:
splitted = sentence.split(":")
language = splitted[0]
sentence_mask = splitted[1]
input_sentences[f'{i}'] = (language.lower(), sentence)
else:
warnings.append(f"Warning: No ':' token found in sentence: {sentence} in line {i}")
if st.button("Submit",use_container_width=True):
result, warnings = fill_mask(input_sentences)
st.session_state['warnings'] = warnings
if st.session_state['warnings']:
for warning in st.session_state['warnings']:
st.warning(warning)
st.session_state['warnings'].clear()
st.markdown("Example")
st.code(sample_sentence, wrap_lines=True)
if st.button("Test Example",use_container_width=True):
result, warnings = fill_mask(sample_sentence)
with col2:
with st.container(border=True):
st.markdown("Output :bar_chart:")
if 'result' in locals() and result:
if len(result) == 1:
for key,(predictions, language, sentence) in result.items():
for prediction in predictions:
predicted_word = prediction['token_str']
score = prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word}</div>
<div style="align-items: center;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
else:
for key,(predictions, language, sentence) in result.items():
if predictions:
top_prediction = predictions[0]
predicted_word = top_prediction['token_str']
score = top_prediction['score'] * 100
st.markdown(f"""
<div class="bar">
<div class="bar-fill" style="width: {score}%;"></div>
</div>
<div class="container">
<div style="align-items: left;">{predicted_word} ({language})</div>
<div style="align-items: right;">{score:.2f}%</div>
</div>
""", unsafe_allow_html=True)
if 'result' in locals():
if result:
line = 0
for key,(predictions, language, sentence) in result.items():
line += 1
predicted_word = predictions[0]['token_str']
full_sentence = replace_mask(sentence, predicted_word)
st.write(f"**Sentence {line}:** {full_sentence }")
css = """
<style>
footer {display:none !important;}
.gr-button-primary {
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.gr-button-primary:hover{
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(66, 133, 244) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
}
.hover\:bg-orange-50:hover {
--tw-bg-opacity: 1 !important;
background-color: rgb(229,225,255) !important;
}
.to-orange-200 {
--tw-gradient-to: rgb(37 56 133 / 37%) !important;
}
.from-orange-400 {
--tw-gradient-from: rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(255 150 51 / 0);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group-hover\:from-orange-500{
--tw-gradient-from:rgb(17, 20, 45) !important;
--tw-gradient-to: rgb(37 56 133 / 37%);
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important;
}
.group:hover .group-hover\:text-orange-500{
--tw-text-opacity: 1 !important;
color:rgb(37 56 133 / var(--tw-text-opacity)) !important;
}
.container {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 5px;
width: 100%;
}
.bar {
# width: 70%;
background-color: #e6e6e6;
border-radius: 12px;
overflow: hidden;
margin-right: 10px;
height: 5px;
}
.bar-fill {
background-color: #17152e;
height: 100%;
border-radius: 12px;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True) |