Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from transformers import pipeline, AutoTokenizer, AutoModel
|
3 |
+
from torchvision import models, transforms
|
4 |
+
from PIL import Image
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
class TextClassifier:
|
8 |
+
def __init__(self, model_name='distilbert-base-uncased'):
|
9 |
+
self.classifier = pipeline("text-classification", model=model_name)
|
10 |
+
|
11 |
+
def classify(self, text):
|
12 |
+
return self.classifier(text)[0]['label']
|
13 |
+
|
14 |
+
|
15 |
+
class SentimentAnalyzer:
|
16 |
+
def __init__(self, model_name='nlptown/bert-base-multilingual-uncased-sentiment'):
|
17 |
+
self.analyzer = pipeline("sentiment-analysis", model=model_name)
|
18 |
+
|
19 |
+
def analyze(self, text):
|
20 |
+
return self.analyzer(text)[0]
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
class ImageRecognizer:
|
26 |
+
def __init__(self, model_name='resnet50'):
|
27 |
+
self.model = models.resnet50(pretrained=True)
|
28 |
+
self.model.eval()
|
29 |
+
self.transform = transforms.Compose([
|
30 |
+
transforms.Resize(256),
|
31 |
+
transforms.CenterCrop(224),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
34 |
+
])
|
35 |
+
|
36 |
+
def recognize(self, image_path):
|
37 |
+
image = Image.open(image_path)
|
38 |
+
image = self.transform(image).unsqueeze(0)
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = self.model(image)
|
41 |
+
_, predicted = torch.max(outputs, 1)
|
42 |
+
return predicted.item()
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class TextGenerator:
|
47 |
+
def __init__(self, model_name='gpt2'):
|
48 |
+
self.generator = pipeline("text-generation", model=model_name)
|
49 |
+
|
50 |
+
def generate(self, prompt):
|
51 |
+
response = self.generator(prompt, max_length=100, num_return_sequences=1)
|
52 |
+
return response[0]['generated_text']
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
class FAQRetriever:
|
58 |
+
def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2'):
|
59 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
60 |
+
self.model = AutoModel.from_pretrained(model_name)
|
61 |
+
self.index = faiss.IndexFlatL2(384) # Dimension of MiniLM embeddings
|
62 |
+
|
63 |
+
def embed(self, text):
|
64 |
+
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True)
|
65 |
+
with torch.no_grad():
|
66 |
+
embeddings = self.model(**inputs).last_hidden_state.mean(dim=1)
|
67 |
+
return embeddings.cpu().numpy()
|
68 |
+
|
69 |
+
def add_faqs(self, faqs):
|
70 |
+
self.faq_embeddings = np.concatenate([self.embed(faq) for faq in faqs])
|
71 |
+
faiss.normalize_L2(self.faq_embeddings)
|
72 |
+
self.index.add(self.faq_embeddings)
|
73 |
+
|
74 |
+
def retrieve(self, query):
|
75 |
+
query_embedding = self.embed(query)
|
76 |
+
faiss.normalize_L2(query_embedding)
|
77 |
+
D, I = self.index.search(query_embedding, 5)
|
78 |
+
return I[0] # Return top 5 FAQ indices
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class CustomerSupportAssistant:
|
83 |
+
def __init__(self):
|
84 |
+
self.text_classifier = TextClassifier()
|
85 |
+
self.sentiment_analyzer = SentimentAnalyzer()
|
86 |
+
self.image_recognizer = ImageRecognizer()
|
87 |
+
self.text_generator = TextGenerator()
|
88 |
+
self.faq_retriever = FAQRetriever()
|
89 |
+
self.faqs = [
|
90 |
+
"How to reset my password?",
|
91 |
+
"What is the return policy?",
|
92 |
+
"How to track my order?",
|
93 |
+
"How to contact customer support?",
|
94 |
+
"What payment methods are accepted?"
|
95 |
+
]
|
96 |
+
self.faq_retriever.add_faqs(self.faqs)
|
97 |
+
|
98 |
+
def process_query(self, text, image_path=None):
|
99 |
+
topic = self.text_classifier.classify(text)
|
100 |
+
sentiment = self.sentiment_analyzer.analyze(text)
|
101 |
+
if image_path:
|
102 |
+
image_info = self.image_recognizer.recognize(image_path)
|
103 |
+
else:
|
104 |
+
image_info = "No image provided."
|
105 |
+
faqs = self.faq_retriever.retrieve(text)
|
106 |
+
faq_responses = [self.faqs[i] for i in faqs]
|
107 |
+
response_prompt = f"Topic: {topic}, Sentiment: {sentiment['label']} with confidence {sentiment['score']}. FAQs: {faq_responses}. Image info: {image_info}. Generate a response."
|
108 |
+
response = self.text_generator.generate(response_prompt)
|
109 |
+
return response
|
110 |
+
|
111 |
+
# Example usage:
|
112 |
+
assistant = CustomerSupportAssistant()
|
113 |
+
input_text = "I'm having trouble with my recent order."
|
114 |
+
output = assistant.process_query(input_text)
|
115 |
+
print(output)
|