File size: 657 Bytes
73666ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# from torchview import draw_graph
from torchviz import make_dot
from configs import *
import os
# import graphviz

# when running on VSCode run the below command
# svg format on vscode does not give desired result
# graphviz.set_jupyter_format("png")

model = EfficientNetB3WithNorm(num_classes=7)

batch_size = 2
# device='meta' -> no memory is consumed for visualization
# model_graph = draw_graph(model, input_size=(32, 3, 224, 224), save_graph=True, filename="model_graph.png")
# model_graph.visual_graph

model_graph = make_dot(
    model(torch.randn(batch_size, 3, 224, 224)), params=dict(model.named_parameters())
).render("torchviz", format="png")