kartik91 commited on
Commit
5cefe7f
·
verified ·
1 Parent(s): 4d576a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import pipeline, AutoTokenizer, AutoModel
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+ import faiss
6
+
7
+ class TextClassifier:
8
+ def __init__(self, model_name='distilbert-base-uncased'):
9
+ self.classifier = pipeline("text-classification", model=model_name)
10
+
11
+ def classify(self, text):
12
+ return self.classifier(text)[0]['label']
13
+
14
+
15
+ class SentimentAnalyzer:
16
+ def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
17
+ self.analyzer = pipeline("sentiment-analysis", model=model_name)
18
+
19
+ def analyze(self, text):
20
+ return self.analyzer(text)[0]
21
+
22
+
23
+
24
+
25
+ class ImageRecognizer:
26
+ def __init__(self, model_name='resnet50'):
27
+ self.model = models.resnet50(pretrained=True)
28
+ self.model.eval()
29
+ self.transform = transforms.Compose([
30
+ transforms.Resize(256),
31
+ transforms.CenterCrop(224),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
34
+ ])
35
+
36
+ def recognize(self, image_path):
37
+ image = Image.open(image_path)
38
+ image = self.transform(image).unsqueeze(0)
39
+ with torch.no_grad():
40
+ outputs = self.model(image)
41
+ _, predicted = torch.max(outputs, 1)
42
+ return predicted.item()
43
+
44
+
45
+
46
+ class TextGenerator:
47
+ def __init__(self, model_name='gpt2'):
48
+ self.generator = pipeline("text-generation", model=model_name)
49
+
50
+ def generate(self, prompt):
51
+ response = self.generator(prompt, max_length=100, num_return_sequences=1)
52
+ return response[0]['generated_text']
53
+
54
+
55
+
56
+
57
+ class FAQRetriever:
58
+ def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
59
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ self.model = AutoModel.from_pretrained(model_name)
61
+ self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings
62
+
63
+ def embed(self, text):
64
+ inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
65
+ with torch.no_grad():
66
+ embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
67
+ return embeddings.cpu().numpy()
68
+
69
+ def add_faqs(self, faqs):
70
+ self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
71
+ faiss.normalize_L2(self.faq_embeddings)
72
+ self.index.add(self.faq_embeddings)
73
+
74
+ def retrieve(self, query):
75
+ query_embedding = self.embed(query)
76
+ faiss.normalize_L2(query_embedding)
77
+ D, I = self.index.search(query_embedding, 5)
78
+ return I[0] # Return top 5 FAQ indices
79
+
80
+
81
+
82
+ class CustomerSupportAssistant:
83
+ def __init__(self):
84
+ self.text_classifier = TextClassifier()
85
+ self.sentiment_analyzer = SentimentAnalyzer()
86
+ self.image_recognizer = ImageRecognizer()
87
+ self.text_generator = TextGenerator()
88
+ self.faq_retriever = FAQRetriever()
89
+ self.faqs = [
90
+ "How to reset my password?",
91
+ "What is the return policy?",
92
+ "How to track my order?",
93
+ "How to contact customer support?",
94
+ "What payment methods are accepted?"
95
+ ]
96
+ self.faq_retriever.add_faqs(self.faqs)
97
+
98
+ def process_query(self, text, image_path=None):
99
+ topic = self.text_classifier.classify(text)
100
+ sentiment = self.sentiment_analyzer.analyze(text)
101
+ if image_path:
102
+ image_info = self.image_recognizer.recognize(image_path)
103
+ else:
104
+ image_info = "No image provided."
105
+ faqs = self.faq_retriever.retrieve(text)
106
+ faq_responses = [self.faqs[i] for i in faqs]
107
+ response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
108
+ response = self.text_generator.generate(response_prompt)
109
+ return response
110
+
111
+ # Example usage:
112
+ assistant = CustomerSupportAssistant()
113
+ input_text = "I'm having trouble with my recent order."
114
+ output = assistant.process_query(input_text)
115
+ print(output)