UnarineLeo commited on
Commit
5b5eeef
·
verified ·
1 Parent(s): 4eb39ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -39
app.py CHANGED
@@ -5,14 +5,17 @@ unmasker = pipeline('fill-mask', model='dsfsi/zabantu-xlm-roberta')
5
 
6
  st.set_page_config(layout="wide")
7
 
 
 
 
8
  def fill_mask(sentences):
9
  results = {}
10
  warnings = []
11
- for language, sentence in sentences.items():
12
  if "<mask>" in sentence:
13
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
14
  unmasked = unmasker(masked_sentence)
15
- results[language] = unmasked
16
  else:
17
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
18
  return results, warnings
@@ -25,52 +28,52 @@ st.markdown("Zabantu-XLMR refers to a fleet of models trained on different combi
25
 
26
  col1, col2 = st.columns(2)
27
 
28
- # Initialize session states
29
- if 'submit_clicked' not in st.session_state:
30
- st.session_state['submit_clicked'] = False
31
  if 'warnings' not in st.session_state:
32
  st.session_state['warnings'] = []
33
 
 
 
 
34
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
35
 
 
 
36
  with col1:
37
  with st.container():
38
  st.markdown("Input :clipboard:")
39
 
40
  input1, input2 = st.columns(2)
41
 
42
- input_sentences = {}
43
  for i in range(5):
44
  with input1:
45
- language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}')
46
  with input2:
47
- # Disable text input if language is not selected
48
- disabled = True if language == "Choose language" else False
49
- sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}', disabled=disabled)
50
- if not disabled and sentence:
51
- input_sentences[language.lower()] = sentence
52
 
53
  button1, button2, _ = st.columns([2, 2, 4])
54
-
55
- with button1:
56
- if st.button("Test Example"):
57
- sample_sentence = {
58
- 'zulu': "Le ndoda ithi izo <mask> ukudla.",
59
- 'tshivenda': "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis.",
60
- 'tshivenda': "Rabulasi wa <mask> u khou bvelela nga u lima",
61
- 'tswana': "Monna o <mask> tsamaya.",
62
- 'tsonga': "N'wana wa xisati u <mask> ku tsaka."
63
- }
64
- input_sentences = sample_sentence
65
- result, warnings = fill_mask(input_sentences)
66
-
67
- with button2:
68
- # Set session state when "Submit" is clicked
69
- if st.button("Submit"):
70
- st.session_state['submit_clicked'] = True
71
- result, warnings = fill_mask(input_sentences)
72
- st.session_state['warnings'] = warnings
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  if st.session_state['warnings']:
75
  for warning in st.session_state['warnings']:
76
  st.warning(warning)
@@ -87,13 +90,10 @@ with col1:
87
  with col2:
88
  with st.container():
89
  st.markdown("Output :bar_chart:")
90
-
91
- # Ensure output only runs after "Submit" is clicked
92
- if st.session_state['submit_clicked'] and input_sentences:
93
- for language, sentence in input_sentences.items():
94
- masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
95
- predictions = unmasker(masked_sentence)
96
-
97
  if predictions:
98
  top_prediction = predictions[0]
99
  predicted_word = top_prediction['token_str']
@@ -109,7 +109,9 @@ with col2:
109
  </div>
110
  """, unsafe_allow_html=True)
111
 
112
- # CSS to hide footer and style the output
 
 
113
  css = """
114
  <style>
115
  footer {display:none !important;}
 
5
 
6
  st.set_page_config(layout="wide")
7
 
8
+ # Disable auto-rerun when selecting options or typing
9
+ st.stop_rerun = True
10
+
11
  def fill_mask(sentences):
12
  results = {}
13
  warnings = []
14
+ for key, (language, sentence) in sentences.items():
15
  if "<mask>" in sentence:
16
  masked_sentence = sentence.replace('<mask>', unmasker.tokenizer.mask_token)
17
  unmasked = unmasker(masked_sentence)
18
+ results[key] = (language, unmasked)
19
  else:
20
  warnings.append(f"Warning: No <mask> token found in sentence: {sentence}")
21
  return results, warnings
 
28
 
29
  col1, col2 = st.columns(2)
30
 
31
+ if 'text_input' not in st.session_state:
32
+ st.session_state['text_input'] = ""
33
+
34
  if 'warnings' not in st.session_state:
35
  st.session_state['warnings'] = []
36
 
37
+ if 'result' not in st.session_state:
38
+ st.session_state['result'] = {}
39
+
40
  language_options = ['Choose language', 'Zulu', 'Tshivenda', 'Sepedi', 'Tswana', 'Tsonga']
41
 
42
+ input_sentences = {}
43
+
44
  with col1:
45
  with st.container():
46
  st.markdown("Input :clipboard:")
47
 
48
  input1, input2 = st.columns(2)
49
 
50
+ # Loop to gather input sentences
51
  for i in range(5):
52
  with input1:
53
+ language = st.selectbox(f"Select language for sentence {i+1}:", language_options, key=f'language_{i}', index=0)
54
  with input2:
55
+ sentence = st.text_input(f"Enter sentence for {language} (with <mask>):", key=f'text_input_{i}')
56
+ if sentence:
57
+ # Use a unique key for each sentence (even if languages are the same)
58
+ input_sentences[f'{language.lower()}_{i+1}'] = (language.lower(), sentence)
 
59
 
60
  button1, button2, _ = st.columns([2, 2, 4])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Call fill_mask on button click, not on form input
63
+ if st.button("Test Example"):
64
+ sample_sentences = {
65
+ 'zulu_1': ('zulu', "Le ndoda ithi izo <mask> ukudla."),
66
+ 'tshivenda_2': ('tshivenda', "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis."),
67
+ 'tshivenda_3': ('tshivenda', "Rabulasi wa <mask> u khou bvelela nga u lima"),
68
+ 'tswana_4': ('tswana', "Monna o <mask> tsamaya."),
69
+ 'tsonga_5': ('tsonga', "N'wana wa xisati u <mask> ku tsaka.")
70
+ }
71
+ st.session_state['result'], st.session_state['warnings'] = fill_mask(sample_sentences)
72
+
73
+ if st.button("Submit"):
74
+ st.session_state['result'], st.session_state['warnings'] = fill_mask(input_sentences)
75
+
76
+ # Display warnings
77
  if st.session_state['warnings']:
78
  for warning in st.session_state['warnings']:
79
  st.warning(warning)
 
90
  with col2:
91
  with st.container():
92
  st.markdown("Output :bar_chart:")
93
+ # Check for the result in session_state and display predictions
94
+ if st.session_state['result']:
95
+ for key, (language, predictions) in st.session_state['result'].items():
96
+ original_sentence = input_sentences[key][1] if key in input_sentences else ""
 
 
 
97
  if predictions:
98
  top_prediction = predictions[0]
99
  predicted_word = top_prediction['token_str']
 
109
  </div>
110
  """, unsafe_allow_html=True)
111
 
112
+ predicted_sentence = replace_mask(original_sentence, predicted_word)
113
+ st.write(f"{language}: {predicted_sentence}\n")
114
+
115
  css = """
116
  <style>
117
  footer {display:none !important;}