import cv2
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image
# Tải mô hình DeepLabV3+ đã được huấn luyện trước
def load_model():
model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
return model
# Hàm xử lý ảnh với DeepLabV3+
def predict(model, image):
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)['out'][0]
output_predictions = output.argmax(0).byte().cpu().numpy()
return output_predictions
# Hàm tách nền
def remove_background(image_path, output_path):
# Tải mô hình
model = load_model()
# Đọc ảnh đầu vào
image = Image.open(image_path).convert('RGB')
# Dự đoán mask
mask = predict(model, image)
# Chọn lớp người (class = 15) từ mask
person_class = 15
binary_mask = (mask == person_class).astype(np.uint8) * 255
# Đọc lại ảnh đầu vào với OpenCV
image = cv2.imread(image_path)
# Tạo ảnh kết quả với nền trong suốt
result = cv2.bitwise_and(image, image, mask=binary_mask)
b_channel, g_channel, r_channel = cv2.split(result)
a_channel = binary_mask
result = cv2.merge((b_channel, g_channel, r_channel, a_channel))
# Lưu ảnh kết quả
cv2.imwrite(output_path, result)
# Đường dẫn tới ảnh đầu vào và đầu ra
image_path = "input_image.jpg" # Đường dẫn tới ảnh áo thun cần tách nền
output_path = "output_image.png" # Đường dẫn lưu ảnh kết quả
# Gọi hàm tách nền
remove_background(image_path, output_path)