import torch import torchvision from torch import nn def create_effnetb2_model(num_classes:int=10, seed:int=42, is_TrivialAugmentWide = True, freeze_layers=True): """Creates an EfficientNetB2 feature extractor model and transforms. Args: num_classes (int, optional): number of classes in the classifier head. Defaults to 10. seed (int, optional): random seed value. Defaults to 42. is_TrivialAugmentWide (boolean): Artificially increase the diversity of a training dataset with data augmentation, default = True Returns: effnetb2_model (torch.nn.Module): EffNetB2 feature extractor model. effnetb2_transforms (torchvision.transforms): EffNetB2 image transforms. """ # 1, 2, 3. Create EffNetB2 pretrained weights, transforms and model weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT effnetb2_transforms = weights.transforms() if is_TrivialAugmentWide: effnetb2_transforms = torchvision.transforms.Compose([ torchvision.transforms.TrivialAugmentWide(), effnetb2_transforms, ]) effnetb2_model = torchvision.models.efficientnet_b2(weights=weights) # 4. Freeze all layers in base model if freeze_layers: for param in effnetb2_model.parameters(): param.requires_grad = False # 5. Change classifier head with random seed for reproducibility torch.manual_seed(seed) effnetb2_model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=num_classes), ) return effnetb2_model, effnetb2_transforms