stn.pytorch

by fxia22

fxia22 / stn.pytorch

pytorch version of spatial transformer networks

456 Stars 79 Forks Last release: Not found Other 52 Commits 0 Releases

Available items

No Items, yet!

The developer of this repository has not created any items for sale yet. Need a bug fixed? Help with integration? A different license? Create a request here:

PyTorch version of spatial transformer network

Ported from https://github.com/qassemoquab/stnbhwd according to pytorch tutorial. Now support CPU and GPU. To use the ffi you need to install the

cffi
package from pip.

Build and test

cd script
./make.sh #build cuda code, don't forget to modify -arch argument for your GPU computational capacity version
python build.py
python test.py

There is a demo in

test_stn.ipynb

Modules

STN
is the spatial transformer module, it takes a
B*H*W*D
tensor and a
B*H*W*2
grid normalized to [-1,1] as an input and do bilinear sampling.

AffineGridGen
takes a
B*2*3
matrix and generate an affine transformation grid.

CylinderGridGen
takes a
B*1
theta vector and generate a transformation grid to remap equirectangular images along x axis.

DenseAffineGridGen
takes a
B*H*W*6
tensor and do affine transformation for each pixel. Example of convolutional spatial transformer can be found in
test_conv_stn.ipynb
.

An example of the landscape of the loss function of a simple STN with L1 Loss can be found in the demo.

Train hacks

  • set a learning rate multiplier, 1e-3 or 1e-4 would work fine.
  • add an auxiliary loss to regularized the difference of the affine transformation from identity mapping, to aviod sampling outside the original image.

Complex grid demo

STN is able to handle a complex grid, however, how to parameterize the grid is a problem.

image

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.