Data augmentation

import numpy as np
import xarray as xr
from matplotlib              import pyplot as plt
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow_examples.models.pix2pix   import pix2pix

In order to enhance the generalizability of a model, one can make use of data augmentation.

By changing the training data set in certain ways (rotation, stretching, flipping, color and brightness adjustments, etc.) you can increase the size of the available training data set significantly without having to gather further labelled data (you just have to go sure to modify the labels accordingly). In this way, the model doesn’t overfit on the original training data as fast and its generalizability increases.

In our example, we saw that the model was previously “overfitted” on orthophoto images with certain conditions (facing north-upwards, sunny conditions, shadows on the top-left side of objects, etc.). This leads the model to look for patterns, that might not be transferable to other orthophotos (e.g. the model learns that houses are always situated at the bottom right of low-brightness areas (the houses shadows). To prevent the model from learning these “wrong” connections, we can augment the data. To prevent the model from learning that houses are at the bottom right corner of dark areas, we can rotate the input images, eg. by 90°, 180° and 270°.

So let’s load out training data again:

# Load data
training_data = xr.open_dataset("data/oberschelden.nc")
tilesize = 224

Within the image data generator we can then include the data augmentation. Tensorflow offers a lot of features for easily augmenting image data:

# Data generator for our model. Tiles the given xr.Dataset into multiple tiles and returns them in batches (row-wise)
# Also splits the data into training/validation/test
class CustomImageDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, ds, sampletype, tilesize=tilesize):
        self.ds   = ds
        self.ylen = self.ds.y.size // tilesize
        self.xlen = self.ds.x.size // tilesize
        self.sampletype = sampletype
        
    def __len__(self):
        return self.ylen

    def __getitem__(self, index):
        
        red       = self.ds.red[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        green     = self.ds.green[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        blue      = self.ds.blue[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        buildings = self.ds.buildings_mask[index*tilesize:(index+1)*tilesize,:-(self.ds.x.size%tilesize)]
        
        rgb       = np.array([red,green,blue]).transpose(1,2,0)
        buildings = np.array(buildings)
        
        # now split into tiles...
        rgb_tiles    = np.array(np.split(rgb, self.xlen,axis=1))
        target_tiles = np.array(np.split(buildings, self.xlen,axis=1))
        
        if self.sampletype == "training" or self.sampletype == "validation":
            rgb_tiles_training, rgb_tiles_validation, target_tiles_training, target_tiles_validation = train_test_split(rgb_tiles, target_tiles, shuffle=True, test_size=0.1, random_state=0)
            if self.sampletype == "training": 
                # do some augmentation (we do only rotate the images by 90, 180 and 270 degrees here)
                rgb_tiles_training = np.concatenate((rgb_tiles_training,
                                                     tf.image.rot90(image=rgb_tiles_training),
                                                     tf.image.rot90(image=rgb_tiles_training,k=2),
                                                     tf.image.rot90(image=rgb_tiles_training,k=3)),
                                                    axis=0)
                target_tiles_training = np.concatenate((target_tiles_training, 
                                                        tf.image.rot90(np.expand_dims(target_tiles_training,axis=-1))[:,:,:,0],
                                                        tf.image.rot90(np.expand_dims(target_tiles_training,axis=-1),k=2)[:,:,:,0], 
                                                        tf.image.rot90(np.expand_dims(target_tiles_training,axis=-1),k=3)[:,:,:,0]),
                                                       axis=0)
                
                ng = np.random.RandomState(42)
                indexes = ng.permutation(rgb_tiles_training.shape[0])
                rgb_tiles_training    = rgb_tiles_training[indexes]
                target_tiles_training = target_tiles_training[indexes]
                
                return rgb_tiles_training, target_tiles_training
            else:                             return rgb_tiles_validation, target_tiles_validation
        
        if self.sampletype == "test": return rgb_tiles, target_tiles
        
        return None

Train and apply a model

We can now load the pretrained MobileNetV2 model again and reshape it for our case study:

base_model = tf.keras.applications.MobileNetV2(input_shape=[tilesize, tilesize, 3], include_top=False)

layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[tilesize, tilesize, 3])
    x = inputs
    
    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    # This is the last layer of the model
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same')  #64x64 -> 128x128

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

OUTPUT_CHANNELS = 2

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

And now we can provide the model with our original and augmented data during training:

cidg_training   = CustomImageDataGenerator(training_data,sampletype="training")
cidg_validation = CustomImageDataGenerator(training_data,sampletype="validation")

model_history = model.fit(cidg_training,validation_data=cidg_validation,epochs=20,initial_epoch=0)
Epoch 1/20
49/49 [==============================] - 251s 5s/step - loss: 0.3108 - accuracy: 0.9162 - val_loss: 0.2399 - val_accuracy: 0.9551
Epoch 2/20
49/49 [==============================] - 249s 5s/step - loss: 0.1537 - accuracy: 0.9590 - val_loss: 0.4743 - val_accuracy: 0.9551
Epoch 3/20
49/49 [==============================] - 264s 5s/step - loss: 0.1284 - accuracy: 0.9589 - val_loss: 0.2661 - val_accuracy: 0.9553
Epoch 4/20
49/49 [==============================] - 266s 5s/step - loss: 0.1151 - accuracy: 0.9591 - val_loss: 0.1700 - val_accuracy: 0.9567
Epoch 5/20
49/49 [==============================] - 259s 5s/step - loss: 0.1073 - accuracy: 0.9608 - val_loss: 0.1704 - val_accuracy: 0.9566
Epoch 6/20
49/49 [==============================] - 256s 5s/step - loss: 0.1057 - accuracy: 0.9608 - val_loss: 0.1820 - val_accuracy: 0.9565
Epoch 7/20
49/49 [==============================] - 271s 6s/step - loss: 0.1060 - accuracy: 0.9605 - val_loss: 0.1108 - val_accuracy: 0.9589
Epoch 8/20
49/49 [==============================] - 249s 5s/step - loss: 0.1003 - accuracy: 0.9625 - val_loss: 0.1033 - val_accuracy: 0.9606
Epoch 9/20
49/49 [==============================] - 251s 5s/step - loss: 0.0982 - accuracy: 0.9629 - val_loss: 0.1105 - val_accuracy: 0.9580
Epoch 10/20
49/49 [==============================] - 257s 5s/step - loss: 0.0958 - accuracy: 0.9640 - val_loss: 0.1077 - val_accuracy: 0.9609
Epoch 11/20
49/49 [==============================] - 256s 5s/step - loss: 0.0962 - accuracy: 0.9641 - val_loss: 0.1240 - val_accuracy: 0.9599
Epoch 12/20
49/49 [==============================] - 242s 5s/step - loss: 0.0897 - accuracy: 0.9657 - val_loss: 0.1043 - val_accuracy: 0.9571
Epoch 13/20
49/49 [==============================] - 246s 5s/step - loss: 0.0898 - accuracy: 0.9654 - val_loss: 0.1010 - val_accuracy: 0.9621
Epoch 14/20
49/49 [==============================] - 252s 5s/step - loss: 0.0897 - accuracy: 0.9652 - val_loss: 0.1080 - val_accuracy: 0.9565
Epoch 15/20
49/49 [==============================] - 247s 5s/step - loss: 0.0905 - accuracy: 0.9655 - val_loss: 0.1035 - val_accuracy: 0.9612
Epoch 16/20
49/49 [==============================] - 239s 5s/step - loss: 0.0854 - accuracy: 0.9673 - val_loss: 0.1044 - val_accuracy: 0.9562
Epoch 17/20
49/49 [==============================] - 236s 5s/step - loss: 0.0851 - accuracy: 0.9675 - val_loss: 0.1104 - val_accuracy: 0.9536
Epoch 18/20
49/49 [==============================] - 235s 5s/step - loss: 0.0879 - accuracy: 0.9668 - val_loss: 0.1054 - val_accuracy: 0.9584
Epoch 19/20
49/49 [==============================] - 235s 5s/step - loss: 0.0842 - accuracy: 0.9678 - val_loss: 0.1185 - val_accuracy: 0.9498
Epoch 20/20
49/49 [==============================] - 241s 5s/step - loss: 0.0816 - accuracy: 0.9686 - val_loss: 0.1000 - val_accuracy: 0.9633

Plotting the training history shows us, that training and validation accuracy now lie closer to each other which implies that the model doesn’t overfit that fat anymore. Still, we can see a divergence in the later epochs. So there is still room for improvement here:

plt.plot(model_history.history['accuracy'])
plt.plot(model_history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.ylim((0.92,0.98))
plt.show()
../../../_images/05_enhancements_12_0.png
plt.plot(model_history.history['loss'])
plt.plot(model_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
../../../_images/05_enhancements_13_0.png

We quickly save the model in case sth. goes wrong so that we can reload it:

model.save('trained_models/UNet_buildings_with_augmentation')
model = tf.keras.models.load_model('trained_models/UNet_buildings_with_augmentation')

Finally, we can apply the model to our rotated test data set and see if it performs better now:

test_data = xr.open_dataset("data/volnsberg.nc")

# rotate 180°
for data_var in list(test_data.data_vars): test_data[data_var].values = np.rot90(test_data[data_var].values,k=2)

cidg_test = CustomImageDataGenerator(test_data,sampletype="test")

predictions = model.predict(cidg_test)
pred_mask = np.argmax(predictions, axis=-1)
xs = test_data.x.size-test_data.x.size%tilesize
ys = test_data.y.size-test_data.y.size%tilesize

x_cnt = test_data.x.size // tilesize
y_cnt = test_data.y.size // tilesize

final_result = np.concatenate([np.concatenate(pred_mask[i*x_cnt:(i+1)*x_cnt], axis=1) for i in range(y_cnt)], axis=0)

y_rest = (test_data.y.size%tilesize)
x_rest = (test_data.x.size%tilesize)

plt.figure(figsize=(18,18))
plt.imshow(np.array([test_data.red[:-y_rest,:-x_rest],test_data.green[:-y_rest,:-x_rest],test_data.blue[:-y_rest,:-x_rest]]).transpose(1,2,0))
plt.imshow(np.where(final_result==1,1,np.nan),cmap="autumn_r",interpolation="None",alpha=0.5)
plt.show()

m = tf.keras.metrics.Accuracy()
m.update_state(final_result, test_data.buildings_mask[:-y_rest,:-x_rest])
print("Accuracy: {0:0.5f}".format(m.result().numpy()))

# MeanIoU (Image segmentation metrics)
m = tf.keras.metrics.MeanIoU(num_classes=2)
m.update_state(final_result, test_data.buildings_mask[:-y_rest,:-x_rest])
print("MeanIoU:  {0:0.5f}".format(m.result().numpy()))
../../../_images/05_enhancements_18_0.png
Accuracy: 0.93575
MeanIoU:  0.65823

We see that the model now is able to detect houses also in orthophotos that were tilted by 180° which shows that its generalizability increased.