if you made something 2x faster, you might have done something smart
if you made something 100x faster, you definitely just stopped doing something stupidThe article showed how a linear, O(N), algorithm running on a lowly 8-bit CPU can beat a cubic algorithm, O(N^3), running on a Cray supercomputer, when N is sufficiently large.
see https://www.cs.rpi.edu/~moorthy/Courses/CSCI2300/p865-bentle...
An example that is basically unrelated to complexity theory is something like talking to a distant service but keeping a small number of requests because eg you were worried about load or didn’t notice you were waiting for acks or create a new tcp connection for each request or have a small sendbuf or somehow send way too fast and get rate limited and need to retry.
My long-term ambition is to replicate OpenAI's Dota 2 reinforcement learning work, since it's one of the most impactful (or at least most entertaining) use of RL. It would be more or less impossible to translate the game logic into Jax, short of transpiling C++ to Jax somehow. Which isn't a bad idea – someone should make that.
It should also be noted that there's a long history of RL being done on accelerators. AlphaZero's chess evaluations ran entirely on TPUs. Pytorch CUDA graphs also make it easier to implement this kind of thing nowadays, since (again, as much as I love Jax) some Pytorch constructs are simply easier to use than turning everything into a functional programming paradigm.
All that said, you should really try out Jax. The fact that you can calculate gradients w.r.t. any arbitrary function is just amazing, and you have complete control over what's JIT'ed into a GPU graph and what's not. It's a wonderful feeling compared to using Pytorch's accursed .backwards() accumulation scheme.
Can't wait for a framework that feels closer to pure arbitrary Python. Maybe AI can figure out how to do it.
While there is work on putting RL environments on accelerators, the main speedup from this work comes from also training many RL agents in parallel. This is largely because the neural networks we use in RL are relatively small and thus don't utilize the GPU very efficiently.
While this was always possible to do, Jax makes it way easier because we just need to call `jax.vmap` to get it to work.
TPUs were used for neural network inference and training, but game logic as well as MCTS was on the CPU using C++.
JAX is awesome though, I use it for all my neural network stuff!
> Training proceeded for 700,000 steps (mini-batches of size 4,096) starting from randomly initialised parameters, using 5,000 first-generation TPUs to generate self-play games and 64 second-generation TPUs to train the neural networks. Further details of the training procedure are provided in the Methods.
It's almost like the author is claiming credit for creating Nvidia, when in fact he is just calling its APIs.
The reason we write about Jax is that doing this technique is really hard in PyTorch / Tensorflow. This is because:
1. Jax has vmap. (PyTorch does now too, but it is far more recent).
2. There are RL environments that others have written in pure Jax (see the blog post for four different repos of RL environments)
3. As m00x hints to, Jax replicates Numpy's API. This makes it way easier to use for non-neural network programming (e.g. RL environments).
You could do the same with tensorflow and pytorch, but in my experience, with more difficulty since they're more opinionated about how you should do your operations.
JAX is definitely easier to do things that aren't on rails.
I wonder how Julia is placed for running reinforcement learning algorithms (efficiently) — particularly in cases when the “environment” is nicely wrapped in Python to fit some standardized interface.