最近想到
如果 pytorch 要顯示模型樹狀圖架構
類似於 keras 的 plot_model
要怎麼做呢?
找一找發現以下方法可以
我有用 resnet18 做一個範例給大家參考
主要是要先安裝 graphviz
https://graphviz.org/download/
直接下載到特定路徑就可以
然後在執行 python 時候特別指定就可以
import os import torch import torch.nn as nn import torch.nn.functional as F import timm # 用 resnet18 來示範 class Resnet18Models(nn.Module): def __init__(self, num_classes=2, pretrained=False): super(Resnet18Models, self).__init__() self.base_model = timm.create_model('resnet18', pretrained=pretrained).cuda() # 取到倒數第二層 特徵層 self.base_model = nn.Sequential(*list(self.base_model.children())[:-1]).cuda() self.bn1 = nn.BatchNorm1d(512).cuda() self.relu = torch.relu self.fc = nn.Linear(512, num_classes).cuda() def forward(self, x): # 資料進入 resnet18 x = self.base_model(x) x = self.bn1(x) x = self.relu(x) x = self.fc(x) # 分類架構 x = F.softmax(x, dim=1) return x # 載入 torchviz from torchviz import make_dot os.environ["PATH"] += os.pathsep + 'C:\graphviz-2.44.1-win32/Graphviz/bin/' # 安裝graphviz的路徑 # 新增模型 mainModel = Resnet18Models() # 決定輸入大小 inputSize = (3, 128, 128) # 產生輸入的示範資料 inputDatas = torch.zeros((2, inputSize[0], inputSize[1], inputSize[2]), requires_grad=False).cuda() # 輸入模型 output = mainModel(inputDatas) # make_dot 用輸出資料反推模型 modelImg = make_dot(output, params=dict(mainModel.named_parameters()), show_saved=True) # 顯示架構 modelImg.view()
成功則顯示以下圖片
顯示出來的圖片是這樣
感覺跟keras 的 plot_model 出來的圖還是不太一樣
感覺沒這麼易讀
參數部分感覺是依照每一層的計算方式去列出來的
連 fc weight 跟 bias 參數都有顯示
算是很詳細
但感覺過於複雜?
不過整體架構還是可以理解的
給大家參考囉
留言板
歡迎留下建議與分享!希望一起交流!感恩!