I'm training a text classifier in PyTorch and I'm experiencing an unexplainable cyclical pattern in the loss curve. The loss drops drastically at the beginning of each epoch and then starts rising slowly. However, the global convergence pattern seems OK. Here's how it looks:
The model is very basic and I'm using the Adam optimizer with default parameters and a learning rate of 0.001. Batches are of 512 samples. I've checked and tried a lot of stuff, so I'm running out of ideas, but I'm sure I've made a mistake somewhere.
Things I've made sure of:
- Data is delivered correctly (VQA v1.0 questions).
- DataLoader is shuffling the dataset.
- LSTM's memory is being zeroed correctly
- Gradient isn't leaking through input tensors.
Things I've already tried:
- Lowering the learning rate. Pattern remains, although amplitude is lower.
- Training without momentum (plain SGD). Gradient noise masks the pattern a bit, but it's still there.
- Using a smaller batch size (gradient noise can grow until it kinda masks the pattern, but that's not like solving it).
The model
class QuestionAnswerer(nn.Module):
def __init__(self):
super(QuestionAnswerer, self).__init__()
self._wemb = nn.Embedding(N_WORDS, HIDDEN_UNITS, padding_idx=NULL_ID)
self._lstm = nn.LSTM(HIDDEN_UNITS, HIDDEN_UNITS)
self._final = nn.Linear(HIDDEN_UNITS, N_ANSWERS)
def forward(self, question, length):
B = length.size(0)
embed = self._wemb(question)
hidden = self._lstm(embed)[0][length-1, torch.arange(B)]
return self._final(hidden)