View model summaries in PyTorch!
(formerly torch-summary)
Torchinfo provides information complementary to what is provided by
print(your_model)in PyTorch, similar to Tensorflow's
model.summary()API to view the visualization of the model, which is helpful while debugging your network. In this project, we implement a similar functionality in PyTorch and create a clean, simple interface to use in your projects.
This is a completely rewritten version of the original torchsummary and torchsummaryX projects by @sksq96 and @nmhkahn. This project addresses all of the issues and pull requests left on the original projects by introducing a completely new API.
pip install torchinfo
from torchinfo import summarymodel = ConvNet() batch_size = 16 summary(model, input_size=(batch_size, 1, 28, 28))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [16, 10, 24, 24] 260 ├─Conv2d: 1-2 [16, 20, 8, 8] 5,020 ├─Dropout2d: 1-3 [16, 20, 8, 8] -- ├─Linear: 1-4 [16, 50] 16,050 ├─Linear: 1-5 [16, 10] 510 ========================================================================================== Total params: 21,840 Trainable params: 21,840 Non-trainable params: 0 Total mult-adds (M): 0.48 ========================================================================================== Input size (MB): 0.05 Forward/backward pass size (MB): 0.91 Params size (MB): 0.09 Estimated Total Size (MB): 1.05 ==========================================================================================
This version now supports:
Other new features:
def summary( model: nn.Module, input_size: Optional[INPUT_SIZE_TYPE] = None, input_data: Optional[INPUT_DATA_TYPE] = None, batch_dim: Optional[int] = None, col_names: Optional[Iterable[str]] = None, col_width: int = 25, depth: int = 3, device: Optional[torch.device] = None, dtypes: Optional[List[torch.dtype]] = None, verbose: int = 1, **kwargs: Any, ) -> ModelStatistics: """ Summarize the given PyTorch model. Summarized information includes: 1) Layer names, 2) input/output shapes, 3) kernel shape, 4) # of parameters, 5) # of operations (Mult-Adds)NOTE: If neither input_data or input_size are provided, no forward pass through the network is performed, and the provided model information is limited to layer names.
Args: model (nn.Module): PyTorch model to summarize. The model should be fully in either train() or eval() mode. If layers are not all in the same mode, running summary may have side effects on batchnorm or dropout statistics. If you encounter an issue with this, please open a GitHub issue.
input_size (Sequence of Sizes): Shape of input data as a List/Tuple/torch.Size (dtypes must match model input, default is FloatTensors). You should include batch size in the tuple. Default: None input_data (Sequence of Tensors): Example input tensor of the model (dtypes inferred from model input). Default: None batch_dim (int): Batch_dimension of input data. If batch_dim is None, assume input_data / input_size contains the batch dimension, which is used in all calculations. Else, expand all tensors to contain the batch_dim. Specifying batch_dim can be an runtime optimization, since if batch_dim is specified, torchinfo uses a batch size of 2 for the forward pass. Default: None col_names (Iterable[str]): Specify which columns to show in the output. Currently supported: ( "input_size", "output_size", "num_params", "kernel_size", "mult_adds", ) Default: ("output_size", "num_params") If input_data / input_size are not provided, only "num_params" is used. col_width (int): Width of each column. Default: 25 depth (int): Number of nested layers to traverse (e.g. Sequentials). Default: 3 device (torch.Device): Uses this torch device for model and input_data. If not specified, uses result of torch.cuda.is_available(). Default: None dtypes (List[torch.dtype]): For multiple inputs, specify the size of both inputs, and also specify the types of each parameter here. Default: None verbose (int): 0 (quiet): No output 1 (default): Print model summary 2 (verbose): Show weight and bias layers in full detail Default: 1 **kwargs: Other arguments used in `model.forward` function. Passing *args is no longer supported.
Return: ModelStatistics object See torchinfo/model_statistics.py for more information. """
from torchinfo import summarymodel_stats = summary(your_model, (1, 3, 28, 28), verbose=0) summary_str = str(model_stats)
summary_str contains the string representation of the summary. See below for examples.
import torchvisionmodel = torchvision.models.resnet50() summary(model, (1, 3, 224, 224), depth=3)
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─Conv2d: 1-1 [1, 64, 112, 112] 9,408 ├─BatchNorm2d: 1-2 [1, 64, 112, 112] 128 ├─ReLU: 1-3 [1, 64, 112, 112] -- ├─MaxPool2d: 1-4 [1, 64, 56, 56] -- ├─Sequential: 1-5 [1, 256, 56, 56] -- | └─Bottleneck: 2-1 [1, 256, 56, 56] -- | | └─Conv2d: 3-1 [1, 64, 56, 56] 4,096 | | └─BatchNorm2d: 3-2 [1, 64, 56, 56] 128 | | └─ReLU: 3-3 [1, 64, 56, 56] -- | | └─Conv2d: 3-4 [1, 64, 56, 56] 36,864 | | └─BatchNorm2d: 3-5 [1, 64, 56, 56] 128 | | └─ReLU: 3-6 [1, 64, 56, 56] -- | | └─Conv2d: 3-7 [1, 256, 56, 56] 16,384 | | └─BatchNorm2d: 3-8 [1, 256, 56, 56] 512 | | └─Sequential: 3-9 [1, 256, 56, 56] 16,896 | | └─ReLU: 3-10 [1, 256, 56, 56] --... ... ...
├─AdaptiveAvgPool2d: 1-9 [1, 2048, 1, 1] -- ├─Linear: 1-10 [1, 1000] 2,049,000 ========================================================================================== Total params: 60,192,808 Trainable params: 60,192,808 Non-trainable params: 0 Total mult-adds (G): 11.63 ========================================================================================== Input size (MB): 0.60 Forward/backward pass size (MB): 360.87 Params size (MB): 240.77 Estimated Total Size (MB): 602.25 ==========================================================================================
class MultipleInputNetDifferentDtypes(nn.Module): def __init__(self): super().__init__() self.fc1a = nn.Linear(300, 50) self.fc1b = nn.Linear(50, 10)self.fc2a = nn.Linear(300, 50) self.fc2b = nn.Linear(50, 10) def forward(self, x1, x2): x1 = F.relu(self.fc1a(x1)) x1 = self.fc1b(x1) x2 = x2.type(torch.float) x2 = F.relu(self.fc2a(x2)) x2 = self.fc2b(x2) x = torch.cat((x1, x2), 0) return F.log_softmax(x, dim=1)
summary(model, [(1, 300), (1, 300)], dtypes=[torch.float, torch.long])
Alternatively, you can also pass in the input_data itself, and torchinfo will automatically infer the data types.
input_data = torch.randn(1, 300) other_input_data = torch.randn(1, 300).long() model = MultipleInputNetDifferentDtypes()summary(model, input_data=[input_data, other_input_data, ...])
class LSTMNet(nn.Module): """ Batch-first LSTM model. """ def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2): super().__init__() self.hidden_dim = hidden_dim self.embedding = nn.Embedding(vocab_size, embed_dim) self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True) self.decoder = nn.Linear(hidden_dim, vocab_size)def forward(self, x): embed = self.embedding(x) out, hidden = self.encoder(embed) out = self.decoder(out) out = out.view(-1, out.size(2)) return out, hidden
summary( LSTMNet(), (1, 100), dtypes=[torch.long], verbose=2, col_width=16, col_names=["kernel_size", "output_size", "num_params", "mult_adds"], )
======================================================================================================================== Layer (type:depth-idx) Kernel Shape Output Shape Param # Mult-Adds ======================================================================================================================== ├─Embedding: 1-1 [300, 20] [1, 100, 300] 6,000 6,000 ├─LSTM: 1-2 -- [1, 100, 512] 3,768,320 3,760,128 | └─weight_ih_l0 [2048, 300] | └─weight_hh_l0 [2048, 512] | └─weight_ih_l1 [2048, 512] | └─weight_hh_l1 [2048, 512] ├─Linear: 1-3 [512, 20] [1, 100, 20] 10,260 10,240 ======================================================================================================================== Total params: 3,784,580 Trainable params: 3,784,580 Non-trainable params: 0 Total mult-adds (M): 3.78 ======================================================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.67 Params size (MB): 15.14 Estimated Total Size (MB): 15.80 ========================================================================================================================
class ContainerModule(nn.Module): """ Model using ModuleList. """def __init__(self): super().__init__() self._layers = nn.ModuleList() self._layers.append(nn.Linear(5, 5)) self._layers.append(ContainerChildModule()) self._layers.append(nn.Linear(5, 5)) def forward(self, x): for layer in self._layers: x = layer(x) return x
class ContainerChildModule(nn.Module): """ Model using Sequential in different ways. """
def __init__(self): super().__init__() self._sequential = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) self._between = nn.Linear(5, 5) def forward(self, x): out = self._sequential(x) out = self._between(out) for l in self._sequential: out = l(out) out = self._sequential(x) for l in self._sequential: out = l(out) return out
summary(ContainerModule(), (1, 5))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== ├─ModuleList: 1 [] -- | └─Linear: 2-1 [1, 5] 30 | └─ContainerChildModule: 2-2 [1, 5] -- | | └─Sequential: 3-1 [1, 5] -- | | | └─Linear: 4-1 [1, 5] 30 | | | └─Linear: 4-2 [1, 5] 30 | | └─Linear: 3-2 [1, 5] 30 | | └─Sequential: 3 [] -- | | | └─Linear: 4-3 [1, 5] (recursive) | | | └─Linear: 4-4 [1, 5] (recursive) | | └─Sequential: 3-3 [1, 5] (recursive) | | | └─Linear: 4-5 [1, 5] (recursive) | | | └─Linear: 4-6 [1, 5] (recursive) | | | └─Linear: 4-7 [1, 5] (recursive) | | | └─Linear: 4-8 [1, 5] (recursive) | └─Linear: 2-3 [1, 5] 30 ========================================================================================== Total params: 150 Trainable params: 150 Non-trainable params: 0 Total mult-adds (M): 0.00 ========================================================================================== Input size (MB): 0.00 Forward/backward pass size (MB): 0.00 Params size (MB): 0.00 Estimated Total Size (MB): 0.00 ==========================================================================================
All issues and pull requests are much appreciated! If you are wondering how to build the project:
pip install -r requirements-dev.txt. We use the latest versions of all dev packages.
pre-commit install.
.pre-commit-config.yaml.
pytest.