Vedai-image / model.py
randomshit11's picture
Upload 3 files
2e1bb0d verified
raw
history blame contribute delete
571 Bytes
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
import random
class ResNet50(nn.Module):
def __init__(self):
super(ResNet50, self).__init__()
self.resnet = models.resnet50(pretrained=True)
for param in self.resnet.parameters():
param.requires_grad = False
self.resnet.fc = nn.Sequential(
nn.Linear(2048, 2)
)
def forward(self, x):
x = self.resnet(x)
return x