Challenges of Training Models on Medical Data
Techniques to tackle Class Imbalance, Multi-Task, and Dataset Size
Amongst the many problems faced during training algorithms on medical datasets, these three are most common:
- Class Imbalance challenge
- Multi-Task challenge
- Dataset Size challenge
For each of these problems, I will share a few techniques to tackle them. So let’s start with them one by one!
Class Imbalance challenge
In the real world, we see a lot more healthy people than diseased people and this is reflected in medical datasets as well. There is not an equal distribution of the number of examples of healthy and diseased classes. This is a reflection of the prevalence or the real-world frequency of disease. In not just medical datasets but also datasets for credit card fraud, you might see a hundred times as many normal examples as abnormal examples.
As a result, it is easy to be tricked into the illusion of the model performing very well whereas it really isn’t doing so. This can happen if simple metrics like accuracy_score are used. Accuracy isn’t a great metric for this kind of datasets since the labels are heavily skewed, so a neural network that just outputs normal would get slightly over 90% accuracy.
We could define more useful metrics such as F1 score or Precision/Recall. Precision is defined as the number of True Positives divided by the number of True Positives and False Positives. It is a good metric to use when the cost of False Positives is high. Recall on the other hand is defined as the number of True Positives divided by the number of True Positives and the number of False Negatives. It is a good metric to use when the cost of False Negatives is high. This is the case with most models in the medical field. However, often we need to take into consideration both False Positives and False Negatives and that’s what F1 score does. It strikes a balance between Precision and Recall and is given by the formula 2 * ((Precision*Recall) / (Precision+Recall)).
Another popular technique to deal with class imbalance is something called Resampling. It is the…