BinaryNet on CIFAR10 (Advanced)

In this example we demonstrate how to use Larq to build and train BinaryNet on the CIFAR10 dataset to achieve a validation accuracy of around 90% using a heavy lifting GPU like a Nvidia V100. On a Nvidia V100 it takes approximately 250 minutes to train. Compared to the original papers, BinaryConnect: Training Deep Neural Networks with binary weights during propagations, and Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, we do not implement image whitening, but we use image augmentation, and a stepped learning rate scheduler.

import tensorflow as tf
import larq as lq
import numpy as np

Import CIFAR10 Dataset

Here we download the CIFAR10 dataset:

train_data, test_data = tf.keras.datasets.cifar10.load_data()

Next, we define our image augmentation technqiues, and create the dataset:

def resize_and_flip(image, labels, training):
    image = tf.cast(image, tf.float32) / (255./2.) - 1.
    if training:
        image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)
        image = tf.random_crop(image, [32, 32, 3])
        image = tf.image.random_flip_left_right(image)
    return image, labels
def create_dataset(data, batch_size, training):
    images, labels = data
    labels = tf.one_hot(np.squeeze(labels), 10)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset = dataset.repeat()
    if training:
        dataset = dataset.shuffle(1000)
    dataset = dataset.map(lambda x, y: resize_and_flip(x, y, training))
    dataset = dataset.batch(batch_size)
    return dataset
batch_size = 50

train_dataset = create_dataset(train_data, batch_size, True)
test_dataset = create_dataset(test_data, batch_size, False)

Build BinaryNet

Here we build the binarynet model layer by layer using a keras sequential model:

# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip",
              use_bias=False)

model = tf.keras.models.Sequential([
    # In the first layer we only quantize the weights and not the input
    lq.layers.QuantConv2D(128, 3,
                          kernel_quantizer="ste_sign",
                          kernel_constraint="weight_clip",
                          use_bias=False,
                          input_shape=(32, 32, 3)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(128, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(256, 3, padding="same", **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(256, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(512, 3, padding="same", **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantConv2D(512, 3, padding="same", **kwargs),
    tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),
    tf.keras.layers.Flatten(),

    lq.layers.QuantDense(1024, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantDense(1024, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),

    lq.layers.QuantDense(10, **kwargs),
    tf.keras.layers.BatchNormalization(momentum=0.999, scale=False),
    tf.keras.layers.Activation("softmax")
])

Larq allows you to print a summary of the model that includes bit-precision information:

lq.models.summary(model)
+sequential_1 stats-------------------------------------------------------------------+
| Layer                   Input prec.            Outputs   # 1-bit  # 32-bit   Memory |
|                               (bit)                                            (kB) |
+-------------------------------------------------------------------------------------+
| quant_conv2d_6                    -  (-1, 30, 30, 128)      3456         0     0.42 |
| batch_normalization_9             -  (-1, 30, 30, 128)         0       384     1.50 |
| quant_conv2d_7                    1  (-1, 30, 30, 128)    147456         0    18.00 |
| max_pooling2d_3                   -  (-1, 15, 15, 128)         0         0     0.00 |
| batch_normalization_10            -  (-1, 15, 15, 128)         0       384     1.50 |
| quant_conv2d_8                    1  (-1, 15, 15, 256)    294912         0    36.00 |
| batch_normalization_11            -  (-1, 15, 15, 256)         0       768     3.00 |
| quant_conv2d_9                    1  (-1, 15, 15, 256)    589824         0    72.00 |
| max_pooling2d_4                   -    (-1, 7, 7, 256)         0         0     0.00 |
| batch_normalization_12            -    (-1, 7, 7, 256)         0       768     3.00 |
| quant_conv2d_10                   1    (-1, 7, 7, 512)   1179648         0   144.00 |
| batch_normalization_13            -    (-1, 7, 7, 512)         0      1536     6.00 |
| quant_conv2d_11                   1    (-1, 7, 7, 512)   2359296         0   288.00 |
| max_pooling2d_5                   -    (-1, 3, 3, 512)         0         0     0.00 |
| batch_normalization_14            -    (-1, 3, 3, 512)         0      1536     6.00 |
| flatten_1                         -         (-1, 4608)         0         0     0.00 |
| quant_dense_3                     1         (-1, 1024)   4718592         0   576.00 |
| batch_normalization_15            -         (-1, 1024)         0      3072    12.00 |
| quant_dense_4                     1         (-1, 1024)   1048576         0   128.00 |
| batch_normalization_16            -         (-1, 1024)         0      3072    12.00 |
| quant_dense_5                     1           (-1, 10)     10240         0     1.25 |
| batch_normalization_17            -           (-1, 10)         0        30     0.12 |
| activation_1                      -           (-1, 10)         0         0     0.00 |
+-------------------------------------------------------------------------------------+
| Total                                                   10352000     11550  1308.79 |
+-------------------------------------------------------------------------------------+
+sequential_1 summary-------------+
| Total params           10363550 |
| Trainable params       10355850 |
| Non-trainable params   7700     |
| Float-32 Equivalent    39.53 MB |
| Compression of Memory  30.93    |
+---------------------------------+

Model Training

We compile and train the model as you are used to in Keras:

initial_lr = 1e-3
var_decay = 1e-5

optimizer = tf.keras.optimizers.Adam(lr=initial_lr, decay=var_decay)
model.compile(
    optimizer=lq.optimizers.XavierLearningRateScaling(optimizer, model),
    loss="categorical_crossentropy",
    metrics=["accuracy"],
)
def lr_schedule(epoch):
    return initial_lr * 0.1 ** (epoch // 100)

trained_model = model.fit(
    train_dataset,
    epochs=500,
    steps_per_epoch=train_data[1].shape[0] // batch_size,
    validation_data=test_dataset,
    validation_steps=test_data[1].shape[0] // batch_size,
    verbose=1,
    callbacks=[tf.keras.callbacks.LearningRateScheduler(lr_schedule)]
)