We develped DiscoGrad, a tool for automatic differentiation through C++ programs involving input-dependent control flow (e.g., "if (f(x) < c) { ... }", differentiating wrt. x) and randomness. Our initial motivation was to enable the use of gradient descent with simulations, which often rely heavily on such discrete branching. The latter makes plain autodiff mostly useless, since it can only account for the single path taken through the program. Our tool offers several backends that handle this situation, giving useful descent directions for optimization by accounting for alternative branches. Besides simulations, this problem arises in many other places, for example in deep learning when trying to combine imperative programs with neural networks.
In a nutshell, DiscoGrad applies an (LLVM-based) source-to-source transformation to your C++ program, adding some calls to our header library, which then handles the gradient computation. What sets it apart from similar tools/estimators is that it's fully automatic (no need to come up with a differentiable problem formulation/reparametrization) and that the branching condition can be any function of the program inputs (no need to know upfront what distribution the condition follows).
We're currently a team of two working on DiscoGrad as part of a research project, so don't expect to see production-grade code quality, but we do intend for it to be more than a throwaway research prototype. Use cases we've successfully tested include calibrating simulation models of epidemics or evacuation scenarios via gradient descent, and combining simulations with neural networks in an end-to-end trainable fashion.
We hope you find this interesting and useful, and we're happy to answer questions!
EDIT: removed part of the question that is answered in the article.
As an example, the very first thing we looked into was a transportation engineering problem, where the red/green phases of traffic lights lead to a non-smooth optimization problem. In essence, in that case we were looking for the "best possible" parameters for a transportation simulation (in the form of a C++ program) that's full of branches.
In all seriousness, this is super interesting. I really like the idea of implementing gradient descent solving branch by branch, and turning it into an optimization-level option for codebases.
I feel like this would normally be something commercialized by Intel's compiler group; it's hard for me to know how to get this out more broadly -- it would probably need to be standardized in some way?
Anyway, thanks for working on it and opening it up -- very cool. Needs more disco balls.
We were thinking of some disco ball-based logo (among some other designs). With this encouragement, there'll probably be an update in the next few days :)
I remember being taught how to write Prolog in University, and then being shown how close the relationship was between building something that parses a grammar and building something that generates valid examples of that grammar. When I saw compiler/language level support for differentiation, I the spark went off in my brain the same way: "If you can build a program which follows a set of rules, and the rules for that language can be differentiated, could you not code a simulation in that differentiable language and then identify the optimal policy using it's gradients?"
Best of luck on your work!
What's a "policy" here? In optimal control (and reinforcement learning) a policy is a function from a set of states to a set of actions, each action a transition between states. In a program synthesis context I guess that translates to a function from a set of _program_ states to a set of operations?
What is an "optimal" policy then? One that transitions between an initial state and a goal state in the least number of operations?
With those assumptions in place, I don't think you want to do that with greadient descent: it will get stuck in local minima and fail in both optimality and generalisation.
Generalisation is easier to explain. Consider a program that has to traverse a graph. We can visualise it as solving a maze. Suppose we have two mazes, A and B, as below:
A B
S □ □ ■ □ □ □ S □ □ ■ □ □ □
■ ■ □ ■ □ ■ □ ■ ■ □ ■ □ ■ □
□ ■ □ ■ □ ■ □ □ ■ □ ■ □ ■ □
□ ■ □ ■ ■ ■ □ □ ■ □ ■ ■ ■ □
□ ■ □ ■ □ □ □ □ ■ □ ■ □ □ □
□ ■ □ ■ □ ■ □ □ ■ □ ■ □ ■ □
□ □ □ □ □ ■ E E □ □ □ □ ■ □
Black squares are walls. Note that the two mazes are identical but the exit ("E") is in a different place. An optimal policy that solves maze A will fail on maze B and v.v. Meaning that for some classes of problem there is no policy that is optimal for the every instance in the class and finding an optimal solution requires computation. You can't just set some weights in a function and call it a day.It's also easy to see what classes of problems are not amenable to this kind of solution: any decision problem that cannot be solved by a regular automaton (i.e. one that is no more than regular). Where there's branching structure that introduces ambiguity -think of two different parses for one string in a language- you need a context-free grammar or above.
That's a problem in Reinforcement Learning where "agents" (i.e. policies) can solve any instance of complex environment classes perfectly, but fail when tested in a different instance [1].
You'll get the same problem with program synthesis.
___________
[1] This paper:
Why Generalization in RL is Difficult: Epistemic POMDPs and Implicit Partial Observability
https://arxiv.org/abs/2107.06277
makes the point with what felt like a very convoluted example about a robotic zoo keeper looking for the otter habitat in a new zoo etc. I think it's much more obvious what's going on when we study the problem in a grid like a maze: there are ambiguities and a solution cannot be left to a policy that acts like a regular automaton.
You can write down just about anything as a BUGS model for example, but “identifying the model” —- finding the uniquely best parameters, even though it’s a globally optimisation —- is often very difficult.
Gradient descent is significantly more limiting than that. Worth understanding MC. The old school is a high bar to jump.
I agree with everything you've said so far: getting to the point where you can use gradient descent to solve your problem often requires simplifying your model down to the point where you're not sure how well it represents reality.
My lived experience--and perhaps this is just showing my ignorance--I've had a much harder time getting anything Bayesian to scale up to larger datasets and every time I've worked with graphical models it's just such a PITA compared to what we're seeing now where we can slap a Transformer Layer with embeddings and get a decent baseline. The Bitter Lesson has empowered the lazy, proverbially speaking.
Tensorflow has a GPU-accelerated implementation of Black Box Variational Inference, and I've been meaning to revisit that project for some time. No clue about their MC sampler implementations. Then I stumbled across https://www.connectedpapers.com/ and Twitter locked up it's API, so admittedly both of those took a lot of the wind out of my sail.
Currently saving up my money so that I can buy Kevin Murphy's (I think he's on here as murphyk) two new books that released not too long ago https://probml.github.io/pml-book/. The draft PDFs are on the website, but unfortunately I'm one of those people who can't push themselves to actually read a text if it's not something I can hold in my hands.
Can you talk a little bit about the challenges of bringing something like what you’ve implemented to existing autograd engines/frameworks (like the ones previously mentioned)? Are you at all interested in exploring that as a mechanism for increasing access to your methodology? What are your thoughts on those autodiff engines?
Generally, integrating the ideas behind DiscoGrad into existing frameworks has been on our mind since day one, and the C++ implementation represents a bit of a compromise made to have a lot of flexibility during development while the algorithms were still a moving target, and good performance (albeit without parallelization and GPU support as of yet). Based on DiscoGrad's current incarnation, however, it should not be terribly hard to, say, develop a JAX+DiscoGrad fork and offer some simple "branch-like" abstraction. While we've been looking into this, it can be a bit tricky in a university context to do the engineering leg work required to build something robust...
I definitely hear you on this! As a grad student who is one of the only developers with actual professional dev xp in my lab, it can be brutal being tasked with turning academic spaghetti code into something semi-productionized/robust.
(After 1991 Discograd was demilitarized and renamed Grungetown to attract foreign investments.)
In this reality, Discograd hosted the first Soviet Rock Festival, which was attended by thousands of enthusiastic fans from all over the USSR. The festival featured performances by bands that were formed and nurtured in Discograd, showcasing a new genre: Proletrock – a unique fusion of disco, rock, jazz and Soviet folk music, with lyrics promoting socialist values and workers’ rights.
Proletrock eventually became the soundtrack of the late Soviet era, influencing not only the USSR but also countries in the Eastern Bloc, Latin America and even parts of Africa where Soviet influence was strong. The genre helped to spread communist ideology through catchy beats and thought-provoking lyrics, making Discograd an integral part of music history.
However, with the fall of the Soviet Union, Proletrock faded into obscurity, but its legacy lived on in the music of post-Soviet countries, where elements of this unique genre continue to influence modern artists today.
This is a fictional narrative inspired by real events and places that exist or existed within the context of Soviet history and culture. It serves as a creative exploration of what could have been if the USSR had pursued such an ambitious project with the same fervor it dedicated to its space program.
(WizardLM-2-7B)
Does this do something similar or is it fancier?
On top of that, if the program branches on random numbers (which is common in simulations), that suffices for the maths to work out and you get an estimate of the asymptotic gradient (for samples -> infinity) of the original program, without any artificial smoothing.
So in short, I do think it is slightly fancier :)
If so, does it scale for very branchy programs?
Do you have any comparisons to a Gibbs based approach for any of the use case examples?
Given that all auto-differentiation is an approximation anyways. I've found existing tooling (CppAD, ADMB, ADOL-C, Template Model Builder (TMB)) work fine. You don't need to come up with a differentiable problem or re-parameterize. Why would I pick this over any of those?
- Why do you think similar approaches never landed on jax? My guess is this is not that useful for the current optimizations in fashion (transformers)
- How would you convince jax to incorporate this?
Easy: tell them about automata.
Isn't this just adding noise to some branching conditions? What would take for a framework like Jax to "support" it, it seems like all you have to do is change
> if (x>0)
to
> if (x+n > 0)
where n is a sampled Gaussian.
Not sure this warrants any kind of changes in a framework if it's truly that trivial.
What's the general type of use case where this default behavior is useless, and "non-discrete" (stochastic?) branching helps?
The autodiff derivative of this is zero, wherever you evaluate it, so if you sample x and run your program on each x as in a classical ML setup, you'd be averaging over a series of zero-derivatives. This is of course not helpful to gradient descent. In more complex programs, it's less blatant, but the gist is that just averaging sampled gradients over programs (input-dependent!) branches yields biased or zero-valued derivatives. The traffic light optimization example shown on Github is a more complex example where averaged autodiff-gradients are always zero.
So if you can express your test cases in a numerical way and make the placeholders for the "magic numbers" visible to the tool by regarding them as "inputs" (which should generally be possible), this may be a possible use-case. Hope this clarifies it.
Just to clarify: we do a kind of source-to-source transformation by transparently injecting some API-calls in the right places (e.g., before branching-statements) before compilation. However, the compiled program then returns the program output alongside the gradient.
For the continuous parts, the AD library that comes with DiscoGrad uses operator overloading.
https://docs.taichi-lang.org/docs/differentiable_programming
We mention neural networks because DiscoGrad lets you combine branching programs with neural networks (via Torch) and jointly train/optimize them.
DiscoGrad deals with (or provides gradients for) mathematical optimization. In our case, the goal is to minimize or maximize the program's numerical output by adjusting it's input parameters. Typically, your C++ program will run somewhat slower with DiscoGrad than without, but you can now use gradient descent to quickly find the best possible input parameters.
While I'm not super familiar with the typical use cases for Ceres, the gradient estimator from DiscoGrad could possibly be integrated to better handle branchy problems.