see https://github.com/pytorch-labs/torchtrain/blob/main/train.py#L87 we should have the following metrics associated with the train steps: 1. gpu memory usage 2. wps 3. loss