data_project / app.py
kartik91's picture
Create app.py
5cefe7f verified
raw
history blame
4.07 kB
from transformers import pipeline, AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
import faiss
class TextClassifier:
def __init__(self, model_name='distilbert-base-uncased'):
self.classifier = pipeline("text-classification", model=model_name)
def classify(self, text):
return self.classifier(text)[0]['label']
class SentimentAnalyzer:
def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
self.analyzer = pipeline("sentiment-analysis", model=model_name)
def analyze(self, text):
return self.analyzer(text)[0]
class ImageRecognizer:
def __init__(self, model_name='resnet50'):
self.model = models.resnet50(pretrained=True)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def recognize(self, image_path):
image = Image.open(image_path)
image = self.transform(image).unsqueeze(0)
with torch.no_grad():
outputs = self.model(image)
_, predicted = torch.max(outputs, 1)
return predicted.item()
class TextGenerator:
def __init__(self, model_name='gpt2'):
self.generator = pipeline("text-generation", model=model_name)
def generate(self, prompt):
response = self.generator(prompt, max_length=100, num_return_sequences=1)
return response[0]['generated_text']
class FAQRetriever:
def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings
def embed(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
return embeddings.cpu().numpy()
def add_faqs(self, faqs):
self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
faiss.normalize_L2(self.faq_embeddings)
self.index.add(self.faq_embeddings)
def retrieve(self, query):
query_embedding = self.embed(query)
faiss.normalize_L2(query_embedding)
D, I = self.index.search(query_embedding, 5)
return I[0] # Return top 5 FAQ indices
class CustomerSupportAssistant:
def __init__(self):
self.text_classifier = TextClassifier()
self.sentiment_analyzer = SentimentAnalyzer()
self.image_recognizer = ImageRecognizer()
self.text_generator = TextGenerator()
self.faq_retriever = FAQRetriever()
self.faqs = [
"How to reset my password?",
"What is the return policy?",
"How to track my order?",
"How to contact customer support?",
"What payment methods are accepted?"
]
self.faq_retriever.add_faqs(self.faqs)
def process_query(self, text, image_path=None):
topic = self.text_classifier.classify(text)
sentiment = self.sentiment_analyzer.analyze(text)
if image_path:
image_info = self.image_recognizer.recognize(image_path)
else:
image_info = "No image provided."
faqs = self.faq_retriever.retrieve(text)
faq_responses = [self.faqs[i] for i in faqs]
response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
response = self.text_generator.generate(response_prompt)
return response
# Example usage:
assistant = CustomerSupportAssistant()
input_text = "I'm having trouble with my recent order."
output = assistant.process_query(input_text)
print(output)