Code to run it
#1
by
antonin-poche
- opened
I had quite a hard time finding the right way to load and call this model.
Therefore, below is my code to do so. It is inspired by https://github.com/chathumal93/EuroSat-RGB-Classifiers/blob/main/evaluation.ipynb.
This model reaches an accuracy of 0.9774 on EuroSAT's test set.
!pip install -q torch torchvision timm datasets
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from datasets import load_dataset
import timm
from tqdm import tqdm
# Transformations
transforms = transforms.Compose([
transforms.Resize(232),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.3445, 0.3803, 0.4077], [0.0915, 0.0652, 0.0553])
])
class EuroSat(Dataset):
def __init__(self, data, transform=transforms, target_transform=None):
self.data = data
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data[idx]['image']
label = self.data[idx]['label']
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
# Parameters
batch_size = 256
# Labels
labels = [
"Forest",
"River",
"Highway",
"AnnualCrop",
"SeaLake",
"HerbaceousVegetation",
"Industrial",
"Residential",
"PermanentCrop",
"Pasture"
]
# Data (test set)
test_data = EuroSat(load_dataset("cm93/eurosat", split='test'), transform=transforms)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
# Load the pretrained ResNet-18 model
model = timm.create_model("hf_hub:cm93/resnet18-eurosat", pretrained=True)
model.eval() # Set to evaluation mode
# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Evaluate both models
correct_model = 0
total_samples = 0
with torch.no_grad():
for data, target in tqdm(test_dataloader):
# Put data on device
data, target = data.to(device), target.to(device)
# Predictions from both models
outputs_model = model(data)
# Get predicted labels
preds_model = torch.argmax(outputs_model, dim=1)
# Count correct predictions
correct_model += (preds_model == target).sum().item()
total_samples += target.size(0)
# Compute accuracy
accuracy_model = correct_model / total_samples
# Print results
print(f"Accuracy of Original Model: {accuracy_model:.4f}")