|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') |
|
|
|
sample_sentences = { |
|
'Zulu': "Le ndoda ithi izo____ ukudla.", |
|
'Tshivenda': "Mufana uyo____ vhukuma.", |
|
'Sepedi': "Mosadi o ____ pheka.", |
|
'Tswana': "Monna o ____ tsamaya.", |
|
'Tsonga': "N'wana wa xisati u ____ ku tsaka." |
|
} |
|
|
|
def fill_mask_for_languages(sentences): |
|
results = {} |
|
for language, sentence in sentences.items(): |
|
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token) |
|
unmasked = unmasker(masked_sentence) |
|
results[language] = unmasked |
|
return results |
|
|
|
def replace_mask(sentence, predicted_word): |
|
return sentence.replace("____", f"**predicted_word**") |
|
|
|
st.title("Fill Mask for Multiple Languages | Zabantu-XLM-Roberta") |
|
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.") |
|
st.write(f"") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
user_sentence = st.text_area("Enter your own sentence with a masked word (use '____'):", "\n".join( |
|
f"'{lang}': '{sentence}'," for lang, sentence in sample_sentences.items() |
|
)) |
|
|
|
if st.button("Submit"): |
|
user_masked_sentence = user_sentence.replace('____', unmasker.tokenizer.mask_token) |
|
|
|
with col2: |
|
if 'user_masked_sentence' in locals(): |
|
if user_masked_sentence: |
|
user_predictions = unmasker(user_masked_sentence) |
|
|
|
|
|
|
|
if len(user_predictions) > 0: |
|
|
|
|
|
st.write("### Predictions for Sample Sentences:") |
|
predictions = fill_mask_for_languages(sample_sentences) |
|
st.write(f"{predictions}") |
|
|
|
if 'predictions' in locals(): |
|
if predictions: |
|
for language, language_predictions in predictions.items(): |
|
|
|
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) |
|
|
|
|
|
st.write(f"{language}: {predicted_sentence}\n") |
|
|
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important;} |
|
|
|
.gr-button-primary { |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(17, 20, 45) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: none !important; |
|
} |
|
.gr-button-primary:hover{ |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(66, 133, 244) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; |
|
} |
|
.hover\:bg-orange-50:hover { |
|
--tw-bg-opacity: 1 !important; |
|
background-color: rgb(229,225,255) !important; |
|
} |
|
.to-orange-200 { |
|
--tw-gradient-to: rgb(37 56 133 / 37%) !important; |
|
} |
|
.from-orange-400 { |
|
--tw-gradient-from: rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(255 150 51 / 0); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group-hover\:from-orange-500{ |
|
--tw-gradient-from:rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(37 56 133 / 37%); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group:hover .group-hover\:text-orange-500{ |
|
--tw-text-opacity: 1 !important; |
|
color:rgb(37 56 133 / var(--tw-text-opacity)) !important; |
|
} |
|
.container { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-bottom: 5px; |
|
width: 100%; |
|
} |
|
.bar { |
|
width: 70%; |
|
background-color: #e6e6e6; |
|
border-radius: 12px; |
|
overflow: hidden; |
|
margin-right: 10px; |
|
height: 5px; |
|
} |
|
.bar-fill { |
|
background-color: #17152e; |
|
height: 100%; |
|
border-radius: 12px; |
|
} |
|
</style> |
|
""" |
|
st.markdown(css, unsafe_allow_html=True) |
|
|