Real vectorized trimap generation
In your paper in listing 1, you give the implementation of
an efficient vectorized version of the confidence trimap generation algorithm
I think your implementation could still be improved a bit. It is not yet fully vectorized, since it steps through various thresholds in the while loop in line 33.
I do not know the range of pred
in your paper, but here is a fully vectorized version with a benchmark on some random image for which it is much faster. Maybe it is useful to you.
# https://arxiv.org/pdf/2501.06230
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import time
import urllib.request
def generate_trimap_vectorized(pred, min_unknown_pixels=60_000):
# Quantize predictions to 256 levels and count them
values = pred.sigmoid() * 256
histc = values.view(-1).histc(bins=256, min=0, max=255)
counts = pred.numel() - (histc + histc.flip(0)).cumsum(0)
# Get first index where count of unknown pixels exceeds threshold
index = (counts > min_unknown_pixels).to(int).argmin()
# If you want to, you could clamp the index to some reasonable range
trimap = (values > 255 - index) * 127.0 + (values >= index) * 128.0
return trimap
def generate_trimap(pred):
min_pixels = 60000
t_high = 0.90
t_low = 0.10
t_min = 0.03 # Lower bound
t_max = 0.97 # Upper bound
step = 0.001 # Adjustment size
mask = pred.sigmoid () # Apply sigmoid to prediction
# Generate initial trimap
trimap = torch.where (
mask >= t_high ,
mask.new_tensor (255.0) ,
torch.where (
mask <= t_low ,
mask.new_tensor (0.0) ,
mask.new_tensor (128.0)
)
)
# Count gray pixels and adjust thresholds
n_gray = ( trimap == 128).sum ().item ()
while n_gray < min_pixels :
t_low = max ( t_low - step , t_min )
t_high = min ( t_high + step , t_max )
if ( t_low <= t_min and t_high >= t_max ) :
break # Exit if bounds reached
trimap = torch.where (
mask >= t_high ,
trimap.new_tensor (255.0) ,
torch.where (
mask <= t_low ,
trimap.new_tensor (0.0) ,
trimap.new_tensor (128.0)
)
)
n_gray = ( trimap == 128).sum ().item ()
return trimap
def main():
# Download test file
url = "https://raw.githubusercontent.com/frcs/alternative-matting-laplacian/master/result-alpha-GT04.png"
if not os.path.isfile("alpha.png"):
urllib.request.urlretrieve(url, "alpha.png")
alpha = np.array(Image.open("alpha.png").convert("L"))
alpha = alpha.astype(np.float32) / 255.0
alpha = torch.from_numpy(alpha)
pred = alpha * 10 - 5
for _ in range(10):
torch.cuda.synchronize()
t = time.perf_counter()
trimap = generate_trimap(pred)
torch.cuda.synchronize()
dt1 = time.perf_counter() - t
torch.cuda.synchronize()
t = time.perf_counter()
trimap_new = generate_trimap_vectorized(pred)
torch.cuda.synchronize()
dt2 = time.perf_counter() - t
print(f"{dt1 * 1000:7.3f} ms for generate_trimap")
print(f"{dt2 * 1000:7.3f} ms for generate_trimap_vectorized")
print()
for i, img in enumerate([alpha, trimap, trimap_new]):
plt.subplot(1, 3, 1 + i)
plt.imshow(img.detach().cpu().numpy(), cmap="gray")
plt.show()
if __name__ == "__main__":
main()
Hello, thank you for your comment. We are still preparing the paper for peer review, so outside perspectives are greatly appreciated. The range of the pred is 0 and 1 as it is the result of sigmoid(logits). This algorithm was used for BEN, and our BEN2 models employ a fixed range. This prevents the steps through the thresholds entirely:
#BEN2 trimap generation
min_low_threshold = 0.01 # Set minimum limit for low_threshold
max_high_threshold = 0.99 # Set maximum limit for high_threshold
mask = predicted_output.sigmoid() # Use the sigmoid output directly
# Start with initial trimap
trimap = torch.where(
mask >= max_high_threshold,
mask.new_tensor(255.0),
torch.where(
mask <= min_low_threshold,
mask.new_tensor(0.0),
mask.new_tensor(128.0)
)
)
We found better model generalization results this way. The purpose of the PyTorch implementation was more for research and demonstration.