|
import streamlit as st |
|
from transformers import pipeline |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta') |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.stop_rerun = True |
|
|
|
def fill_mask(sentences): |
|
results = {} |
|
warnings = [] |
|
for key, (language, sentence) in sentences.items(): |
|
if "<mask>" in sentence: |
|
masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token) |
|
unmasked = unmasker(masked_sentence) |
|
results[key] = (language, unmasked) |
|
else: |
|
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") |
|
return results, warnings |
|
|
|
def replace_mask(sentence, predicted_word): |
|
return sentence.replace("<mask>", f"**{predicted_word}**") |
|
|
|
st.title("Fill Mask | Zabantu-XLM-Roberta") |
|
st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combinations of South African Bantu languages.") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
if 'text_input' not in st.session_state: |
|
st.session_state['text_input'] = "" |
|
|
|
if 'warnings' not in st.session_state: |
|
st.session_state['warnings'] = [] |
|
|
|
if 'result' not in st.session_state: |
|
st.session_state['result'] = {} |
|
|
|
language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga'] |
|
|
|
input_sentences = {} |
|
|
|
with col1: |
|
with st.container(): |
|
st.markdown("Input :clipboard:") |
|
|
|
input1, input2 = st.columns(2) |
|
|
|
for i in range(5): |
|
with input1: |
|
language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0) |
|
with input2: |
|
sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}') |
|
if sentence: |
|
input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence) |
|
|
|
button1, button2, _ = st.columns([2, 2, 4]) |
|
|
|
if st.button("Test Example"): |
|
sample_sentences = { |
|
'zulu': "Le ndoda ithi izo <mask> ukudla.", |
|
'tshivenda': "Mufana uyo <mask> vhukuma.", |
|
'sepedi': "Mosadi o <mask> pheka.", |
|
'tswana': "Monna o <mask> tsamaya.", |
|
'tsonga': "N'wana wa xisati u <mask> ku tsaka." |
|
} |
|
st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences) |
|
|
|
if st.button("Submit"): |
|
st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences) |
|
|
|
if st.session_state['warnings']: |
|
for warning in st.session_state['warnings']: |
|
st.warning(warning) |
|
|
|
st.markdown("Example") |
|
st.code({ |
|
'zulu': "Le ndoda ithi izo <mask> ukudla.", |
|
'tshivenda': "Mufana uyo <mask> vhukuma.", |
|
'sepedi': "Mosadi o <mask> pheka.", |
|
'tswana': "Monna o <mask> tsamaya.", |
|
'tsonga': "N'wana wa xisati u <mask> ku tsaka." |
|
}, wrap_lines=True) |
|
|
|
with col2: |
|
with st.container(): |
|
st.markdown("Output :bar_chart:") |
|
if st.session_state['result']: |
|
for key, (language, predictions) in st.session_state['result'].items(): |
|
original_sentence = input_sentences[key][1] if key in input_sentences else "" |
|
if predictions: |
|
top_prediction = predictions[0] |
|
predicted_word = top_prediction['token_str'] |
|
score = top_prediction['score'] * 100 |
|
|
|
st.markdown(f""" |
|
<div class="bar"> |
|
<div class="bar-fill" style="width: {score}%;"></div> |
|
</div> |
|
<div class="container"> |
|
<div style="align-items: left;">{predicted_word} ({language})</div> |
|
<div style="align-items: right;">{score:.2f}%</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
if 'predictions' in locals(): |
|
if result: |
|
for language, language_predictions in result.items(): |
|
original_sentence = sample_sentence[language] |
|
predicted_sentence = replace_mask(original_sentence, language_predictions[0]['token_str']) |
|
st.write(f"{language}: {predicted_sentence}\n") |
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important;} |
|
|
|
.container { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-bottom: 5px; |
|
width: 100%; |
|
} |
|
.bar { |
|
background-color: #e6e6e6; |
|
border-radius: 12px; |
|
overflow: hidden; |
|
margin-right: 10px; |
|
height: 5px; |
|
} |
|
.bar-fill { |
|
background-color: #17152e; |
|
height: 100%; |
|
border-radius: 12px; |
|
} |
|
</style> |
|
""" |
|
st.markdown(css, unsafe_allow_html=True) |