0

I am using a U-Net to segment cancer cells in images of patients' arms. I would like to add patient data to it in order to see if it is possible to enhance the segmentation (patient data comes in the form of a table containing features such as gender, age, etc.). So far, my researches have led me nowhere. What can I do in order to achieve this?

nbro
  • 39,006
  • 12
  • 98
  • 176
Skyris
  • 115
  • 3

1 Answers1

2

What you want to do is called multi-task learning. Here's what you do:

  1. Create a second Input.
  2. Attach it to 1D CNN (2-3 layers), so it aggregates this tabular information.
  3. Concatenate this feature with the intermediate feature generated by the U-Net using Concatenate layer.
  4. Put a dense layer of 2 after this.
  5. Put softmax with units = number of classes.
  6. Add CE loss calculated with this one and the ground labels to the loss of U-Net.

This is in regards to TensorFlow. The same can be done in PyTorch easily.

Abhishek Verma
  • 858
  • 3
  • 6