We're working on a sequence-to-sequence problem using pytorch, and are using cross-entropy to calculate the loss when comparing the output sequence to the target sequence. This works fine and penalizes the model correctly. However, we also have the added constraint that the prediction can't contain repeated indices, e.g.
good: [1, 2, 3, 4, 5]
bad: [1, 2, 2, 4, 5]
We would like to add an additional penalty term that punishes the model further for producing duplicate indices, which would be added to the cross-entropy loss.
How would I construct this additional loss function in pytorch?
PS: Yes it's true that we could just hack the code-generation piece to not generate duplicate indices and then incorporate this into our beam-search, but I would like to first see whether this additional constraint produces a better model!