import ast import streamlit as st def generate_sidebar(): st.sidebar.header("About", divider="rainbow") st.sidebar.markdown( ("SciPIP will generate ideas step by step. The generation pipeline is the same as " "one-click generation, While you can improve each part manually after SciPIP providing the manuscript.") ) DONE_COLOR = "black" UNDONE_COLOR = "gray" # INPROGRESS_COLOR = "#4d9ee6" INPROGRESS_COLOR = "black" color_list = [] pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works", "5. Generate Initial Ideas", "6. Generate Final Ideas"] for i in range(1, 8): if st.session_state["global_state_step"] < i: color_list.append(UNDONE_COLOR) elif st.session_state["global_state_step"] == i: color_list.append(INPROGRESS_COLOR) elif st.session_state["global_state_step"] > i: color_list.append(DONE_COLOR) st.sidebar.header("Pipeline", divider="red") for i in range(6): st.sidebar.markdown(f"{pipeline_list[i]}", unsafe_allow_html=True) # if st.session_state["global_state_step"] == i + 1: # st.sidebar.progress(50, text=None) st.sidebar.header("Supported Fields", divider="orange") st.sidebar.caption("The supported fields are temporarily limited because we only collect literature " "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.") st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True) st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True) st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True) st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True) st.sidebar.header("Help Us To Improve", divider="green") st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True) def get_textarea_height(text_content): if text_content is None: return 100 lines = text_content.split("\n") count = len(lines) for line in lines: count += len(line) // 96 return count * 23 + 20 # 23 is a magic number def genrate_mainpage(backend): # print("refresh mainpage") st.title('💦 Generate Idea Step-by-step') st.markdown("# 🐳 Background") with st.form('background_form') as bg_form: background = st.session_state.get("background", "") background = st.text_area("Input your field background", background, placeholder="Input your field background", height=200, label_visibility="collapsed") cols = st.columns(2) def click_demo_i(i): st.session_state["background"] = backend.get_demo_i(i) for i, col in enumerate(cols): col.form_submit_button(f"Example {i + 1}", use_container_width=True, on_click=click_demo_i, args=(i,)) col1, col2 = st.columns([2, 30]) submitted = col1.form_submit_button('Submit', type="primary") if submitted: st.session_state["global_state_step"] = 2.0 with st.spinner(text="Brainstorming..."): st.session_state["brainstorms"] = backend.background2brainstorm_callback(background) # st.session_state["brainstorms"] = "Test text" st.session_state["brainstorms_expand"] = True st.session_state["global_state_step"] = 2.5 # st.warning('Please enter your OpenAI API key!', icon='⚠') ## Brainstorms st.markdown("# 👻 Brainstorms") with st.expander("Here is the generated brainstorms", expanded=st.session_state.get("brainstorms_expand", False)): # st.write("
") col1, col2 = st.columns(2) widget_height = get_textarea_height(st.session_state.get("brainstorms", "")) brainstorms = col1.text_area(label="brainstorms", value=st.session_state.get("brainstorms", ""), label_visibility="collapsed", height=widget_height) st.session_state["brainstorms"] = brainstorms if brainstorms: col2.markdown(f"{brainstorms}") else: col2.markdown(f"Please input the brainstorms on the left.") # st.write("
") col1, col2 = st.columns([2, 30]) submitted = col1.button('Submit') if submitted: st.session_state["global_state_step"] = 3.0 with st.spinner(text="Extracting entities..."): st.session_state["entities"] = backend.brainstorm2entities_callback(background, brainstorms) # st.session_state["entities"] = "entities" st.session_state["global_state_step"] = 3.5 st.session_state["entities_expand"] = True ## Entities st.markdown("# 🐱 Extracted Entities") with st.expander("Here is the extracted entities", expanded=st.session_state.get("entities_expand", False)): col1, col2 = st.columns(2, ) entities = col1.text_area(label="entities", value=st.session_state.get("entities", "[]"), label_visibility="collapsed") entities = ast.literal_eval(entities) st.session_state["entities"] = entities if entities: col2.markdown(f"{entities}") else: col2.markdown(f"Please input the entities on the left.") submitted = col1.button('Submit', key="entities_button") if submitted: st.session_state["global_state_step"] = 4.0 with st.spinner(text="Retrieving related works..."): st.session_state["related_works"], st.session_state["related_works_intact"] = backend.entities2literature_callback(background, entities) # st.session_state["related_works"] = "related works" st.session_state["global_state_step"] = 4.5 st.session_state["related_works_expand"] = True ## Retrieved related works st.markdown("# 📖 Retrieved Related Works") with st.expander("Here is the retrieved related works", expanded=st.session_state.get("related_works_expand", False)): col1, col2 = st.columns(2, ) widget_height = get_textarea_height(st.session_state.get("related_works", "")) related_works_title = col1.text_area(label="related_works", value=st.session_state.get("related_works", ""), label_visibility="collapsed", height=widget_height) if related_works_title: col2.markdown(f"{related_works_title}") else: col2.markdown(f"Please input the related works on the left.") submitted = col1.button('Submit', key="related_works_button") if submitted: st.session_state["global_state_step"] = 5.0 with st.spinner(text="Generating initial ideas..."): res = backend.literature2initial_ideas_callback(background, st.session_state["related_works_intact"]) st.session_state["initial_ideas"] = res[0] st.session_state["final_ideas"] = res[1] # st.session_state["initial_ideas"] = "initial ideas" st.session_state["global_state_step"] = 5.5 st.session_state["initial_ideas_expand"] = True ## Initial ideas st.markdown("# 😼 Generated Initial Ideas") with st.expander("Here is the generated initial ideas", expanded=st.session_state.get("initial_ideas_expand", False)): col1, col2 = st.columns(2, ) widget_height = get_textarea_height(st.session_state.get("initial_ideas", "")) initial_ideas = col1.text_area(label="initial_ideas", value=st.session_state.get("initial_ideas", ""), label_visibility="collapsed", height=widget_height) if initial_ideas: col2.markdown(f"{initial_ideas}") else: col2.markdown(f"Please input the initial ideas on the left.") submitted = col1.button('Submit', key="initial_ideas_button") if submitted: st.session_state["global_state_step"] = 6.0 with st.spinner(text="Generating final ideas..."): st.session_state["final_ideas"] = backend.initial2final_callback(initial_ideas, st.session_state["final_ideas"]) # st.session_state["final_ideas"] = "final ideas" st.session_state["global_state_step"] = 6.5 st.session_state["final_ideas_expand"] = True ## Final ideas st.markdown("# 😸 Generated Final Ideas") with st.expander("Here is the generated final ideas", expanded=st.session_state.get("final_ideas_expand", False)): col1, col2 = st.columns(2, ) widget_height = get_textarea_height(st.session_state.get("final_ideas", "")) user_input = col1.text_area(label="final_ideas", value=st.session_state.get("final_ideas", ""), label_visibility="collapsed", height=widget_height) if user_input: col2.markdown(f"{user_input}") else: col2.markdown(f"Please input the final ideas on the left.") submitted = col1.button('Submit', key="final_ideas_button") def step_by_step_generation(backend): ## Pipeline global state # 1.0: Input background is in progress # 2.0: Brainstorming is in progress # 2.5 Brainstorming is finished # 3.0: Extracting entities is in progress # 3.5 Extracting entities is finished # 4.0: Retrieving literature is in progress # 4.5 Retrieving ideas is finished # 5.0: Generating initial ideas is in progress # 5.5 Generating initial ideas is finished # 6.0: Generating final ideas is in progress # 6.5 Generating final ideas is finished if "global_state_step" not in st.session_state: st.session_state["global_state_step"] = 1.0 # backend = button_interface.Backend() genrate_mainpage(backend) generate_sidebar()