Spaces:
Sleeping
Sleeping
# Install required packages | |
#!pip install gradio yfinance transformers torch pandas plotly requests beautifulsoup4 | |
import gradio as gr | |
import yfinance as yf | |
import pandas as pd | |
import torch | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from transformers import pipeline | |
from datetime import datetime, timedelta | |
import requests | |
from bs4 import BeautifulSoup | |
import numpy as np | |
# Check for GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Initialize models | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device) | |
sentiment_analyzer = pipeline("sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis", device=device) | |
class CompanyResearchAgent: | |
def __init__(self): | |
self.cache = {} | |
def get_stock_data(self, symbol, period="1y"): | |
"""Fetch stock data using yfinance""" | |
try: | |
stock = yf.Ticker(symbol) | |
hist = stock.history(period=period) | |
return stock, hist | |
except Exception as e: | |
return None, None | |
def create_stock_chart(self, hist): | |
"""Create interactive stock price chart""" | |
if hist is None or hist.empty: | |
return None | |
fig = go.Figure() | |
fig.add_trace(go.Candlestick( | |
x=hist.index, | |
open=hist['Open'], | |
high=hist['High'], | |
low=hist['Low'], | |
close=hist['Close'], | |
name='Stock Price' | |
)) | |
fig.update_layout( | |
title="Stock Price History", | |
yaxis_title="Price", | |
xaxis_title="Date", | |
template="plotly_dark" | |
) | |
return fig | |
def get_news_sentiment(self, company_name): | |
"""Analyze news sentiment""" | |
try: | |
url = f"https://news.google.com/rss/search?q={company_name}+when:7d" | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, 'xml') | |
titles = [item.title.text for item in soup.find_all('item')[:5]] | |
sentiments = sentiment_analyzer(titles) | |
sentiment_scores = [s['score'] for s in sentiments] | |
avg_sentiment = sum(sentiment_scores) / len(sentiment_scores) | |
return { | |
'average_sentiment': round(avg_sentiment, 2), | |
'recent_news': titles | |
} | |
except Exception as e: | |
return { | |
'average_sentiment': 0, | |
'recent_news': ['Unable to fetch news'] | |
} | |
def generate_swot(self, stock, company_name): | |
"""Generate SWOT analysis using company data""" | |
if stock is None: | |
return "Unable to generate SWOT analysis - invalid stock data" | |
info = stock.info | |
# Create SWOT analysis text | |
swot_text = f""" | |
Company Analysis for {company_name}: | |
Sector: {info.get('sector', 'N/A')} | |
Industry: {info.get('industry', 'N/A')} | |
Market Cap: ${info.get('marketCap', 0):,.2f} | |
P/E Ratio: {info.get('trailingPE', 'N/A')} | |
Revenue Growth: {info.get('revenueGrowth', 'N/A')} | |
""" | |
# Summarize SWOT analysis | |
summary = summarizer(swot_text, max_length=150, min_length=50)[0]['summary_text'] | |
return summary | |
def analyze_company(self, symbol): | |
"""Main analysis function""" | |
try: | |
# Get stock data | |
stock, hist = self.get_stock_data(symbol) | |
if stock is None: | |
return "Invalid stock symbol", None, None, None | |
# Create visualization | |
stock_chart = self.create_stock_chart(hist) | |
# Get company info | |
info = stock.info | |
company_name = info.get('longName', symbol) | |
# Generate SWOT analysis | |
swot_analysis = self.generate_swot(stock, company_name) | |
# Get news sentiment | |
sentiment_data = self.get_news_sentiment(company_name) | |
# Prepare company overview | |
# Fixed: Indentation adjusted to align with the function definition | |
company_overview = ( | |
f"## {company_name} ({symbol})\n\n" | |
f"**Sector:** {info.get('sector', 'N/A')}\n" | |
f"**Industry:** {info.get('industry', 'N/A')}\n" | |
f"**Market Cap:** ${info.get('marketCap', 0):,.2f}\n" | |
f"**Current Price:** ${info.get('currentPrice', 0):,.2f}\n\n" | |
f"### News Sentiment Score: {sentiment_data['average_sentiment']}\n\n" | |
"Recent News:\n" | |
+ "\n".join(f"- {news}" for news in sentiment_data['recent_news']) + "\n\n" | |
f"### SWOT Analysis Summary:\n" | |
f"{swot_analysis}" | |
) | |
return company_overview, stock_chart, None, None | |
except Exception as e: | |
return f"Error analyzing company: {str(e)}", None, None, None | |
# Create Gradio interface | |
def create_interface(): | |
agent = CompanyResearchAgent() | |
with gr.Blocks(theme=gr.themes.Base()) as interface: | |
gr.Markdown("# Company Research Agent 📈") | |
gr.Markdown("Enter a stock symbol (e.g., AAPL, GOOGL, MSFT)") | |
with gr.Row(): | |
symbol_input = gr.Textbox(label="Stock Symbol") | |
analyze_btn = gr.Button("Analyze Company", variant="primary") | |
with gr.Row(): | |
with gr.Column(): | |
overview_output = gr.Markdown(label="Company Overview") | |
with gr.Column(): | |
chart_output = gr.Plot(label="Stock Price Chart") | |
analyze_btn.click( | |
fn=agent.analyze_company, | |
inputs=[symbol_input], | |
outputs=[overview_output, chart_output] | |
) | |
return interface | |
# Launch the interface | |
interface = create_interface() | |
interface.launch(debug=True, share=True) |