multimodalart HF staff commited on
Commit
8350a42
·
1 Parent(s): da291f3

Update sketch_helper.py

Browse files
Files changed (1) hide show
  1. sketch_helper.py +10 -12
sketch_helper.py CHANGED
@@ -6,20 +6,18 @@ from skimage.color import lab2rgb
6
  from sklearn.cluster import KMeans
7
 
8
  def color_quantization(image, n_colors):
9
- # Determine the number of bins dynamically
10
- unique_colors = np.unique(image.reshape(-1, 3), axis=0)
11
- n_bins = int(np.ceil(np.sqrt(unique_colors.shape[0])))
12
 
13
- # Cluster the colors using k-means
14
- kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(unique_colors)
15
- colors = kmeans.cluster_centers_
16
 
17
- # Replace each pixel with the closest color
18
- dists = np.sqrt(np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2))
19
- labels = np.argmin(dists, axis=1)
20
- new_image = colors[labels].reshape(image.shape).astype(np.uint8)
21
-
22
- return new_image
23
 
24
 
25
  def get_high_freq_colors(image):
 
6
  from sklearn.cluster import KMeans
7
 
8
  def color_quantization(image, n_colors):
9
+ # Reshape the image into a 2D array of pixels
10
+ pixels = image.reshape(-1, 3)
 
11
 
12
+ # Fit k-means clustering algorithm to the pixel data
13
+ kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(pixels)
 
14
 
15
+ # Replace each pixel with its nearest cluster center
16
+ labels = kmeans.predict(pixels)
17
+ colors = kmeans.cluster_centers_.astype(np.uint8)
18
+ quantized_image = colors[labels].reshape(image.shape)
19
+
20
+ return quantized_image
21
 
22
 
23
  def get_high_freq_colors(image):