Need help with siren-pytorch?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

lucidrains
242 Stars 25 Forks MIT License 23 Commits 3 Opened issues

Description

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function

Services available

!
?

Need anything else?

Contributors list

# 11,636
Python
pytorch
golang
swagger
21 commits
# 173,058
C++
Jupyter...
3d-obje...
kitti
1 commit

SIREN in Pytorch

PyPI version

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function

Install

$ pip install siren-pytorch

Usage

A SIREN based multi-layered neural network

import torch
from torch import nn
from siren_pytorch import SirenNet

net = SirenNet( dim_in = 2, # input dimension, ex. 2d coor dim_hidden = 256, # hidden dimension dim_out = 3, # output dimension, ex. rgb value num_layers = 5, # number of layers final_activation = nn.Sigmoid(), # activation of final layer (nn.Identity() for direct output) w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter )

coor = torch.randn(1, 2) net(coor) # (1, 3)

One SIREN layer

import torch
from siren_pytorch import Siren

neuron = Siren( dim_in = 3, dim_out = 256 )

coor = torch.randn(1, 3) neuron(coor) # (1, 256)

Sine activation (just a wrapper around

torch.sin
)
import torch
from siren_pytorch import Sine

act = Sine(1.) coor = torch.randn(1, 2) act(coor)

Wrapper to train on a specific image of specified height and width from a given

SirenNet
, and then to subsequently generate.
import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet( dim_in = 2, # input dimension, ex. 2d coor dim_hidden = 256, # hidden dimension dim_out = 3, # output dimension, ex. rgb value num_layers = 5, # number of layers w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter )

wrapper = SirenWrapper( net, image_width = 256, image_height = 256 )

img = torch.randn(1, 3, 256, 256) loss = wrapper(img) loss.backward()

after much training ...

simply invoke the wrapper without passing in anything

pred_img = wrapper() # (1, 3, 256, 256)

Modulation with Latent Code

A new paper proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren.

You can use this simply by setting an extra keyword

latent_dim
, on the
SirenWrapper
import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet( dim_in = 2, # input dimension, ex. 2d coor dim_hidden = 256, # hidden dimension dim_out = 3, # output dimension, ex. rgb value num_layers = 5, # number of layers w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter )

wrapper = SirenWrapper( net, latent_dim = 512, image_width = 256, image_height = 256 )

latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2)) img = torch.randn(1, 3, 256, 256)

loss = wrapper(img, latent = latent) loss.backward()

after much training ...

simply invoke the wrapper without passing in anything

pred_img = wrapper(latent = latent) # (1, 3, 256, 256)

Citations

@misc{sitzmann2020implicit,
    title   = {Implicit Neural Representations with Periodic Activation Functions},
    author  = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
    year    = {2020},
    eprint  = {2006.09661},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mehta2021modulated,
    title   = {Modulated Periodic Activations for Generalizable Local Functional Representations}, 
    author  = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker},
    year    = {2021},
    eprint  = {2104.03960},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.