I have been playing with TPUs for a couple of months now, and to be honest I don't understand how can people use them in production for inference:
- almost no resources online showing how to run modern generative models like Mistral, Yi 34B, etc. on TPUs - poor compatibility between JAX and Pytorch - very hard to understand the memory consumption of the TPU chips (no nvidia-smi equivalent) - rotating IP addresses on TPU VMs - almost impossible to get my hands on a TPU v5
Is it only me? Or did I miss something?
I totally understand that TPUs can be useful for training though.
https://www.prnewswire.com/news-releases/google-announces-ex...
"Partnership includes important new collaborations on AI safety standards, committing to the highest standards of AI security, and use of TPU v5e accelerators for AI inference "
1. They have bought far less from NVidia than other hyper scalers, and they literally can’t vomit without saying “AI”. They have to be running those models on something. They have purchased huge amounts of chips from fabs, and what else would that be?
2. They have said they use them. Should be pretty obvious here.
3. They maintain a whole software stack for them, they design the chips, etc. Then they don’t really try to sell the TPU. Why else would they do this?
Google has been doing AI before any other company even thought about it. They are on the 6th generation of TPU hardware.
I don't think there is any maturity issue, just an availability issue because they are all being used internally.
If you aren't internal, the documentation, support, and even just general bug fixing is impossible to get.
This not even remotely true. SRI was working on AI in various forms long before google existed
That said, it wouldn’t be too difficult to port most models to Jax, load in the existing weights, and export the result for serving. Should you bother? IMO, no, unless we’re talking really large scale inference. Your time and money are almost certainly better spent iterating on the models.
--> We tried such ports at https://kwatch.io (the company I work for), and it appeared to be much harder than expected (at least for us). I don't think so many people are capable of porting an LLM based on PyTorch + GPU to Jax + TPU.
I'll just leave this here: https://jax.readthedocs.io/en/latest/pallas/index.html
https://aerospace.org/article/aerospaces-slingshot-1-demonst...
Speaking of SBCs, prior to the Raspberry Pi, I was looking at the Orange Pi 5 which has a Rockchip RK3588S with an NPU (Neural Processing Unit). This was the first I had heard of such a thing but I was curious how/what exactly it does. Unfortunately, there's very little support for Orange Pi & not a large community for it so I couldn't find any feedback on how well it worked or what it did.
http://www.orangepi.org/html/hardWare/computerAndMicrocontro...
It seems to me that Google does not really want to sell TPUs but only showcase their AI work and maybe get some early adopters feedback. It must be quite a challenge for them to create a dynamic community around JAX and TPUs if TPUs stay a vendor locked-in product...
Just as an example, over a decade ago I replaced a few cases filled with racks and a SAN that made up a compute cluster with one box (plus SAN) and a backup box (both boxes were basically the same in case one failed), but basically like dozens of servers were replaced by a two CPU box with a couple Tesla cards (probably one A100 later). The entire model had to be re-written, but it was not that bad. I wanted to do with AMD cards, but there was no easy way.
I would also say that modern networked has made all kinds of stuff more interesting (also lining Nvidia's pockets). Those TPU's do not make sense to me. I have no idea how to use them. They should release their version of CUDA.
https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#...