Deep Learning
torchviz로 모델 시각화 하기
jinmc
2022. 7. 27. 09:49
반응형
from torchviz import make_dot
from torch.autograd import Variable
x = Variable(torch.randn(6, 2, 64, 344))
x = x.to("cuda")
make_dot(myModel(x), params=dict(list(myModel.named_parameters()))).render("myModel", format="png")
간단해 보이는 script지만 상당히 많은 부분에서 오류가 났습니다.
1. Anaconda environment의 경우에는 pip install torchviz가 아니라 conda install torchviz를 해 줍니다.
2. ubuntu의 경우에는 sudo apt도 해줘야 합니다.
sudo apt-get install graphviz
3. torch Variable을 설정할 때, dimension을 맞춰줘야 합니다. 위의 6,2,64, 344 의 숫자들이 괜히 나온게 아니에요.
model에 들어가는 input을 보려면,
print(input.shape)
로 확인 할 수 있습니다.
4. 만약 cuda variable을 사용하는 경우에는, 위의 x = x.to("cuda") 를 통해서 cuda variable로 변환해주지 않으면 에러가 생기게 됩니다.
5. make_dot도 다음 포맷에 맞춰줘서 불러야 합니다.
정확한 사항은 참조를 참고하시면 좋을것 같습니다.
참조 : torchviz
반응형