Visualisation tool for CNNs in pytorch
Francesco Saverio Zuppichini
This is a raw beta so expect lots of things to change and improve over time.
An interactive version of this tutorial can be found here
To install mirror run
pip install git+https://github.com/FrancescoSaverioZuppichini/mirror.git
A basic example
from mirror import mirror from mirror.visualisations.web import * from PIL import Image from torchvision.models import resnet101, resnet18, vgg16, alexnet from torchvision.transforms import ToTensor, Resize, Composecreate a model
model = vgg16(pretrained=True)
open some images
cat = Image.open("./cat.jpg") dog_and_cat = Image.open("./dog_and_cat.jpg")
resize the image and make it a tensor
to_input = Compose([Resize((224, 224)), ToTensor()])
call mirror with the inputs and the model
mirror([to_input(cat), to_input(dog_and_cat)], model, visualisations=[BackProp, GradCam, DeepDream])
* Serving Flask app "mirror.App" (lazy loading)
It will automatic open a new tab in your browser
On the left you can see your model tree structure, by clicking on one layer all his children are showed. On the right there are the visualisation settings. You can select your input by clicking on the bottom tab.
All visualisation available for the web app are inside
.mirror.visualisations.web.
By clicking on the radio button 'guide', all the relus negative output will be set to zero producing a nicer looking image
If you want, you can use the vanilla version of each visualisation by importing them from
.mirror.visualisation.core.
from mirror.visualisations.core import GradCamcreate a model
model = vgg16(pretrained=True)
open some images
cat = Image.open("./cat.jpg") dog_and_cat = Image.open("./dog_and_cat.jpg")
resize the image and make it a tensor
to_input = Compose([Resize((224, 224)), ToTensor()])
cam = GradCam(model, device='cpu') cam(to_input(cat).unsqueeze(0), None) # will return the output image and some additional information
To create a visualisation you first have to subclass the
Visualisationclass and override the
__call__method to return an image and, if needed, additional informations. The following example creates a custom visualisation that just repeat the input
repeattimes. So
from mirror.visualisations.core import Visualisationclass RepeatInput(Visualisation):
def __call__(self, inputs, layer, repeat=1): return inputs.repeat(repeat, 1, 1, 1), None
This class repeats the input for
repeattimes.
To connect our fancy visualisation to the web interface, we have to create a
WebInterface. Easily, we can use
WebInterface.from_visualisationto generate the communication channel between our visualisation and the web app.
It follows and example
from mirror.visualisations.web import WebInterface from functools import partialparams = {'repeat' : { 'type' : 'slider', 'min' : 1, 'max' : 100, 'value' : 2, 'step': 1, 'params': {} } }
visualisation = partial(WebInterface.from_visualisation, RepeatInput, params=params, name='Repeat')
First we import
WebInterfaceand
partial. Then, we create a dictionary where each they key is the visualisation parameter name. In our example,
RepeatInputtakes a parameter called
repeat, thus we have to define a dictionary
{ 'repeat' : { ... }' }.
The value of that dictionary is the configuration for one of the basic UI elements: slider, textfield and radio.
The input is stored in the
valueslot.
Then we call
WebInterface.from_visualisationby passing the visualisation, the params and the name. We need to wrap this function using
partialsince
mirrorwill need to dynamically pass some others parameters, the current layer and the input, at run time.
The final result is
All the front-end is developed usin React and Material-UI, two very known frameworks, making easier for anybody to contribuite.
You can customise the front-end by changing the source code in
mirror/client. After that, you need to build the react app and move the file to the server static folder.
I was not able to serve the static file directly from the /mirror/client/build folder if you know how to do it any pull request is welcome :)
cd ./mirror/mirror/client // assuming the root folder is called mirror npm run build
Then you need to move the fiels from the
mirror/mirror/client/buildfolder to
mirror/mirror. You can remove all the files in
mirror/mirro/static
mv ./build/static ../ && cp ./build/* ../static/
Add all visualisation present here https://github.com/utkuozbulak/pytorch-cnn-visualizations
[ ] Add a
output_transformationparams for each visualisation to allow better customisation
[ ] Add a
input_transformationparams for each visualisation to allow better customisation