Add weights to deal with class imbalance
Currently all samples are weighted equally which means that classes that have more samples will affect training process more towards their desired result. Because of that, the model will have poor performance on less frequent classes. To fix that you can set different weights for different classes.
Use weight
parameter of NLLLoss
class https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html.
A standard approach to me seems to be the following formula:
Weight for class i
= total nr of training samples / training samples in class i