File size: 2,430 Bytes
f5f8f9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
from transformers import pipeline

unmasker = pipeline('fill-mask', model='dsfsi/zabantu-bantu-250m')

sample_sentences = {
    'Zulu': "Le ndoda ithi izo____ ukudla.",  
    'Tshivenda': "Mufana uyo____ vhukuma.", 
    'Sepedi': "Mosadi o ____ pheka.",  
    'Tswana': "Monna o ____ tsamaya.",  
    'Tsonga': "N'wana wa xisati u ____ ku tsaka."  
}

def fill_mask_for_languages(sentences):
    results = {}
    for language, sentence in sentences.items():
        masked_sentence = sentence.replace('____', unmasker.tokenizer.mask_token)
        
        unmasked = unmasker(masked_sentence)
        
        results[language] = unmasked
    return results

st.title("Fill Mask for Multiple Languages | Zabantu-Bantu-250m")
st.write("This app predicts the missing word for sentences in Zulu, Tshivenda, Sepedi, Tswana, and Tsonga using a Zabantu BERT model.")

st.write("### Sample sentences:")
for language, sentence in sample_sentences.items():
    st.write(f"**{language}**: {sentence}")

if st.button("Submit"):
    result = fill_mask_for_languages(sample_sentences)

    if result:
        st.write("### Predictions:")
        for language, predictions in result.items():
            original_sentence = sample_sentences[language]
            predicted_sentence = predictions[0]['sequence']
            st.write(f"Original sentence ({language}): {original_sentence}")
            st.write(f"Top prediction for the masked token: {predicted_sentence}\n")
            st.write("=" * 80)

css = """
<style>
footer {display:none !important}

.stButton > button {
    background-color: #17152e;
    color: white;
    border: none;
    padding: 0.75em 2em;
    text-align: center;
    text-decoration: none;
    display: inline-block;
    font-size: 16px;
    margin: 4px 2px;
    cursor: pointer;
    border-radius: 12px;
    transition: background-color 0.3s ease;
}

.stButton > button:hover {
    background-color: #3c4a6b;
}

.stTextInput, .stTextArea {
    border: 1px solid #e6e6e6;
    padding: 0.75em;
    border-radius: 10px;
    font-size: 16px;
    width: 100%;
}

.stTextInput:focus, .stTextArea:focus {
    border-color: #17152e;
    outline: none;
    box-shadow: 0px 0px 5px rgba(23, 21, 46, 0.5);
}

div[data-testid="stMarkdownContainer"] p {
    font-size: 16px;
}

.stApp {
    padding: 2em;
    font-family: 'Poppins', sans-serif;
}
</style>
"""
st.markdown(css, unsafe_allow_html=True)