anonymousatom commited on
Commit
c0f6127
·
verified ·
1 Parent(s): 5fc1dbc

Upload 2 files

Browse files
Files changed (2) hide show
  1. CLIP_CreativeTesting.py +76 -0
  2. detect_adv.py +139 -0
CLIP_CreativeTesting.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from PIL import Image
4
+ import requests
5
+ # Create function to pass input image and text to model and return the label probabilities
6
+ import torch
7
+ import time
8
+ from detect_adv import detect_text, analyze_layout, analyze_shapes
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ # Streamlit code to upload image and output label probabilities
11
+ import streamlit as st
12
+ import tempfile
13
+
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
16
+
17
+
18
+ def get_label_probs(image, text, model, processor):
19
+ torch.cuda.empty_cache() # Release cached memory
20
+ inputs = processor(text=text, images=image, return_tensors="pt", padding=True)
21
+ inputs = inputs
22
+ outputs = model(**inputs)
23
+ logits_per_image = outputs.logits_per_image
24
+ probs = logits_per_image.softmax(dim=1)
25
+ # Clear GPU memory
26
+ torch.cuda.empty_cache()
27
+ del inputs, outputs, logits_per_image
28
+ return probs
29
+
30
+ text = ['Advertisement Creative(Contains Text)', 'Not an Advertisement Creative(Contains No Text)', 'Simple Product Image and not an Advertisement)']
31
+
32
+
33
+
34
+ st.title("Advertisement Detection using CLIP")
35
+
36
+ # Upload image
37
+ uploaded_image = st.file_uploader("Choose an image...", type="jpg")
38
+
39
+
40
+ if uploaded_image is not None:
41
+ temp_dir = tempfile.mkdtemp()
42
+ path = os.path.join(temp_dir, uploaded_image.name)
43
+ with open(path, "wb") as f:
44
+ f.write(uploaded_image.getvalue())
45
+
46
+ image = Image.open(uploaded_image)
47
+ st.image(image, caption="Uploaded Image.", use_column_width=True)
48
+ # Get label probabilities
49
+ probs = get_label_probs(image, text, model, processor)
50
+ # Output label probabilities
51
+ prob = probs.tolist()
52
+ prob = prob[0]
53
+ # st.write("Label Probabilities:", prob)
54
+ # st.write("Label Probabilities:", probs)
55
+ # # Output predicted label
56
+ # predicted_label = text[torch.argmax(probs[0])]
57
+ # st.write("Predicted Label:", predicted_label)
58
+
59
+ # Augmenting using classic techniques
60
+ layout_result = analyze_layout(path)
61
+ shape_result = analyze_shapes(path)
62
+ #
63
+ # # Output classic technique results
64
+ # st.write("Layout Analysis Result:", layout_result)
65
+ # st.write("Shape Analysis Result:", shape_result)
66
+ final_out = False
67
+ # Find index of max value from list
68
+ max_index = prob.index(max(prob))
69
+ if max_index == 0 and (layout_result == True or shape_result == True):
70
+ final_out = True
71
+ # Write 'Advertisement' if the image is an advertisement
72
+ if final_out == True:
73
+ st.write("Advertisement")
74
+ else:
75
+ st.write("Not an Advertisement")
76
+
detect_adv.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from collections import Counter
4
+
5
+
6
+
7
+ from paddleocr import PaddleOCR, draw_ocr
8
+
9
+ # Paddleocr supports Chinese, English, French, German, Korean and Japanese.
10
+ # You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
11
+ # to switch the language model in order.
12
+ ocr = PaddleOCR(use_angle_cls=True, lang='en') # need to run only once to download and load model into memory
13
+ def detect_text(image):
14
+
15
+ result = ocr.ocr(image, cls=True)
16
+ txt = ''
17
+ for idx in range(len(result)):
18
+ res = result[idx]
19
+ for line in res:
20
+ txt += line[1][0]
21
+
22
+ return txt
23
+ def analyze_text(text):
24
+ marketing_keywords = ['sale', 'offer', 'discount', 'promotion', 'limited', 'buy', 'now', ]
25
+
26
+ # Count the occurrences of marketing keywords
27
+ word_count = Counter([word.lower() for word in text.split()])
28
+ keyword_count = sum(word_count[keyword] for keyword in marketing_keywords)
29
+
30
+ # Classify based on the number of marketing keywords
31
+ if keyword_count > 2:
32
+ return "Advertisement"
33
+ else:
34
+ return "Normal Product Image"
35
+
36
+
37
+ # Point 2: Layout and Composition Analysis
38
+
39
+ def analyze_layout(image_path):
40
+
41
+ image = cv2.imread(image_path)
42
+
43
+ # Convert image to grayscale
44
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
45
+
46
+ # Invert the grayscale image
47
+ inverted = cv2.bitwise_not(gray)
48
+
49
+ # Apply Otsu's thresholding
50
+ _, thresholded = cv2.threshold(inverted, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
51
+
52
+ # Find contours in the thresholded image
53
+ contours, _ = cv2.findContours(thresholded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
+
55
+ # Initialize counters
56
+ asymmetric_count = 0
57
+ dynamic_shape_count = 0
58
+
59
+ # Iterate through contours
60
+ for contour in contours:
61
+ # Calculate the bounding rectangle of the contour
62
+ x, y, w, h = cv2.boundingRect(contour)
63
+
64
+ # Calculate aspect ratio
65
+ aspect_ratio = float(w) / h
66
+
67
+ # Check for asymmetric layout
68
+ if aspect_ratio < 0.8 or aspect_ratio > 1.2:
69
+ asymmetric_count += 1
70
+
71
+ # Check for dynamic shape
72
+ if len(contour) > 5:
73
+ _, _, angle = cv2.fitEllipse(contour)
74
+ if angle > 30 and angle < 150:
75
+ dynamic_shape_count += 1
76
+
77
+ # Determine if it's an advertisement based on criteria
78
+ is_advertisement = False
79
+ if asymmetric_count > 1 or dynamic_shape_count > 1:
80
+ is_advertisement = True
81
+
82
+ return is_advertisement
83
+
84
+
85
+
86
+
87
+
88
+ # Point 3: Color Analysis
89
+ def analyze_color(image_path):
90
+ # Load the image
91
+ image = cv2.imread(image_path)
92
+
93
+ # Convert image to HSV
94
+ hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
95
+
96
+ # Calculate mean saturation and value
97
+ mean_saturation = np.mean(hsv_image[:, :, 1])
98
+ mean_value = np.mean(hsv_image[:, :, 2])
99
+
100
+ # Check for high saturation and value (vivid colors)
101
+ if mean_saturation > 150 and mean_value > 150:
102
+ return "Advertisement"
103
+ else:
104
+ return "Not Advertisement"
105
+
106
+
107
+
108
+ # Point 4: Edge Detection and Shape Analysis
109
+ def analyze_shapes(image_path):
110
+ image = cv2.imread(image_path)
111
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
112
+ edges = cv2.Canny(gray, 100, 200)
113
+
114
+ # Find contours
115
+ contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116
+
117
+ # Check for specific shapes (e.g., arrows, starbursts)
118
+ for cnt in contours:
119
+ approx = cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True)
120
+ if len(approx) in [3, 5, 7]: # Triangles, pentagons, or starbursts
121
+ return True
122
+
123
+ return False
124
+
125
+
126
+ # # Load the image
127
+ # image = '/home/karun/PycharmProjects/AdGod/250.jpg'
128
+ # img_open = Image.open(image)
129
+ #
130
+ # # Analyze the image using different techniques
131
+ # text_result = analyze_text(detect_text(image))
132
+ # layout_result = analyze_layout(image)
133
+ # color_result = analyze_color(image)
134
+ # shape_result = analyze_shapes(image)
135
+ #
136
+ # # Print the results
137
+ # print("Text Analysis Result:", text_result)
138
+ # print("Layout Analysis Result:", layout_result)
139
+ # print("Shape Analysis Result:", shape_result)