Spaces:
Runtime error
Runtime error
File size: 1,165 Bytes
fe70fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from torch.utils.data.dataloader import DataLoader,Dataset
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
class Segmentation_Dataset(Dataset):
def __init__(self,img_dir,mask_dir,transform=None):
self.img_dir=img_dir
self.mask_dir=mask_dir
self.transform=transform
self.images=os.listdir(img_dir)
self.images=[im for im in self.images if ".jpg" in im]
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img_path=os.path.join(self.img_dir,self.images[idx])
mask_path=os.path.join(self.mask_dir,self.images[idx].replace(".jpg",".png"))
image=np.array(Image.open(img_path).convert("RGB"))
mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
mask[mask==255]=1.0
if self.transform is not None:
augmentations=self.transform(image=image,mask=mask)
image=augmentations["image"]
mask=augmentations["mask"]
return image, mask
|