U-Net understanding and implementation in python

One of the biggest challenges in the world of Computer Vision is Image Segmentation. We have played around with Image classification a lot in fact when we hear of Computer Vision applications we immediately think of convolutional neural networks and their use in image classification. However theirs a new problem in town and it’s called ‘Semantic segmentation’. Spoiler Alert! U-Net is the solution.

Well, it’s not exactly a new problem, it has been around for several years, but unlike Classification, Image Segmentation is a much more complex task and requires a much complex architecture for training and loads of training data. The U-Net architecture was introduced in 2015 and was specifically aimed at solving the problem of biomedical image segmentation.

In this article we will discuss the following things:

  • Semantic Segmentation
  • U-Net Architecture explanation
  • U-Net Implementation

Semantic Segmentation

Where Image Classification aims at predicting a single class for the whole input image, Image segmentation has two jobs to perform: localization and classification.

Localization means finding the location (pixels) of a particular object within a much larger image. Classification comes next and is self-explanatory; it means to classify the object that has been localized within the image.

It’s like a game of Where’s Waldo. You first look for an object that resembles a human/cartoon and then you classify whether it is Waldo or not.

Semantic segmentation has several real-world applications. Perhaps the most common one is in self-driving cars. The cameras attached to autonomous vehicles are able to distinguish all the objects that are present in its view and are able to notify the system when to turn/stop or drive depending on what is in the path in front of it.

The U-Net Architecture

Skateboard ramp representing U-shaped Architecture of the U-Net

First introduced in 2015 in the paper U-Net: Convolutional Networks for Biomedical Image Segmentation by Ronneberger et. al., the U-Net boasts an approach to image segmentation that outperformed its competitor at the time, a sliding window convolutional network, all the while using fewer images in the training dataset and making use of image augmentation to increasing the learning capability of the network.

The name itself gives away the overall shape of the network. It’s… well… a U-shaped network consisting of a contracting path and an expansive path. The network itself looks like a skateboard ramp, the basic intuition is that while on the downslope (contracting path) the network learns to classify the object, and while on the upslope (Expansive path) the networks on the localization of the object.

Overall architecture of the U-Net
The overall architecture of U-Net as depicted in the original research paper

Looking at the network you can see that the corresponding layers on the contracting are passing over information to the expansive path. This way the classification context is transferred over to the localization module which makes the overall network so good.

If you’re interested in other deep learning projects, you can start here.

U-Net Implemented in Keras

The first is always to import the relevant libraries.

import numpy as np
from tensorflow.keras.layers import Activation, Dense, Dropout, Conv2D, Conv2DTranspose, MaxPooling2D, Concatenate, Input, Cropping2D, Flatten
from keras.models import Model
from tensorflow import keras

Mostly we use the Keras sequential API to build models however, that is only feasible when you have a linearly running model. However, we can see in the diagram above that we have features from one layer being passed onto another layer that comes later in the network. To build this network we need to use the Functional API.

Contracting Path
The first few layers of the U-Net
First block on conv nets.

We’ll start with these first few layers as these represent our basic Convolutional Neural Network


x_input = Input(shape =(572,572, 3))
print(x_input)
#-------------------------------------------------------------------------
conv_1 = Conv2D(64, 3, activation = 'relu')(x_input)
conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)
pool_1 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_2)

 The ‘Input’ function helps us define the input to the model and the remaining is a series of basic convolutional operations.

Second block of the contracting path
The second block in the contracting path.

The output from the last Conv layer is Max-pooled in the next block and then the same CNN architecture.


#-------------------------------------------------------------------------
conv_3 = Conv2D(128, 3, activation = 'relu')(pool_1)
conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)
pool_2 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_4)

The same network is repeated 4 times, each time increasing the number of channels in the output. This resembles a very simple CNN model that is most often used for Image classification. This is the exact purpose of the contracting path as we discussed in the previous section.

The following block of code represents the entire contracting path

x_input = Input(shape =(572,572, 3))
print(x_input)
#-------------------------------------------------------------------------
conv_1 = Conv2D(64, 3, activation = 'relu')(x_input)
conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)
pool_1 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_2)
#-------------------------------------------------------------------------
conv_3 = Conv2D(128, 3, activation = 'relu')(pool_1)
conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)
pool_2 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_4)
#-------------------------------------------------------------------------
conv_5 = Conv2D(256, 3, activation = 'relu')(pool_2)
print(conv_5.shape)
conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)
print(conv_6.shape)
pool_3 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_6)
#-------------------------------------------------------------------------
conv_7 = Conv2D(512, 3, activation = 'relu')(pool_3)
conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)
print(conv_8.shape)
pool_4 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_8)

There is one more important thing to note here. The paper mentions one very important detail about the input size given to the network.

To allow a seamless tiling of the output segmentation map, it is important to select the input tile size such that all 2×2 max-pooling operations are applied to a layer with an even x- and y-size.

Validating Shapes

So to ensure that we have an input size that would always result in an even dimensioned input to the max-pooling layer I created the following function.

def model_UNET_validation(x,y):    
    #Descending Layers
x_input = Input(shape =(x, y,3))
    print(x_input)
    #-------------------------------------------------------------------------
    conv_1 = Conv2D(64, 3, activation = 'relu')(x_input)
    conv_2 = Conv2D(64, 3, activation = 'relu')(conv_1)
    if ((conv_2.shape[1] % 2) !=0) or ((conv_2.shape[1] % 2) !=0): #check if input dimension is an even number
        raise Exception("Input shape is invalid, please choose a different shape")
        return None
    pool_1 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_2)
    #-------------------------------------------------------------------------
    conv_3 = Conv2D(128, 3, activation = 'relu')(pool_1)
    conv_4 = Conv2D(128, 3, activation = 'relu')(conv_3)
    if ((conv_4.shape[1] % 2) !=0) or ((conv_4.shape[1] % 2) !=0):
        raise Exception("Input shape is invalid, please choose a different shape")
        return None
    pool_2 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_4)
    #-------------------------------------------------------------------------
    conv_5 = Conv2D(256, 3, activation = 'relu')(pool_2)
    conv_6 = Conv2D(256, 3, activation = 'relu')(conv_5)
    if ((conv_6.shape[1] % 2) !=0) or ((conv_6.shape[1] % 2) !=0):
        raise Exception("Input shape is invalid, please choose a different shape")
        return None
    pool_3 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_6)
    #-------------------------------------------------------------------------
    conv_7 = Conv2D(512, 3, activation = 'relu')(pool_3)
    conv_8 = Conv2D(512, 3, activation = 'relu')(conv_7)
    if ((conv_8.shape[1] % 2) !=0) or ((conv_8.shape[1] % 2) !=0):
        print("Input shape is invalid, please choose a different shape")
        return None
    pool_4 = MaxPooling2D(pool_size=(2, 2),strides=2, padding='same')(conv_8)
    #final layer for output
    flat = Flatten()(pool_4)
    out = Dense(1, activation='softmax')(flat)

The above function is nothing special but simply the complete contracting path but with exceptions thrown if the input size doesn’t traverse the network completely.

Intermediate Layer

Now we move on to the middle/intermediate layer.

intermediate layer of the U-Net
intermediate layer

This is nothing but another series of convolutions so…

#-------------------------------------------------------------------------
conv_9 = Conv2D(1024, 3, activation = 'relu')(pool_4)
conv_10 = Conv2D(1024, 3, activation = 'relu')(conv_9)
Expanding path

Now we move to the expanding path (This is where the network learns to localize the object with help from a little context from the previous layers)

Expansive path
Beginning of the expansive path

The contracting path involves 2 new concepts that you may have not seen before: Deconvolution and Layer Concatenation.

Deconvolution is the exact opposite of convolution, the outputs get bigger in dimension when you apply deconvolution. Concatenation is simply joining two layers together (on the third axis in our case). The diagram below explains this better.

A simple illustration of what deconvolution and layer concatenation looks like
A simple illustration of what deconvolution and layer concatenation looks like

Before we pass on the information from one layer to another we need to crop the information. This can be done using the cropping2D layer but we also need to set a general rule for the output shapes so that we don’t encounter errors of dimensionality mismatch.

We’ve already imported the relevant layers so here’s how we implement it.

The following helper function will help determine the shape after cropping an output.“`

def get_cropping_shape(previous_layer_shape, current_layer_shape):
    return int((previous_layer_shape - current_layer_shape)/2)
#Ascending Layers
up_conv_10a = Conv2DTranspose(512, 2,strides=(2, 2),padding="same")(conv_10)
crop_shape_conv_8 = get_cropping_shape(conv_8.shape[1], up_conv_10a.shape[1])
conv_8_cropped = Cropping2D((crop_shape_conv_8,crop_shape_conv_8))(conv_8)
up_conv_10b = Concatenate(axis = 3)([up_conv_10a, conv_8_cropped])
conv_11 = Conv2D(512, 3, activation = 'relu')(up_conv_10b)
conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)
#-------------------------------------------------------------------------

So for the expansive path, we take the following steps.

  1. Carry out deconvolution on the output of the intermediate layer. 
  2. Crop the output from the contracting path.
  3. Join the contracted output and the deconvolved output on the third axis.
  4. Carry out normal convolution for the remaining layer.

The above steps are carried out in 4 batches and the overall code for the expansive path is.

#Ascending Layers
up_conv_10a = Conv2DTranspose(512, 2,strides=(2, 2),padding="same")(conv_10)
crop_shape_conv_8 = get_cropping_shape(conv_8.shape[1], up_conv_10a.shape[1])
conv_8_cropped = Cropping2D((crop_shape_conv_8,crop_shape_conv_8))(conv_8)
up_conv_10b = Concatenate(axis = 3)([up_conv_10a, conv_8_cropped])
conv_11 = Conv2D(512, 3, activation = 'relu')(up_conv_10b)
conv_12 = Conv2D(512, 3, activation = 'relu')(conv_11)
#-------------------------------------------------------------------------
up_conv_13a = Conv2DTranspose(256, 2,strides=(2, 2),padding="same")(conv_12)
crop_shape_conv_6 = get_cropping_shape(conv_6.shape[1], up_conv_13a.shape[1])
conv_6_cropped = Cropping2D((crop_shape_conv_6,crop_shape_conv_6))(conv_6)
up_conv_13b = Concatenate(axis = 3)([up_conv_13a, conv_6_cropped])
conv_14 = Conv2D(256, 3, activation = 'relu')(up_conv_13b)
conv_15 = Conv2D(256, 3, activation = 'relu')(conv_14)
#-------------------------------------------------------------------------
up_conv_16a = Conv2DTranspose(128, 2,strides=(2, 2),padding="same")(conv_15)
crop_shape_conv_4 = get_cropping_shape(conv_4.shape[1], up_conv_16a.shape[1])
conv_4_cropped = Cropping2D((crop_shape_conv_4,crop_shape_conv_4))(conv_4)
up_conv_16b = Concatenate(axis = 3)([up_conv_16a, conv_4_cropped])
conv_17 = Conv2D(128, 3, activation = 'relu')(up_conv_16b)
conv_18 = Conv2D(128, 3, activation = 'relu')(conv_17)
#-------------------------------------------------------------------------
up_conv_19a = Conv2DTranspose(64, 2,strides=(2, 2),padding="same")(conv_18)
crop_shape_conv_2 = get_cropping_shape(conv_2.shape[1], up_conv_19a.shape[1])
conv_2_cropped = Cropping2D((crop_shape_conv_2,crop_shape_conv_2))(conv_2)
up_conv_19b = Concatenate(axis = 3)([up_conv_19a, conv_2_cropped])
conv_19 = Conv2D(64, 3, activation = 'relu')(up_conv_19b)
conv_20 = Conv2D(64, 3, activation = 'relu')(conv_19)
#final layer for output
out = Conv2D(2, 1, activation = 'relu')(conv_20)

You may have noticed that the above network has not used a single fully connected layer. Even the final output is a 2-D image since the network directly outputs the mask of the input image.

To summarise the model…

model = Model(inputs=x_input, outputs=out)
# summarize layers
print(model.summary())
Model Summary

We can see it’s a quite large network.

Visualize the Model

Keras provides us with a very nice function to plot the structure of the model.

keras.utils.plot_model(model, to_file = "Model.png", show_shapes=True)
Plot of our constructed model.
U-Net Structure that we implemented

Now that’s a U shape.

Final Thoughts

UNet was a very well-received architecture back in 2015 and won several competitions related to Computer Vision. Even today it can be seen in several competitions on Kaggle. Understanding this model was a fun journey and I hope it helped you clear your concepts. If you would like to look at another example of computer vision using deep neural networks then click here.

Leave a Reply

Your email address will not be published.