cycool29 commited on
Commit
28449df
·
1 Parent(s): 3de491d
Files changed (1) hide show
  1. predict.py +9 -9
predict.py CHANGED
@@ -10,16 +10,16 @@ from configs import *
10
 
11
 
12
  # Load your model (change this according to your model definition)
13
- model2 = EfficientNetB2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
14
- model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB2WithDropout.pth"))
15
- model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
16
- model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
17
- model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
18
- model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
19
 
20
- model1.eval()
21
- model2.eval()
22
- model3.eval()
23
 
24
  # Load the model
25
  model = MODEL.to(DEVICE)
 
10
 
11
 
12
  # Load your model (change this according to your model definition)
13
+ # model2 = EfficientNetB2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
14
+ # model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB2WithDropout.pth"))
15
+ # model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
16
+ # model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
17
+ # model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
18
+ # model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
19
 
20
+ # model1.eval()
21
+ # model2.eval()
22
+ # model3.eval()
23
 
24
  # Load the model
25
  model = MODEL.to(DEVICE)