最近想到

如果 pytorch 要顯示模型樹狀圖架構

類似於 keras 的 plot_model

要怎麼做呢?

 

找一找發現以下方法可以

我有用 resnet18 做一個範例給大家參考

主要是要先安裝 graphviz

https://graphviz.org/download/

直接下載到特定路徑就可以

然後在執行 python 時候特別指定就可以

 

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

# 用 resnet18 來示範
class Resnet18Models(nn.Module):
    def __init__(self, num_classes=2, pretrained=False):
        super(Resnet18Models, self).__init__()
        self.base_model     = timm.create_model('resnet18', pretrained=pretrained).cuda()
        # 取到倒數第二層 特徵層
        self.base_model     = nn.Sequential(*list(self.base_model.children())[:-1]).cuda()
        self.bn1            = nn.BatchNorm1d(512).cuda()
        self.relu           = torch.relu
        self.fc             = nn.Linear(512, num_classes).cuda()
        
    def forward(self, x):
        # 資料進入 resnet18
        x = self.base_model(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.fc(x)
        # 分類架構
        x = F.softmax(x, dim=1)
        return x

# 載入 torchviz
from torchviz import make_dot
os.environ["PATH"] += os.pathsep + 'C:\graphviz-2.44.1-win32/Graphviz/bin/'  # 安裝graphviz的路徑
# 新增模型
mainModel = Resnet18Models()
# 決定輸入大小
inputSize = (3, 128, 128)
# 產生輸入的示範資料
inputDatas = torch.zeros((2, inputSize[0], inputSize[1], inputSize[2]), requires_grad=False).cuda()
# 輸入模型
output = mainModel(inputDatas)
# make_dot 用輸出資料反推模型
modelImg = make_dot(output, params=dict(mainModel.named_parameters()), show_saved=True)
# 顯示架構
modelImg.view()

 

 


 

成功則顯示以下圖片

 

 

顯示出來的圖片是這樣

感覺跟keras 的 plot_model 出來的圖還是不太一樣

感覺沒這麼易讀

 

 

參數部分感覺是依照每一層的計算方式去列出來的

連 fc weight 跟 bias 參數都有顯示

算是很詳細

但感覺過於複雜?

不過整體架構還是可以理解的

給大家參考囉