Arrcttacsrks commited on
Commit
6f77f61
·
verified ·
1 Parent(s): 5bdfe95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -32
app.py CHANGED
@@ -6,14 +6,12 @@ from torch.autograd import Variable
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
- import ezdxf # Thêm thư viện ezdxf để tạo file DXF
10
  from PIL import Image, UnidentifiedImageError
11
 
12
- # Chuẩn hóa dự đoán
13
  def normPRED(d):
14
  return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
15
 
16
- # Hàm suy luận với U2NET
17
  def inference(net, input_img):
18
  input_img = input_img / np.max(input_img)
19
  tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
@@ -22,49 +20,63 @@ def inference(net, input_img):
22
  tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
23
  tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
24
  tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
25
- d1, _, _, _, _, _, _ = net(tmpImg)
26
  pred = normPRED(1.0 - d1[:, 0, :, :])
27
  return pred.cpu().data.numpy().squeeze()
28
 
29
- # Hàm tạo file DXF từ ảnh kết quả
30
- def convert_to_dxf(image, filename="output.dxf"):
 
 
 
 
 
 
 
 
31
  # Tìm các đường nét trong ảnh
32
- contours, _ = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
33
 
34
- # Tạo file DXF và thêm các đường nét
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  doc = ezdxf.new(dxfversion="R2010")
36
  msp = doc.modelspace()
37
  for contour in contours:
38
  points = contour.reshape(-1, 2)
39
  msp.add_lwpolyline(points, close=True)
40
-
41
  doc.saveas(filename)
42
  return filename
43
 
44
- # Hàm chính để xử lý ảnh đầu vào, trả về ảnh chân dung và lưu DXF
45
  def process_image(img, bw_option):
46
  try:
47
- # Đảm bảo ảnh đầu vào hợp lệ
48
- img = Image.open(img).convert("RGB") # Chuyển ảnh sang RGB nếu cần
49
  img = np.array(img)
50
-
51
- # Chuyển đổi ảnh thành đen trắng nếu được chọn
52
  if bw_option:
53
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
54
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # Chuyển lại thành ảnh 3 kênh cho mô hình
55
-
56
- # Chạy suy luận để tạo ảnh chân dung
57
  result = inference(u2net, img)
58
  result_img = (result * 255).astype(np.uint8)
59
-
60
- # Tạo file DXF từ kết quả
61
- dxf_path = convert_to_dxf(result_img)
62
-
63
  return result_img, dxf_path
64
  except UnidentifiedImageError:
65
- return "Lỗi: Không thể nhận diện file ảnh. Hãy đảm bảo file đầu vào ảnh hợp lệ.", None
66
 
67
- # Tải mô hình từ Hugging Face Hub
68
  def load_u2net_model():
69
  model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
70
  net = U2NET(3, 1)
@@ -72,22 +84,20 @@ def load_u2net_model():
72
  net.eval()
73
  return net
74
 
75
- # Khởi tạo mô hình U2NET
76
  u2net = load_u2net_model()
77
 
78
- # Tạo giao diện với Gradio
79
  iface = gr.Interface(
80
  fn=process_image,
81
  inputs=[
82
- gr.Image(type="filepath", label="Tải lên ảnh của bạn"),
83
- gr.Checkbox(label="Chuyển sang trắng đen?", value=False)
84
  ],
85
  outputs=[
86
- gr.Image(type="numpy", label="Kết quả chân dung"),
87
- gr.File(label="Tải xuống file DXF") # Thêm output cho file DXF
88
  ],
89
- title="Tạo ảnh chân dung file DXF từ ảnh",
90
- description="Tải lên một ảnh để tạo ảnh chân dung file DXF từ ảnh đó."
91
  )
92
 
93
- iface.launch(share=True)
 
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
+ import ezdxf
10
  from PIL import Image, UnidentifiedImageError
11
 
 
12
  def normPRED(d):
13
  return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
14
 
 
15
  def inference(net, input_img):
16
  input_img = input_img / np.max(input_img)
17
  tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
 
20
  tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
21
  tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
22
  tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
23
+ d1, *, *, *, *, *, * = net(tmpImg)
24
  pred = normPRED(1.0 - d1[:, 0, :, :])
25
  return pred.cpu().data.numpy().squeeze()
26
 
27
+ def extract_contours(portrait_mask):
28
+ """
29
+ Trích xuất các đường nét (contours) từ ảnh chân dung.
30
+
31
+ Parameters:
32
+ portrait_mask (numpy.ndarray): Ảnh chân dung dạng binary (đen trắng).
33
+
34
+ Returns:
35
+ list: Danh sách các đường nét (contours) được trích xuất.
36
+ """
37
  # Tìm các đường nét trong ảnh
38
+ contours, _ = cv2.findContours(portrait_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
39
 
40
+ # Làm sạch các đường nét
41
+ contours = [cnt for cnt in contours if cv2.contourArea(cnt) > 100] # Loại bỏ các đường nét nhỏ
42
+ contours = [cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True) for cnt in contours] # Làm trơn đường nét
43
+
44
+ return contours
45
+
46
+ def convert_to_dxf(contours, filename="output.dxf"):
47
+ """
48
+ Tạo file DXF từ các đường nét (contours).
49
+
50
+ Parameters:
51
+ contours (list): Danh sách các đường nét.
52
+ filename (str): Tên file DXF.
53
+
54
+ Returns:
55
+ str: Đường dẫn đến file DXF.
56
+ """
57
  doc = ezdxf.new(dxfversion="R2010")
58
  msp = doc.modelspace()
59
  for contour in contours:
60
  points = contour.reshape(-1, 2)
61
  msp.add_lwpolyline(points, close=True)
 
62
  doc.saveas(filename)
63
  return filename
64
 
 
65
  def process_image(img, bw_option):
66
  try:
67
+ img = Image.open(img).convert("RGB")
 
68
  img = np.array(img)
 
 
69
  if bw_option:
70
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
71
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
 
 
72
  result = inference(u2net, img)
73
  result_img = (result * 255).astype(np.uint8)
74
+ contours = extract_contours(result_img)
75
+ dxf_path = convert_to_dxf(contours)
 
 
76
  return result_img, dxf_path
77
  except UnidentifiedImageError:
78
+ return "Error: Unable to identify the image file. Please ensure the input file is a valid image.", None
79
 
 
80
  def load_u2net_model():
81
  model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
82
  net = U2NET(3, 1)
 
84
  net.eval()
85
  return net
86
 
 
87
  u2net = load_u2net_model()
88
 
 
89
  iface = gr.Interface(
90
  fn=process_image,
91
  inputs=[
92
+ gr.Image(type="filepath", label="Upload your image"),
93
+ gr.Checkbox(label="Convert to black and white?", value=False)
94
  ],
95
  outputs=[
96
+ gr.Image(type="numpy", label="Portrait result"),
97
+ gr.File(label="Download DXF file")
98
  ],
99
+ title="Create Portrait Images and DXF Files from Images",
100
+ description="Upload an image to generate a portrait image and a DXF file from it."
101
  )
102
 
103
+ iface.launch(share=True)