Code to run it

#1
by antonin-poche - opened

I had quite a hard time finding the right way to load and call this model.

Therefore, below is my code to do so. It is inspired by https://github.com/chathumal93/EuroSat-RGB-Classifiers/blob/main/evaluation.ipynb.

This model reaches an accuracy of 0.9774 on EuroSAT's test set.

!pip install -q torch torchvision timm datasets

import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import timm
from tqdm import tqdm

# Transformations
transforms = transforms.Compose([
    transforms.Resize(232),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.3445, 0.3803, 0.4077], [0.0915, 0.0652, 0.0553])
])

class EuroSat(Dataset):
    def __init__(self, data, transform=transforms, target_transform=None):
        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]['image']
        label = self.data[idx]['label']
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

# Parameters
batch_size = 256

# Labels
labels = [
    "Forest",
    "River",
    "Highway",
    "AnnualCrop",
    "SeaLake",
    "HerbaceousVegetation",
    "Industrial",
    "Residential",
    "PermanentCrop",
    "Pasture"
  ]

# Data (test set)
test_data = EuroSat(load_dataset("cm93/eurosat", split='test'), transform=transforms)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)

# Load the pretrained ResNet-18 model
model = timm.create_model("hf_hub:cm93/resnet18-eurosat", pretrained=True)
model.eval()  # Set to evaluation mode

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Evaluate both models
correct_model = 0
total_samples = 0

with torch.no_grad():
    for data, target in tqdm(test_dataloader):
        # Put data on device
        data, target = data.to(device), target.to(device)

        # Predictions from both models
        outputs_model = model(data)

        # Get predicted labels
        preds_model = torch.argmax(outputs_model, dim=1)

        # Count correct predictions
        correct_model += (preds_model == target).sum().item()
        total_samples += target.size(0)

# Compute accuracy
accuracy_model = correct_model / total_samples

# Print results
print(f"Accuracy of Original Model: {accuracy_model:.4f}")

Sign up or log in to comment