Spaces:
Runtime error
Runtime error
####################################################### | |
# This file stores all the models used in the project.# | |
####################################################### | |
import torch | |
from torchvision.models import resnet50 | |
from torchvision.models import resnet18 | |
# resnet50 | |
class Bottleneck(torch.nn.Module): | |
expansion = 4 | |
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): | |
super(Bottleneck, self).__init__() | |
# hmm,ex 1x1 convolution to reduce channels (intermediate channels) | |
self.conv1 = torch.nn.Conv2d( | |
in_channels, out_channels, kernel_size=1, stride=1, padding=0 | |
) | |
self.batch_norm1 = torch.nn.BatchNorm2d(out_channels) | |
# 3x3 convolution with specified stride | |
self.conv2 = torch.nn.Conv2d( | |
out_channels, out_channels, kernel_size=3, stride=stride, padding=1 | |
) | |
self.batch_norm2 = torch.nn.BatchNorm2d(out_channels) | |
# and then leh,1x1 expand back | |
self.conv3 = torch.nn.Conv2d( | |
out_channels, | |
out_channels * self.expansion, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
self.batch_norm3 = torch.nn.BatchNorm2d(out_channels * self.expansion) | |
self.i_downsample = i_downsample | |
self.stride = stride | |
self.relu = torch.nn.ReLU() | |
##forward the input x through the network,haiyaa | |
def forward(self, x): | |
identity = x.clone() | |
x = self.relu(self.batch_norm1(self.conv1(x))) | |
x = self.relu(self.batch_norm2(self.conv2(x))) | |
x = self.conv3(x) | |
x = self.batch_norm3(x) | |
# downsample if needed | |
if self.i_downsample is not None: | |
identity = self.i_downsample(identity) | |
# add identity | |
x += identity | |
x = self.relu(x) | |
return x | |
# we no use this first,but we can just copy this whole class and apply to resnet16 and etc | |
class Block(torch.nn.Module): | |
expansion = 1 | |
def __init__(self, in_channels, out_channels, i_downsample=None, stride=1): | |
super(Block, self).__init__() | |
self.conv1 = torch.nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
stride=stride, | |
bias=False, | |
) | |
self.batch_norm1 = torch.nn.BatchNorm2d(out_channels) | |
self.conv2 = torch.nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
stride=stride, | |
bias=False, | |
) | |
self.batch_norm2 = torch.nn.BatchNorm2d(out_channels) | |
self.i_downsample = i_downsample | |
self.stride = stride | |
self.relu = torch.nn.ReLU() | |
def forward(self, x): | |
identity = x.clone() | |
x = self.relu(self.batch_norm2(self.conv1(x))) | |
x = self.batch_norm2(self.conv2(x)) | |
if self.i_downsample is not None: | |
identity = self.i_downsample(identity) | |
print(x.shape) | |
print(identity.shape) | |
x += identity | |
x = self.relu(x) | |
return x | |
class ResNet(torch.nn.Module): | |
def __init__(self, ResBlock, layer_list, num_classes, num_channels=3): | |
super(ResNet, self).__init__() | |
self.in_channels = 64 | |
# intial conv layaer | |
self.conv1 = torch.nn.Conv2d( | |
num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False | |
) | |
self.batch_norm1 = torch.nn.BatchNorm2d(64) | |
self.relu = torch.nn.ReLU() | |
self.max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
# residual block(layers),each block got three three layer,total 4 blocks | |
self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64) | |
self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2) | |
self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2) | |
self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2) | |
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc = torch.nn.Linear(512 * ResBlock.expansion, num_classes) | |
def forward(self, x): | |
x = self.relu(self.batch_norm1(self.conv1(x))) | |
x = self.max_pool(x) | |
x = self.layer1(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
x = self.layer4(x) | |
x = self.avgpool(x) | |
x = x.reshape(x.shape[0], -1) | |
x = self.fc(x) | |
return x | |
def _make_layer(self, ResBlock, blocks, planes, stride=1): | |
# plane is the number of output channel | |
ii_downsample = None | |
layers = [] | |
if stride != 1 or self.in_channels != planes * ResBlock.expansion: | |
ii_downsample = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
self.in_channels, | |
planes * ResBlock.expansion, | |
kernel_size=1, | |
stride=stride, | |
), | |
torch.nn.BatchNorm2d(planes * ResBlock.expansion), | |
) | |
layers.append( | |
ResBlock( | |
self.in_channels, planes, i_downsample=ii_downsample, stride=stride | |
) | |
) | |
self.in_channels = planes * ResBlock.expansion | |
for i in range(blocks - 1): | |
layers.append(ResBlock(self.in_channels, planes)) | |
return torch.nn.Sequential(*layers) | |
##list here leh is the number of residual block in each layer | |
def ResNet50(num_classes, channels=3): | |
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels) | |
# VGG16 model | |
class VGG16(torch.nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.block_1 = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=3, | |
out_channels=64, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), | |
) | |
self.block_2 = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=64, | |
out_channels=128, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=128, | |
out_channels=128, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), | |
) | |
self.block_3 = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=128, | |
out_channels=256, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=256, | |
out_channels=256, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=256, | |
out_channels=256, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), | |
) | |
self.block_4 = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=256, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), | |
) | |
self.block_5 = torch.nn.Sequential( | |
torch.nn.Conv2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d( | |
in_channels=512, | |
out_channels=512, | |
kernel_size=(3, 3), | |
stride=(1, 1), | |
padding=1, | |
), | |
torch.nn.ReLU(), | |
torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)), | |
) | |
height, width = 3, 3 | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Linear(512 * height * width, 4096), | |
torch.nn.ReLU(True), | |
torch.nn.Dropout(p=0.5), | |
torch.nn.Linear(4096, 4096), | |
torch.nn.ReLU(True), | |
torch.nn.Dropout(p=0.5), | |
torch.nn.Linear(4096, num_classes), | |
) | |
for m in self.modules(): | |
if isinstance(m, torch.torch.nn.Conv2d) or isinstance( | |
m, torch.torch.nn.Linear | |
): | |
torch.nn.init.kaiming_uniform_( | |
m.weight, mode="fan_in", nonlinearity="relu" | |
) | |
if m.bias is not None: | |
m.bias.detach().zero_() | |
self.avgpool = torch.nn.AdaptiveAvgPool2d((height, width)) | |
def forward(self, x): | |
x = self.block_1(x) | |
x = self.block_2(x) | |
x = self.block_3(x) | |
x = self.block_4(x) | |
x = self.block_5(x) | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) # flatten | |
logits = self.classifier(x) | |
# probas = F.softmax(logits, dim=1) | |
return logits | |
# ResNet18 model | |