import streamlit as st import pandas as pd import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity import torch import json import os import glob from pathlib import Path from datetime import datetime import edge_tts import asyncio import requests from collections import defaultdict from audio_recorder_streamlit import audio_recorder import streamlit.components.v1 as components from urllib.parse import quote from xml.etree import ElementTree as ET from datasets import load_dataset # 🧠 Initialize session state variables SESSION_VARS = { 'search_history': [], # Track search history 'last_voice_input': "", # Last voice input 'transcript_history': [], # Conversation history 'should_rerun': False, # Trigger for UI updates 'search_columns': [], # Available search columns 'initial_search_done': False, # First search flag 'tts_voice': "en-US-AriaNeural", # Default voice 'arxiv_last_query': "", # Last ArXiv search 'dataset_loaded': False, # Dataset load status 'current_page': 0, # Current data page 'data_cache': None, # Data cache 'dataset_info': None # Dataset metadata } # Constants ROWS_PER_PAGE = 100 # Initialize session state for var, default in SESSION_VARS.items(): if var not in st.session_state: st.session_state[var] = default @st.cache_resource def get_model(): return SentenceTransformer('all-MiniLM-L6-v2') @st.cache_data def load_dataset_page(dataset_id, token, page, rows_per_page): try: start_idx = page * rows_per_page end_idx = start_idx + rows_per_page dataset = load_dataset( dataset_id, token=token, streaming=False, split=f'train[{start_idx}:{end_idx}]' ) return pd.DataFrame(dataset) except Exception as e: st.error(f"Error loading page {page}: {str(e)}") return pd.DataFrame() @st.cache_data def get_dataset_info(dataset_id, token): try: dataset = load_dataset(dataset_id, token=token, streaming=True) return dataset['train'].info except Exception as e: st.error(f"Error loading dataset info: {str(e)}") return None def fetch_dataset_info(dataset_id): info_url = f"https://huggingface.co/api/datasets/{dataset_id}" try: response = requests.get(info_url, timeout=30) if response.status_code == 200: return response.json() except Exception as e: st.warning(f"Error fetching dataset info: {e}") return None def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100): url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}" try: response = requests.get(url, timeout=30) if response.status_code == 200: data = response.json() if 'rows' in data: processed_rows = [] for row_data in data['rows']: row = row_data.get('row', row_data) # Process embeddings if present for key in row: if any(term in key.lower() for term in ['embed', 'vector', 'encoding']): if isinstance(row[key], str): try: row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()] except: continue row['_config'] = config row['_split'] = split processed_rows.append(row) return processed_rows except Exception as e: st.warning(f"Error fetching rows: {e}") return [] class FastDatasetSearcher: def __init__(self, dataset_id="tomg-group-umd/cinepile"): self.dataset_id = dataset_id self.text_model = get_model() self.token = os.environ.get('DATASET_KEY') if not self.token: st.error("Please set the DATASET_KEY environment variable") st.stop() if st.session_state['dataset_info'] is None: st.session_state['dataset_info'] = get_dataset_info(self.dataset_id, self.token) def load_page(self, page=0): return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE) def quick_search(self, query, df): """Enhanced search with strict token matching and semantic relevance""" if df.empty or not query.strip(): return df try: # Define stricter thresholds MIN_SEMANTIC_SCORE = 0.5 # Higher semantic threshold EXACT_MATCH_BOOST = 2.0 # Boost for exact matches # Get searchable columns searchable_cols = [] for col in df.columns: sample_val = df[col].iloc[0] if not isinstance(sample_val, (np.ndarray, bytes)): searchable_cols.append(col) query_lower = query.lower() query_terms = set(query_lower.split()) query_embedding = self.text_model.encode([query], show_progress_bar=False)[0] scores = [] matched_any = [] for _, row in df.iterrows(): text_parts = [] row_matched = False exact_match = False # Prioritize description and matched_text fields priority_fields = ['description', 'matched_text'] other_fields = [col for col in searchable_cols if col not in priority_fields] # First check priority fields for exact matches for col in priority_fields: if col in row: val = row[col] if val is not None: val_str = str(val).lower() # Check for exact token matches if query_lower in val_str.split(): exact_match = True if any(term in val_str.split() for term in query_terms): row_matched = True text_parts.append(str(val)) # Then check other fields for col in other_fields: val = row[col] if val is not None: val_str = str(val).lower() if query_lower in val_str.split(): exact_match = True if any(term in val_str.split() for term in query_terms): row_matched = True text_parts.append(str(val)) text = ' '.join(text_parts) if text.strip(): # Calculate exact token matches text_tokens = set(text.lower().split()) matching_terms = query_terms.intersection(text_tokens) keyword_score = len(matching_terms) / len(query_terms) # Calculate semantic score text_embedding = self.text_model.encode([text], show_progress_bar=False)[0] semantic_score = float(cosine_similarity([query_embedding], [text_embedding])[0][0]) # Weighted scoring with priority for exact matches combined_score = 0.8 * keyword_score + 0.2 * semantic_score if exact_match: combined_score *= EXACT_MATCH_BOOST elif row_matched: combined_score *= 1.2 else: combined_score = 0.0 row_matched = False scores.append(combined_score) matched_any.append(row_matched) results_df = df.copy() results_df['score'] = scores results_df['matched'] = matched_any # Filter relevant results filtered_df = results_df[ (results_df['matched']) | # Include direct matches (results_df['score'] > MIN_KEYWORD_MATCHES) # Or high relevance ] return filtered_df.sort_values('score', ascending=False) except Exception as e: st.error(f"Search error: {str(e)}") return df class VideoSearch: def __init__(self): self.text_model = SentenceTransformer('all-MiniLM-L6-v2') self.dataset_id = "omegalabsinc/omega-multimodal" self.load_dataset() def fetch_dataset_rows(self): try: df, configs, splits = search_dataset( self.dataset_id, "", include_configs=None, include_splits=None ) if not df.empty: st.session_state['search_columns'] = [col for col in df.columns if col not in ['video_embed', 'description_embed', 'audio_embed'] and not col.startswith('_')] return df return self.load_example_data() except Exception as e: st.warning(f"Error loading videos: {e}") return self.load_example_data() def load_example_data(self): example_data = [{ "video_id": "sample-123", "youtube_id": "dQw4w9WgXcQ", "description": "An example video", "views": 12345, "start_time": 0, "end_time": 60 }] return pd.DataFrame(example_data) def load_dataset(self): self.dataset = self.fetch_dataset_rows() self.prepare_features() def prepare_features(self): try: embed_cols = [col for col in self.dataset.columns if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])] embeddings = {} for col in embed_cols: try: data = [] for row in self.dataset[col]: if isinstance(row, str): values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()] elif isinstance(row, list): values = row else: continue data.append(values) if data: embeddings[col] = np.array(data) except: continue self.video_embeds = embeddings.get('video_embed', next(iter(embeddings.values())) if embeddings else None) self.text_embeds = embeddings.get('description_embed', self.video_embeds) except: num_rows = len(self.dataset) self.video_embeds = np.random.randn(num_rows, 384) self.text_embeds = np.random.randn(num_rows, 384) def search(self, query, column=None, top_k=20): """Enhanced search with better relevance scoring""" MIN_RELEVANCE = 0.3 # Minimum relevance threshold query_embedding = self.text_model.encode([query])[0] video_sims = cosine_similarity([query_embedding], self.video_embeds)[0] text_sims = cosine_similarity([query_embedding], self.text_embeds)[0] combined_sims = 0.7 * text_sims + 0.3 * video_sims # Favor text matches if column and column in self.dataset.columns and column != "All Fields": # Direct matches in specified column matches = self.dataset[column].astype(str).str.contains(query, case=False) combined_sims[matches] *= 1.5 # Boost exact matches # Filter by minimum relevance relevant_indices = np.where(combined_sims >= MIN_RELEVANCE)[0] if len(relevant_indices) == 0: return [] top_k = min(top_k, len(relevant_indices)) top_indices = relevant_indices[np.argsort(combined_sims[relevant_indices])[-top_k:][::-1]] results = [] for idx in top_indices: result = {'relevance_score': float(combined_sims[idx])} for col in self.dataset.columns: if col not in ['video_embed', 'description_embed', 'audio_embed']: result[col] = self.dataset.iloc[idx][col] results.append(result) return results def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None): dataset_info = fetch_dataset_info(dataset_id) if not dataset_info: return pd.DataFrame(), [], [] configs = include_configs if include_configs else dataset_info.get('config_names', ['default']) all_rows = [] available_splits = set() for config in configs: try: splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}" splits_response = requests.get(splits_url, timeout=30) if splits_response.status_code == 200: splits_data = splits_response.json() splits = [split['split'] for split in splits_data.get('splits', [])] if not splits: splits = ['train'] if include_splits: splits = [s for s in splits if s in include_splits] available_splits.update(splits) for split in splits: rows = fetch_dataset_rows(dataset_id, config, split) for row in rows: text_content = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float))) if search_text.lower() in text_content.lower(): row['_matched_text'] = text_content row['_relevance_score'] = text_content.lower().count(search_text.lower()) all_rows.append(row) except Exception as e: st.warning(f"Error processing config {config}: {e}") continue if all_rows: df = pd.DataFrame(all_rows) df = df.sort_values('_relevance_score', ascending=False) return df, configs, list(available_splits) return pd.DataFrame(), configs, list(available_splits) @st.cache_resource def get_speech_model(): return edge_tts.Communicate async def generate_speech(text, voice=None): if not text.strip(): return None if not voice: voice = st.session_state['tts_voice'] try: communicate = get_speech_model()(text, voice) audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3" await communicate.save(audio_file) return audio_file except Exception as e: st.error(f"Error generating speech: {e}") return None def transcribe_audio(audio_path): """Placeholder for ASR implementation""" return "ASR not implemented. Add your preferred speech recognition here!" def arxiv_search(query, max_results=5): base_url = "http://export.arxiv.org/api/query?" search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}" try: r = requests.get(search_url) if r.status_code == 200: root = ET.fromstring(r.text) ns = {'atom': 'http://www.w3.org/2005/Atom'} entries = root.findall('atom:entry', ns) results = [] for entry in entries: title = entry.find('atom:title', ns).text.strip() summary = entry.find('atom:summary', ns).text.strip() link = next((l.get('href') for l in entry.findall('atom:link', ns) if l.get('type') == 'text/html'), None) results.append((title, summary, link)) return results except Exception as e: st.error(f"ArXiv search error: {e}") return [] def show_file_manager(): st.subheader("📂 File Manager") col1, col2 = st.columns(2) with col1: uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3']) if uploaded_file: with open(uploaded_file.name, "wb") as f: f.write(uploaded_file.getvalue()) st.success(f"Uploaded: {uploaded_file.name}") st.experimental_rerun() with col2: if st.button("🗑 Clear Files"): for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"): os.remove(f) st.success("All files cleared!") st.experimental_rerun() files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3") if files: st.write("### Existing Files") for f in files: with st.expander(f"📄 {os.path.basename(f)}"): if f.endswith('.mp3'): st.audio(f) else: with open(f, 'r', encoding='utf-8') as file: st.text_area("Content", file.read(), height=100) if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"): os.remove(f) st.experimental_rerun() def perform_arxiv_lookup(query, vocal_summary=True, titles_summary=True, full_audio=False): results = arxiv_search(query, max_results=5) if not results: st.write("No results found.") return st.markdown(f"**ArXiv Results for '{query}':**") for i, (title, summary, link) in enumerate(results, start=1): st.markdown(f"**{i}. {title}**") st.write(summary) if link: st.markdown(f"[View Paper]({link})") if vocal_summary: spoken_text = f"Here are ArXiv results for {query}. " if titles_summary: spoken_text += " Titles: " + ", ".join([res[0] for res in results]) else: spoken_text += " " + results[0][1][:200] audio_file = asyncio.run(generate_speech(spoken_text)) if audio_file: st.audio(audio_file) if full_audio: full_text = "" for i, (title, summary, _) in enumerate(results, start=1): full_text += f"Result {i}: {title}. {summary} " audio_file_full = asyncio.run(generate_speech(full_text)) if audio_file_full: st.write("### Full Audio Summary") st.audio(audio_file_full) def render_result(result): """Render a search result with voice selection and TTS options""" score = result.get('relevance_score', 0) result_filtered = {k: v for k, v in result.items() if k not in ['relevance_score', 'video_embed', 'description_embed', 'audio_embed']} if 'youtube_id' in result: st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}") cols = st.columns([2, 1]) with cols[0]: text_content = [] # Collect text for TTS for key, value in result_filtered.items(): if isinstance(value, (str, int, float)): st.write(f"**{key}:** {value}") if isinstance(value, str) and len(value.strip()) > 0: text_content.append(f"{key}: {value}") with cols[1]: st.metric("Relevance Score", f"{score:.2%}") # Voice selection for TTS voices = { "Aria (US Female)": "en-US-AriaNeural", "Guy (US Male)": "en-US-GuyNeural", "Sonia (UK Female)": "en-GB-SoniaNeural", "Tony (UK Male)": "en-GB-TonyNeural", "Jenny (US Female)": "en-US-JennyNeural" } selected_voice = st.selectbox( "Select Voice", list(voices.keys()), key=f"voice_{result.get('video_id', '')}" ) if st.button("🔊 Read Description", key=f"read_{result.get('video_id', '')}"): text_to_read = ". ".join(text_content) audio_file = asyncio.run(generate_speech(text_to_read, voices[selected_voice])) if audio_file: st.audio(audio_file) def main(): st.title("🎥 Advanced Video & Dataset Search with Voice") # Initialize search search = VideoSearch() # Create tabs tab1, tab2, tab3, tab4 = st.tabs([ "🔍 Search", "🎙️ Voice Input", "📚 ArXiv", "📂 Files" ]) # Search Tab with tab1: st.subheader("Search Videos") col1, col2 = st.columns([3, 1]) with col1: query = st.text_input("Enter search query:", value="" if st.session_state['initial_search_done'] else "aliens") with col2: search_column = st.selectbox("Search in:", ["All Fields"] + st.session_state['search_columns']) col3, col4 = st.columns(2) with col3: num_results = st.slider("Max results:", 1, 100, 20) with col4: search_button = st.button("🔍 Search") if (search_button or not st.session_state['initial_search_done']) and query: st.session_state['initial_search_done'] = True selected_column = None if search_column == "All Fields" else search_column with st.spinner("Searching..."): results = search.search(query, selected_column, num_results) if results: st.session_state['search_history'].append({ 'query': query, 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'results': results[:5] }) st.write(f"Found {len(results)} results:") for i, result in enumerate(results, 1): with st.expander(f"Result {i}", expanded=(i==1)): render_result(result) else: st.warning("No matching results found.") # Voice Input Tab with tab2: st.subheader("Voice Search") st.write("🎙️ Record your query:") audio_bytes = audio_recorder() if audio_bytes: with st.spinner("Processing audio..."): audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav" with open(audio_path, "wb") as f: f.write(audio_bytes) voice_query = transcribe_audio(audio_path) st.markdown("**Transcribed Text:**") st.write(voice_query) st.session_state['last_voice_input'] = voice_query if st.button("🔍 Search from Voice"): results = search.search(voice_query, None, 20) for i, result in enumerate(results, 1): with st.expander(f"Result {i}", expanded=(i==1)): render_result(result) if os.path.exists(audio_path): os.remove(audio_path) # ArXiv Tab with tab3: st.subheader("ArXiv Search") arxiv_query = st.text_input("Search ArXiv:", value=st.session_state['arxiv_last_query']) vocal_summary = st.checkbox("🎙 Quick Audio Summary", value=True) titles_summary = st.checkbox("🔖 Titles Only", value=True) full_audio = st.checkbox("📚 Full Audio Summary", value=False) if st.button("🔍 Search ArXiv"): st.session_state['arxiv_last_query'] = arxiv_query perform_arxiv_lookup(arxiv_query, vocal_summary, titles_summary, full_audio) # File Manager Tab with tab4: show_file_manager() # Sidebar with st.sidebar: st.subheader("⚙️ Settings & History") if st.button("🗑️ Clear History"): st.session_state['search_history'] = [] st.experimental_rerun() st.markdown("### Recent Searches") for entry in reversed(st.session_state['search_history'][-5:]): with st.expander(f"{entry['timestamp']}: {entry['query']}"): for i, result in enumerate(entry['results'], 1): st.write(f"{i}. {result.get('description', '')[:100]}...") st.markdown("### Voice Settings") st.selectbox("TTS Voice:", [ "en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural" ], key="tts_voice") if __name__ == "__main__": main()