U-Net Segmentation BraTS 2019

Naomi Fridman
4 min readNov 14, 2019

Multi class Image Segmentation

BraTS brain tumors segmentation challenge

BraTS is a challenge of segmentation of brain tumors in multimodal magnetic resonance imaging (MRI) scans. BraTS 2019 utilizes multi-institutional pre-operative MRI scans and focuses on the segmentation of intrinsically heterogeneous (in appearance, shape, and histology) brain tumors, namely gliomas.

The group article of BRATS challenge, including this work, can be found here: Machine_Learning_Algorithms_for_Brain_Tumor_Segmentation…

Brain tumor multi class unet segmentation
Ground truth label and model predictions

Loading Dicom MRI data

The data contains 259 Train images with 4 label types: And 76 validation images. Images are MRI dicom files with 4 channels: ‘FLAIR’, ‘T1’,’T1CE’, ‘T2’. SimpleITK python package is used to read, write and convert MRI data to python. For convenience, I change label 4 to 3. Train data was split,with 0.2 test data. Label are not equal, distribution is:

Tumor segmentation Label types

From viewing few segmented images, we can predict that the distinction between labels, will be the hardest part.

UNET Architecture and Training

The UNET was developed by Olaf Ronneberger et al. for Bio Medical Image Segmentation. The original network won the ISBI cell tracking challenge 2015, by a large margin, and became since the state-of-the-art deep learning tool for image segmentation.

The original network was built for 512x512x3 microscopy images, here its modified to an image shape of 240x240x4.

inputs = Input(input_size)

conv1 = Conv2D(64, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(inputs)
conv1 = Conv2D(64, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(pool1)
conv2 = Conv2D(128, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(pool2)
conv3 = Conv2D(256, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(pool3)
conv4 = Conv2D(512, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv4)
drop4 = Dropout(dropout)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(pool4)
conv5 = Conv2D(1024, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv5)
drop5 = Dropout(dropout)(conv5)
up6 = Conv2D(512, 2, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(merge6)
conv6 = Conv2D(512, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv6)
up7 = Conv2D(256, 2, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(merge7)
conv7 = Conv2D(256, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv7)
up8 = Conv2D(128, 2, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(merge8)
conv8 = Conv2D(128, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv8)
up9 = Conv2D(64, 2, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(merge9)
conv9 = Conv2D(64, 3, activation = ‘relu’, padding = ‘same’, kernel_initializer = hn)(conv9)

conv10 = Conv2D(IMG_CHANNELS, (1,1), activation = ‘softmax’)(conv9)

model = Model(inputs = inputs, outputs = conv10)

To compensate label imbalance, train batches where constructed in the following way: half of the batch are randomly selected planes from few randomly selected images. And the other half, are planes that include the “rare” labels, 1 and 3. For test images, random planes where chosen.

Training:

history = model.fit_generator(gen_train_fast,
validation_data = gen_test_im, validation_steps=1,
steps_per_epoch=30,
epochs=100,
callbacks=[earlystopper, checkpointer, reduce_lr])

Model learned up to the 11 epoch, and reached:

loss: 0.0519 - accuracy: 0.9819 
val_loss: 0.0323 - val_accuracy: 0.9876

Lets view few predictions:

Collective article, that includes this work: https://www.researchgate.net/publication/346716362_Identifying_the_Best_Machine_Learning_Algorithms_for_Brain_Tumor_Segmentation_Progression_Assessment_and_Overall_Survival_Prediction_in_the_BRATS_Challenge

Code in Github: https://github.com/naomifridman/Unet_Brain_tumor_segmentation

A lot more can be done, mainly in augmentation and loss function.

Please don’t hesitate to write remarks, question and ideas.

--

--

Naomi Fridman

MSc. Mathematics. Data Scientist. Love Deep learning ,Machine learning , Mathematics and Surfing. https://github.com/naomifridman