2

TLDR: given two tensors $t_1$ and $t_2$, both with shape $(c,h,w),$ how shall the distance between them be measured?


More Info: I'm working on a project in which I'm trying to distinguish between an anomalous sample (specifically from MNIST) and a "regular" sample (specifically from CIFAR10). The solution I chose is to consider the feature maps that are given by ResNet and use kNN. More specifically:

  • I embed the entire CIFAR10_TRAIN data to achieve a dataset that consists of activations with dimension $(N,c,h,w)$ where $N$ is the size of CIFAR_TRAIN
  • I embed $2$ new test samples $t_C$ and $t_M$ from CIFAR10_TEST and MNIST_TEST respectively (both with shape $(c,h,w)$), same as I did with the training data.
  • (!) I find the k-Nearest-Neighbours of $t_C$ and $t_M$ w.r.t the embedding of the training data
  • I calculate the mean distance between the $k$ neighbors
  • Given some predefined threshold, I classify $t_C$ and $t_M$ as regular or anomalous, hoping that the distance for $t_M$ would be higher, as it represents O.O.D sample.

Notice that in (!) I need some distance measure, but this is not trivial as these are tensors, not vectors.


What I've Tried: a trivial solution is to flatten the tensor to have shape $(c\cdot h\cdot w)$ and then use basic $\ell_2$, but the results turned out pretty bad. (could not distinguish regular vs anomalous in this case). Hence: Is there a better way of measuring this distance?

Hadar Sharvit
  • 371
  • 1
  • 12

1 Answers1

2

You could try an earth mover distance in 2d or 3d over the image? For example you could follow this example, but call it sequentially. The idea would be something like the following (untested and written on my cell phone):

def cumsum_3d(a):
    a = torch.cumsum(a, -1)
    a = torch.cumsum(a, -2)
    a = torch.cumsum(a, -3)
    return a

def norm_3d(a):
    return a / torch.sum(a, dim=(-1,-2,-3), keepdim=True)

def emd_3d(a, b):
    a = norm_3d(a)
    b = norm_3d(b)
    return torch.mean(torch.square(cumsum_3d(a) - cumsum_3d(b)), dim=(-1,-2,-3))

This should also work with batched data. I would also try normalizing the images first (so they each sum to 1) unless you want to account for changes in intensity.

John St. John
  • 206
  • 1
  • 4
  • can this function be used within torch.cdist, to account for pairwise distances of every pair? – Hadar Sharvit Jul 05 '22 at 12:11
  • It looks like torch.cdist only supports different values of p for the L_p distance. It doesn’t look like it supports applying a function to all pairs. Sounds like you want to make a kernel matrix? You could use something like https://pytorch.org/docs/stable/generated/torch.combinations.html and then stack the 0,1 elements from the tuples into two tensors, then run this? – John St. John Jul 05 '22 at 13:53
  • 1
    Or better you could flatten after doing the 3d cumsum in the example above. Then you could use cdist with p=2 to replace the subtract and square steps. You would give the function your list of per-sample flattened cumsum tensors (samples x flat) twice. Then that would be all pairwise distances with something kind of like an EMD. – John St. John Jul 05 '22 at 13:56
  • Wow, this has worked tremendously well. I highly appreciate your response. – Hadar Sharvit Jul 05 '22 at 14:18
  • Nice! Glad I could help. – John St. John Jul 05 '22 at 14:27
  • You could do the same process using torch.cdist and the embedding output (before classification layer) of a ResNET or whatever pretrained/frozen image model. You might want to divide each embedded vector by its norm first though so you have a cosine similarity rather than a pure l2 if you do that though. Embeddings from a good image model should be better than emd, but I bet emd is a pretty good baseline. – John St. John Jul 05 '22 at 15:30