|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m') |
|
|
|
sample_sentences = { |
|
'Zulu': "Le ndoda ithi izo____ ukudla.", |
|
'Tshivenda': "Mufana uyo____ vhukuma.", |
|
'Sepedi': "Mosadi o ____ pheka.", |
|
'Tswana': "Monna o ____ tsamaya.", |
|
'Tsonga': "N'wana wa xisati u ____ ku tsaka." |
|
} |
|
|
|
def fill_mask_for_languages(sentences): |
|
results = {} |
|
for language, sentence in sentences.items(): |
|
masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token) |
|
|
|
unmasked = unmasker(masked_sentence) |
|
|
|
results[language] = unmasked |
|
return results |
|
|
|
st.title("Fill Mask for Multiple Languages | Zabantu-Bantu-250m") |
|
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.") |
|
|
|
st.write("### Sample sentences:") |
|
for language, sentence in sample_sentences.items(): |
|
st.write(f"**{language}**: {sentence}") |
|
|
|
if st.button("Submit"): |
|
result = fill_mask_for_languages(sample_sentences) |
|
|
|
if result: |
|
st.write("### Predictions:") |
|
for language, predictions in result.items(): |
|
original_sentence = sample_sentences[language] |
|
predicted_sentence = predictions[0]['sequence'] |
|
st.write(f"Original sentence ({language}): {original_sentence}") |
|
st.write(f"Top prediction for the masked token: {predicted_sentence}\n") |
|
st.write("=" * 80) |
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important} |
|
|
|
.stButton > button { |
|
background-color: #17152e; |
|
color: white; |
|
border: none; |
|
padding: 0.75em 2em; |
|
text-align: center; |
|
text-decoration: none; |
|
display: inline-block; |
|
font-size: 16px; |
|
margin: 4px 2px; |
|
cursor: pointer; |
|
border-radius: 12px; |
|
transition: background-color 0.3s ease; |
|
} |
|
|
|
.stButton > button:hover { |
|
background-color: #3c4a6b; |
|
} |
|
|
|
.stTextInput, .stTextArea { |
|
border: 1px solid #e6e6e6; |
|
padding: 0.75em; |
|
border-radius: 10px; |
|
font-size: 16px; |
|
width: 100%; |
|
} |
|
|
|
.stTextInput:focus, .stTextArea:focus { |
|
border-color: #17152e; |
|
outline: none; |
|
box-shadow: 0px 0px 5px rgba(23, 21, 46, 0.5); |
|
} |
|
|
|
div[data-testid="stMarkdownContainer"] p { |
|
font-size: 16px; |
|
} |
|
|
|
.stApp { |
|
padding: 2em; |
|
font-family: 'Poppins', sans-serif; |
|
} |
|
</style> |
|
""" |
|
st.markdown(css, unsafe_allow_html=True) |
|
|