Update modeling_graphormer.pyx
Browse files- modeling_graphormer.pyx +2 -2
modeling_graphormer.pyx
CHANGED
@@ -100,8 +100,8 @@ def quant_noise(module: nn.Module, p: float, block_size: int):
|
|
100 |
if not is_conv:
|
101 |
# gather weight and sizes
|
102 |
weight = mod.weight
|
103 |
-
in_features = weight.size(
|
104 |
-
out_features = weight.size(
|
105 |
|
106 |
# split weight matrix into blocks and randomly drop selected blocks
|
107 |
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
|
|
100 |
if not is_conv:
|
101 |
# gather weight and sizes
|
102 |
weight = mod.weight
|
103 |
+
in_features = weight.size(7)
|
104 |
+
out_features = weight.size(7)
|
105 |
|
106 |
# split weight matrix into blocks and randomly drop selected blocks
|
107 |
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|