Spaces:
Sleeping
Sleeping
from transformers import pipeline, AutoTokenizer, AutoModel | |
from torchvision import models, transforms | |
from PIL import Image | |
import faiss | |
class TextClassifier: | |
def __init__(self, model_name='distilbert-base-uncased'): | |
self.classifier = pipeline("text-classification", model=model_name) | |
def classify(self, text): | |
return self.classifier(text)[0]['label'] | |
class SentimentAnalyzer: | |
def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'): | |
self.analyzer = pipeline("sentiment-analysis", model=model_name) | |
def analyze(self, text): | |
return self.analyzer(text)[0] | |
class ImageRecognizer: | |
def __init__(self, model_name='resnet50'): | |
self.model = models.resnet50(pretrained=True) | |
self.model.eval() | |
self.transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def recognize(self, image_path): | |
image = Image.open(image_path) | |
image = self.transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = self.model(image) | |
_, predicted = torch.max(outputs, 1) | |
return predicted.item() | |
class TextGenerator: | |
def __init__(self, model_name='gpt2'): | |
self.generator = pipeline("text-generation", model=model_name) | |
def generate(self, prompt): | |
response = self.generator(prompt, max_length=100, num_return_sequences=1) | |
return response[0]['generated_text'] | |
class FAQRetriever: | |
def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings | |
def embed(self, text): | |
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
with torch.no_grad(): | |
embeddings = self.model(**inputs).last_hidden_state.mean(dim=1) | |
return embeddings.cpu().numpy() | |
def add_faqs(self, faqs): | |
self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs]) | |
faiss.normalize_L2(self.faq_embeddings) | |
self.index.add(self.faq_embeddings) | |
def retrieve(self, query): | |
query_embedding = self.embed(query) | |
faiss.normalize_L2(query_embedding) | |
D, I = self.index.search(query_embedding, 5) | |
return I[0] # Return top 5 FAQ indices | |
class CustomerSupportAssistant: | |
def __init__(self): | |
self.text_classifier = TextClassifier() | |
self.sentiment_analyzer = SentimentAnalyzer() | |
self.image_recognizer = ImageRecognizer() | |
self.text_generator = TextGenerator() | |
self.faq_retriever = FAQRetriever() | |
self.faqs = [ | |
"How to reset my password?", | |
"What is the return policy?", | |
"How to track my order?", | |
"How to contact customer support?", | |
"What payment methods are accepted?" | |
] | |
self.faq_retriever.add_faqs(self.faqs) | |
def process_query(self, text, image_path=None): | |
topic = self.text_classifier.classify(text) | |
sentiment = self.sentiment_analyzer.analyze(text) | |
if image_path: | |
image_info = self.image_recognizer.recognize(image_path) | |
else: | |
image_info = "No image provided." | |
faqs = self.faq_retriever.retrieve(text) | |
faq_responses = [self.faqs[i] for i in faqs] | |
response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response." | |
response = self.text_generator.generate(response_prompt) | |
return response | |
# Example usage: | |
assistant = CustomerSupportAssistant() | |
input_text = "I'm having trouble with my recent order." | |
output = assistant.process_query(input_text) | |
print(output) |