import os import gradio as gr import pdfplumber import requests import faiss import json import torch from bs4 import BeautifulSoup from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from sentence_transformers import SentenceTransformer import numpy as np import tempfile import logging from datetime import datetime from typing import List, Dict # Optimize CUDA memory management os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CaseStudyGenerator: def __init__(self): self.model_name = "facebook/opt-2.7b" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.device = "cuda" if torch.cuda.is_available() else "cpu" # Clear any reserved memory if self.device == "cuda": torch.cuda.empty_cache() torch.cuda.ipc_collect() model_kwargs = { 'torch_dtype': torch.float16 if self.device == "cuda" else torch.float32 } try: self.model = AutoModelForCausalLM.from_pretrained(self.model_name, **model_kwargs) if self.device == "cuda": self.model = self.model.to(self.device) self.model.gradient_checkpointing_enable() except RuntimeError as e: logger.warning(f"Memory issue detected: {e}, attempting 8-bit loading.") try: from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_8bit=True) self.model = AutoModelForCausalLM.from_pretrained(self.model_name, quantization_config=quantization_config) except ImportError: logger.error("Missing 'bitsandbytes'. Install it using 'pip install -U bitsandbytes'") logger.info("Switching to CPU to continue operations.") self.device = "cpu" self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float32) self.generator = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1, max_length=2048, num_return_sequences=1, temperature=0.8, top_p=0.95, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.dimension = 384 self.index = faiss.IndexFlatL2(self.dimension) self.stored_texts: List[Dict] = [] def clean_url(self, url: str) -> str: if not url.startswith(('http://', 'https://')): return "" return url.split('?')[0][:100] def fetch_articles(self, topic: str) -> List[str]: try: search_url = f"https://www.google.com/search?q={topic.replace(' ', '+')}+case+study+manufacturing+strategy" headers = {"User-Agent": "Mozilla/5.0"} response = requests.get(search_url, headers=headers, timeout=10) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") articles = [self.clean_url(link.get("href", "")) for link in soup.find_all("a") if "google" not in link.get("href", "")] return articles[:5] or ["No articles found"] except Exception as e: logger.error(f"Error fetching articles: {str(e)}") return ["Error fetching articles"] def process_pdf(self, pdf_file) -> str: try: if pdf_file is None: return "No PDF provided" with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf: temp_pdf.write(pdf_file.read()) temp_path = temp_pdf.name text = [] with pdfplumber.open(temp_path) as pdf: text = [page.extract_text().strip() for page in pdf.pages if page.extract_text()] os.unlink(temp_path) return "\n".join(text) or "No text extracted from PDF" except Exception as e: logger.error(f"Error processing PDF: {str(e)}") return "Error processing PDF" def generate_case_study(self, topic: str, pdf=None) -> str: try: if self.device == "cuda": torch.cuda.empty_cache() articles = self.fetch_articles(topic) pdf_text = self.process_pdf(pdf) if pdf else "No PDF provided" prompt = f"""Write a professional case study about {topic}. Background Information: - Topic: {topic} - Supporting Documents: {pdf_text[:500]} - Related Sources: {', '.join(articles)} Format your response as: 1. Executive Summary 2. Company Background 3. Challenge Analysis 4. Strategic Implementation 5. Results and Impact 6. Key Learnings """ output = self.generator( prompt, max_new_tokens=1024, num_return_sequences=1, temperature=0.8, top_p=0.95, do_sample=True, repetition_penalty=1.2, no_repeat_ngram_size=3 ) case_study = output[0]['generated_text'].replace(prompt, "").strip() embedding = self.embedding_model.encode([case_study])[0] self.index.add(embedding.reshape(1, -1)) self.stored_texts.append({ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "topic": topic, "content": case_study }) return case_study except Exception as e: logger.error(f"Error generating case study: {str(e)}") return f"Error generating case study: {str(e)}" def retrieve_past_case_studies(self) -> str: try: if not self.stored_texts: return "No case studies generated yet." result = "" for idx, case in enumerate(self.stored_texts[-5:], start=1): result += f"Case Study {idx}\nTopic: {case['topic']}\nGenerated on: {case['timestamp']}\n\n{case['content']}\n\n=== End of Case Study ===\n\n" return result except Exception as e: logger.error(f"Error retrieving past case studies: {str(e)}") return "Error retrieving past case studies" # Gradio interface with gr.Blocks() as app: gr.Markdown("# AI Case Study Generator (Optimized for GPU-T4 & CPU)") with gr.Row(): topic = gr.Textbox(label="Enter Topic") pdf = gr.File(label="Upload PDF", type="binary") with gr.Row(): generate_btn = gr.Button("Generate Case Study") retrieve_btn = gr.Button("Retrieve Past Case Studies") output = gr.Textbox(label="Generated Case Study", lines=20) past_cases = gr.Textbox(label="Past Case Studies", lines=20) generator = CaseStudyGenerator() generate_btn.click(generator.generate_case_study, inputs=[topic, pdf], outputs=output) retrieve_btn.click(generator.retrieve_past_case_studies, outputs=past_cases) # Launch the application if __name__ == "__main__": app.launch(share=True) # Remove enable_queue # or, If using Gradio 3.x or later, use: # app.queue().launch(share=True)