Wait, this actually proves their point. It sounds like you didn’t train a model, but rather ran a bunch of ops. That’s fine, but the answer to their question would be "we didn’t actually measure loss or try to get a useful result, because the goal was to demonstrate raw throughput."
I agree that raw throughput is the metric to aim for, since figuring out how to put it to use is an exercise for the user. But it’s probably best to be straightforward about that. The reason MLPerf measures "time until loss reaches X for a resnet classifier on imagenet" is precisely because it gives information about performance at scale —- if you didn’t train anything, you haven’t actually achieved "fastest training run". You’ve achieved largest throughput, which is similar but not the same.
And I don’t think this is a pedantic distinction. Just throw LARS on it (the MLPerf code you used in 2019 is at https://github.com/shawwn/daxx-lightning fwiw, and it runs on pods last time I tried) and see how it performs in practice.
EDIT: reading over https://github.com/google/maxtext, it looks pretty delightful. I was in the TPU scene back in 2020, and there was no way to do ahead of time compilation. Restarting training runs was a major pain point once LLMs became the focus, and I kept pestering James Bradbury to please add it. Happy to see that it finally made its way in.
It sounds like MaxText is the right approach, but until you try to actually train a model —- to achieve a low loss on a specific dataset —- you can’t know whether the code works. This isn’t theoretical. I spent over a year debugging google’s public BigGAN code (compare_gan) and discovered why it never worked: the batch norm gamma parameter was initialized to zero instead of one, so everything was being multiplied by zero to start off, which severely crippled the model.
A bug like that could easily be lurking in MaxText. You can’t know until you try to train a useful LLM. Note that compare_gan seemed to work; the authors noted that they couldn’t replicate the performance of the official BigGAN paper, but the samples looked sort of reasonable. But the model was screwy, and no one knew why until the rigorous debugging process.
If you need help with this, let me know. There are challenges when training an actual LLM that aren’t present in theoretical runs like these. For example, you need a big dataset. The Pile is a good starting point for that, and it gives a nice comparison baseline, e.g. to GPT-J.
Alternatively, post a link to a tensorboard.dev showing the loss curves for your training runs. I suspect the reason you didn’t is because you didn’t have a real dataset. That’s ok, but it doesn’t prove that MaxText works until there’s empirical evidence.
In other words, DavidSJ was precisely right: it’s an impressive-looking benchmark, which doesn’t actually help your customers train LLMs in practice. They’ll need to solve this problem eventually, and the optimizer is certainly one aspect. The other is the quantized INT8 training. It may sound impressive to say it gives a 1.4x step count speedup, but that’s useless if it harms loss convergence. How do you know it doesn’t? This isn’t an easy question to answer unless you run MLPerf or some other known stable baseline, which I’m a little shocked no one has done yet.