Source code for estorch.modules

import torch
from torch import nn
import torch.nn.functional as F


[docs]class VirtualBatchNorm(nn.Module): """ Applies Virtual Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in paper `Improved Techniques for Training GANs`: https://arxiv.org/abs/1606.03498 .. math:: y = \\frac{x - \\mathrm{E}[x_\\text{ref}]}{ \\sqrt{\\mathrm{Var}[x_\\text{ref}] + \\epsilon}} * \\gamma + \\beta VirtualBatchNorm requires two forward passes. First one is to calculate mean and variance over a reference batch and second is to calculate the actual output. Args: num_features: :math:`C` from an expected input of size :math:`(N, C, H, W)` eps: a value added to the denominator for numerical stability. Default: 1e-5 """ def __init__(self, num_features, eps=1e-5): super(VirtualBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.mean = None self.var = None self.weight = nn.parameter.Parameter(torch.Tensor(num_features)) self.bias = nn.parameter.Parameter(torch.Tensor(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def normalize(self, x): y = ((x-self.mean) / torch.sqrt(self.var + self.eps) * self.weight.view(1, self.num_features, 1, 1) + self.bias.view(1, self.num_features, 1, 1)) return y def forward(self, x): """""" if self.mean is None and self.var is None: self.mean = torch.mean(x, dim=0, keepdim=True) self.var = torch.var(x, dim=0, keepdim=True) out = self.normalize(x) else: out = self.normalize(x) self.mean = None self.var = None return out