https://pytorch.org/blog/peak-performance-minimized-memory/
> "The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. "
Is this the same thing as you discuss above?