GAN Tricks: Equalized Learning Rate

Generalised implementation in Pytorch

Equalized Learning Rate is a trick that was introduced in the works on the progressive growing of GANs and StyleGAN, for stabilised and improved training.

The idea is to scale the parameters of each layer just before every forward propagation that passes through. How much to scale by is determined by a statistic calculated from the parameter values of each layer.

This post focuses on how this trick is implemented in the StyleGAN library github.com/rosinality/style-based-gan-pytorch, rather than on how the value of the scaling factor comes about.

Generalized implementation

In constrast with several other StyleGAN libraries where equalized learning rate is implemented separately for nn.Linear and nn.Conv2d modules, style-based-gan-pytorch has a class EqualLR that can be applied to either nn.Linear or nn.Conv2d to introduce equalized learning rate in the module.

Here is its definition:


class EqualLR:
  def __init__(self, name):
      self.name = name

  def compute_weight(self, module):
      weight = getattr(module, self.name + '_orig')
      fan_in = weight.data.size(1) * weight.data[0][0].numel()

      return weight * math.sqrt(2 / fan_in)

  @staticmethod
  def apply(module, name):
      fn = EqualLR(name)

      weight = getattr(module, name)    #1
      del module._parameters[name]
      module.register_parameter(name + '_orig', nn.Parameter(weight.data))
      module.register_forward_pre_hook(fn)    #2

      return fn

  def __call__(self, module, input):    #3
      weight = self.compute_weight(module)    #4
      setattr(module, self.name, weight)    #5
  1. In applying equalized learning rate to a module's name parameter, the module's name parameter is first renamed as f'{name}_orig', and the name parameter itself is deleted. At this point, a forward propagation through the module would fail, because it does not have a name attribute, which pytorch looks for by default.

  2. Therefore, just before each forward propagation, a name attribute needs to be created. This happens in the fn function, and to run this function just before a forward propagation through the module, the function needs to be registered as a forward pre-hook, using the module's register_forward_pre_hook method.

  3. Since a Python function is just a callable, here fn is really just an instance of the EqualLR class, which has a __call__ method.

  4. In this method, the first thing to do is to retrieve the weight_orig parameter and scale it by some number, the details of which are in the EqualLR.compute_weight method. The important thing to observe is that the weight returned by compute_weight is just a torch.Tensor, not a nn.Parameter. This makes it simply a non-leaf variable in the computation graph, and back propagation will pass through it to reach the weight_orig parameter, the leaf variable to be updated during training.

  5. Then, the obtained weight tensor is set as the the name attribute to the module. This means that forward propgagtion can go through the module now, with parameter values that are scaled versions of the original. Here it's important to create the weight attribute with setattr. module.register_parameter(name=name, param=nn.Parameter(weight)) can also create the attribute, but this would lead to unintended behaviour. The back propagation would stop at weight and would not reach weight_orig, so the gradient w.r.t weight_orig would not be available.

Simple linear example

To help understand the above, consider the following example of a simplest linear layer without a bias: $$ y = w x $$ where \(w = 2\) is the "weight" parameter of the layer. Suppose, every time, before passing some input \(x\) through the layer, we want to scale \(w\) by a factor of 3. Then, the first thing is to rename \(w\) to \(w_{orig}\), so that

$$ w_{orig} = 2 \\ w = 3 w_{orig} \\ y = w x $$

\(w_{orig}\) is like weight_orig, and \(w\) is like the weight returned by EqualLR.compute_weight in EqualLR.__call__. After back propagation from \(y\), \(\partial y / \partial w_{orig}\) will become available and accessible via module.weight_orig.grad. So, when the optimizer takes a step, module.weight_orig will be updated according to this gradient.

If, as pointed out above, module.weight were instead created through the registration of a nn.Parameter, then, after back propagation from \(y\), \(\partial y / \partial w\), i.e. module.weight.grad, would become available, instead of \(\partial y / \partial w_{orig}\), which would no longer be updated by the optimizer. This is not the intended behaviour of equalized learning rate.


See also

comments powered by Disqus