Equalized Learning Rate is a trick that was introduced in the works on the progressive growing of GANs and StyleGAN, for stabilised and improved training.
The idea is to scale the parameters of each layer just before every forward propagation that passes through. How much to scale by is determined by a statistic calculated from the parameter values of each layer.
This post focuses on how this trick is implemented in the StyleGAN library github.com/rosinality/style-based-gan-pytorch, rather than on how the value of the scaling factor comes about.
Generalized implementation
In constrast with several other StyleGAN libraries where equalized learning rate is implemented separately for nn.Linear
and nn.Conv2d
modules, style-based-gan-pytorch has a class EqualLR
that can be applied to either nn.Linear
or nn.Conv2d
to introduce equalized learning rate in the module.
Here is its definition:
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * math.sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name) #1
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn) #2
return fn
def __call__(self, module, input): #3
weight = self.compute_weight(module) #4
setattr(module, self.name, weight) #5
-
In applying equalized learning rate to a module's
name
parameter, the module'sname
parameter is first renamed asf'{name}_orig'
, and thename
parameter itself is deleted. At this point, a forward propagation through the module would fail, because it does not have aname
attribute, which pytorch looks for by default. -
Therefore, just before each forward propagation, a
name
attribute needs to be created. This happens in thefn
function, and to run this function just before a forward propagation through the module, the function needs to be registered as a forward pre-hook, using the module'sregister_forward_pre_hook
method. -
Since a Python function is just a callable, here
fn
is really just an instance of theEqualLR
class, which has a__call__
method. -
In this method, the first thing to do is to retrieve the
weight_orig
parameter and scale it by some number, the details of which are in theEqualLR.compute_weight
method. The important thing to observe is that theweight
returned bycompute_weight
is just atorch.Tensor
, not ann.Parameter
. This makes it simply a non-leaf variable in the computation graph, and back propagation will pass through it to reach theweight_orig
parameter, the leaf variable to be updated during training. -
Then, the obtained
weight
tensor is set as the thename
attribute to the module. This means that forward propgagtion can go through the module now, with parameter values that are scaled versions of the original. Here it's important to create theweight
attribute withsetattr
.module.register_parameter(name=name, param=nn.Parameter(weight))
can also create the attribute, but this would lead to unintended behaviour. The back propagation would stop atweight
and would not reachweight_orig
, so the gradient w.r.tweight_orig
would not be available.
Simple linear example
To help understand the above, consider the following example of a simplest linear layer without a bias: $$ y = w x $$ where \(w = 2\) is the "weight" parameter of the layer. Suppose, every time, before passing some input \(x\) through the layer, we want to scale \(w\) by a factor of 3. Then, the first thing is to rename \(w\) to \(w_{orig}\), so that
$$ w_{orig} = 2 \\ w = 3 w_{orig} \\ y = w x $$
\(w_{orig}\) is like weight_orig
, and \(w\) is like the weight
returned by EqualLR.compute_weight
in EqualLR.__call__
. After back propagation from \(y\), \(\partial y / \partial w_{orig}\) will become available and accessible via module.weight_orig.grad
. So, when the optimizer takes a step, module.weight_orig
will be updated according to this gradient.
If, as pointed out above, module.weight
were instead created through the registration of a nn.Parameter
, then, after back propagation from \(y\), \(\partial y / \partial w\), i.e. module.weight.grad
, would become available, instead of \(\partial y / \partial w_{orig}\), which would no longer be updated by the optimizer. This is not the intended behaviour of equalized learning rate.