File size: 4,072 Bytes
5cefe7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

from transformers import pipeline, AutoTokenizer, AutoModel
from torchvision import models, transforms
from PIL import Image
import faiss

class TextClassifier:
    def __init__(self, model_name='distilbert-base-uncased'):
        self.classifier = pipeline("text-classification", model=model_name)

    def classify(self, text):
        return self.classifier(text)[0]['label']


class SentimentAnalyzer:
    def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
        self.analyzer = pipeline("sentiment-analysis", model=model_name)

    def analyze(self, text):
        return self.analyzer(text)[0]




class ImageRecognizer:
    def __init__(self, model_name='resnet50'):
        self.model = models.resnet50(pretrained=True)
        self.model.eval()
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def recognize(self, image_path):
        image = Image.open(image_path)
        image = self.transform(image).unsqueeze(0)
        with torch.no_grad():
            outputs = self.model(image)
        _, predicted = torch.max(outputs, 1)
        return predicted.item()



class TextGenerator:
    def __init__(self, model_name='gpt2'):
        self.generator = pipeline("text-generation", model=model_name)

    def generate(self, prompt):
        response = self.generator(prompt, max_length=100, num_return_sequences=1)
        return response[0]['generated_text']




class FAQRetriever:
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.index = faiss.IndexFlatL2(384)  # Dimension of MiniLM embeddings

    def embed(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
        with torch.no_grad():
            embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
        return embeddings.cpu().numpy()

    def add_faqs(self, faqs):
        self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
        faiss.normalize_L2(self.faq_embeddings)
        self.index.add(self.faq_embeddings)

    def retrieve(self, query):
        query_embedding = self.embed(query)
        faiss.normalize_L2(query_embedding)
        D, I = self.index.search(query_embedding, 5)
        return I[0]  # Return top 5 FAQ indices



class CustomerSupportAssistant:
    def __init__(self):
        self.text_classifier = TextClassifier()
        self.sentiment_analyzer = SentimentAnalyzer()
        self.image_recognizer = ImageRecognizer()
        self.text_generator = TextGenerator()
        self.faq_retriever = FAQRetriever()
        self.faqs = [
            "How to reset my password?",
            "What is the return policy?",
            "How to track my order?",
            "How to contact customer support?",
            "What payment methods are accepted?"
        ]
        self.faq_retriever.add_faqs(self.faqs)

    def process_query(self, text, image_path=None):
        topic = self.text_classifier.classify(text)
        sentiment = self.sentiment_analyzer.analyze(text)
        if image_path:
            image_info = self.image_recognizer.recognize(image_path)
        else:
            image_info = "No image provided."
        faqs = self.faq_retriever.retrieve(text)
        faq_responses = [self.faqs[i] for i in faqs]
        response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
        response = self.text_generator.generate(response_prompt)
        return response

# Example usage:
assistant = CustomerSupportAssistant()
input_text = "I'm having trouble with my recent order."
output = assistant.process_query(input_text)
print(output)