Neural networks with extremely low-precision weights and activations, such as Binarized Neural Networks (BNNs), usually contain a mix of low-precision weights (e.g. 1-bit) and higher-precision weights (e.g. 8-bit, 16-bit, or 32-bit). Examples of this include the first and last layers of image classificiation models, which have higher-precision weights in most BNN architectures from the literature.
Training a BNN, then, consists of optimizing both low-precision and higher-precision weights. In
larq, we provide a mechanism to target different bit-precision variables with different optimizers using the
CaseOptimizer class. Modeled after the
CaseOptimizer accepts pairs of predicates and optimizers. A predicate, given a variable, decides whether its optimizer should train that variable.
CaseOptimizer behaves much like any other Keras optimizer, and once you instantiate it you can pass it to your
model.compile() as usual. To instantiate a
CaseOptimzer, pass one or a list of
(predicate, optimizer) tuples, along with a
default optimizer which trains any variables not claimed by another optimizer. A variable may not be claimed by more than one optimizer's predicate.
case_optimizer = lq.optimizers.CaseOptimizer( ( lq.optimizers.Bop.is_binary_variable, # predicate lq.optimizers.Bop(threshold=1e-6, gamma=1e-3), # optimizer ), default_optimizer=tf.keras.optimizers.Adam(0.01), )
An optimizer is used to train a variable iff its accompanying predicate evaluates to
For each variable, at most one optimizer's predicate may evaluate to
True. If no optimizer's predicate evaluates to
True for a variable, it is trained with the
default_optimizer. If a variable is claimed by no optimizers and
default_optimizer == None, the variable is not trained.
predicate_optimizer_pairs: One or more
(pred, tf.keras.optimizers.Optimizer)pairs, where
tf.Variableas argument and returns
Trueif the optimizer should be used for that variable, e.g.
pred(var) == True.
tf.keras.optimizers.Optimizerto be applied to any variable not claimed by any other optimizer. (Must be passed as keyword argument.)
Bop(threshold=1e-08, gamma=0.0001, name='Bop', **kwargs)
Bop is a latent-free optimizer for Binarized Neural Networks (BNNs) and Binary Weight Networks (BWN).
Bop maintains an exponential moving average of the gradients controlled by
gamma. If this average exceeds the
threshold, a weight is flipped.
gamma is somewhat analogues to the learning rate in SGD methods: a high
gamma results in rapid convergence but also makes training more noisy.
Note that the default
threshold is not optimal for all situations. Setting the threshold too high results in little learning, while setting it too low results in overly noisy behaviour.
optimizer = lq.optimizers.CaseOptimizer( ( lq.optimizers.Bop.is_binary_variable, lq.optimizers.Bop(), ), default_optimizer=tf.keras.optimizers.Adam(0.01), # for FP weights )
threshold: determines to whether to flip each weight.
gamma: the adaptivity rate.
name: name of the optimizer.