Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
$ pip install siren-pytorch
A SIREN based multi-layered neural network
import torch from torch import nn from siren_pytorch import SirenNetnet = SirenNet( dim_in = 2, # input dimension, ex. 2d coor dim_hidden = 256, # hidden dimension dim_out = 3, # output dimension, ex. rgb value num_layers = 5, # number of layers final_activation = nn.Sigmoid(), # activation of final layer (nn.Identity() for direct output) w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter )
coor = torch.randn(1, 2) net(coor) # (1, 3)
One SIREN layer
import torch from siren_pytorch import Sirenneuron = Siren( dim_in = 3, dim_out = 256 )
coor = torch.randn(1, 3) neuron(coor) # (1, 256)
Sine activation (just a wrapper around
torch.sin)import torch from siren_pytorch import Sineact = Sine(1.) coor = torch.randn(1, 2) act(coor)
Wrapper to train on a specific image of specified height and width from a given
SirenNet, and then to subsequently generate.import torch from torch import nn from siren_pytorch import SirenNet, SirenWrappernet = SirenNet( dim_in = 2, # input dimension, ex. 2d coor dim_hidden = 256, # hidden dimension dim_out = 3, # output dimension, ex. rgb value num_layers = 5, # number of layers w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter )
wrapper = SirenWrapper( net, image_width = 256, image_height = 256 )
img = torch.randn(1, 3, 256, 256) loss = wrapper(img) loss.backward()
after much training ...
simply invoke the wrapper without passing in anything
pred_img = wrapper() # (1, 3, 256, 256)
Citations
@misc{sitzmann2020implicit, title={Implicit Neural Representations with Periodic Activation Functions}, author={Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein}, year={2020}, eprint={2006.09661}, archivePrefix={arXiv}, primaryClass={cs.CV} }