3

I've done an MNIST digit recognition neural network.

When you put images in that are completely unlike its training data, it still tries to classify them as digits. Sometimes it strongly classifies nonsense data as being a specific digit.

I am interested in the problem of rejecting nonsense inputs generally, and want a solution that is effective for MNIST in particular. I think the most basic approach is just to have a confidence threshold for the standard network. I have also heard about a Bayesian approach.

I have read that in high dimensional vectors like these bitmap images "almost all data is extrapolation", I think this makes rejection of out of distribution a difficult problem, but I feel like it is an important problem too so I will appreciate information on this. Thank you.

river
  • 133
  • 6
  • 1
    have you considered adding a not-a-digit class and adding a bunch of not-a-digit examples to the training set? – user253751 Apr 24 '23 at 21:50

1 Answers1

2

There are evidences that neural networks trained on a given dataset (or domain) like MNIST, wrongly put some probability mass on datasets from a completely different domain, i.e. the out-of-distribution (OOD) data.

This is an open research topic that is also related to anomaly and novelty detection, and as well as out-of-domain generalization: the latter about correctly predicting OOD data even without training on them.

In your case, what you want to do is to constrain the classifier to learn a representation of each digit that is disjoint (has no support) w.r.t. OOD data. There is not yet an universal recipe for this, so I'll write some ideas:

  1. As suggested by @user253751 you can pick some OOD data, like fashion-MNIST and CIFAR-10, and add an extra non-digit class and train your classifier again. This is a very simple approach that might improve OOD performance a bit.
  2. A more involved approach would make use of some representation learning technique, like self-supervised contrastive learning (e.g. SimSiam). In this case you use data-augmentation to enable the model better understand what a digit is. Indeed, I think that combining this method with approach 1, should be quite effective.
  3. You can use advanced data-augmentations like CutMix and MixUp to further regularize your model, by explicitly training it in the "interpolation region" between classes. This can be either applied to a MNIST-only training, as well method 1 and 2.
  4. AE + classifier: first train an auto-encoder (AE: see this and VAEs) to reconstruct the digits, then you train a classifier from its latent space. The AE should have learned a compact representation of the digits, so train a classifier from that representation may be helpful. This is also reasonably simple to implement.
  5. In the same spirit of 2, there are more complicated methods that involve clustering, like deep embedding clustering.
  6. Bayesian NN: I elaborate a bit on this since you mentioned it. A BNN learns a probability distribution over weights (doing this per-layer), meaning that each time you forward the network a different realization of the weights is sampled. You can exploit the variance of the predictions (due to sampled weights) to sort of estimate a confidence score for each sample, a part from the predicted class. In principle, you can accept only samples with a minimum confidence level and reject the others. I don't know how this would perform, but you can try it as well. (In case you use tensorflow, I suggest you to use tensorflow-probability for easy BNN implementation.)
Luca Anzalone
  • 2,120
  • 2
  • 13
  • 1
    I want to add another approach: GANs are typically trained in a fashion where one side produces an image of a digit, and the other side has to decide whether a given sample is real or made up. Doing this with the MNIST dataset might result in a discriminator that is able to tell digits from non-digits. – N. Kiefer Apr 26 '23 at 11:13
  • 1
    @N.Kiefer True, but actually the GAN's generator is trained by fooling the discriminator. Moreover, at least theoretically, the optimum (i.e. equilibrium of the min-max game between generator and discriminator) is achieved when the discriminator outputs 0.5 probability for both fake and true data. This means the generator is high quality, but that the discriminator is also unusable too. – Luca Anzalone Apr 26 '23 at 17:27