I'd assumed that overall performance in Stable Diffusion was limited by the code running on the GPU, with Python performance being a fairly minor factor-- but I guess that's not the case?
The Python part only runs a handful of times so JIT vs. non-JIT doesn't really make a difference.
See, for example, XLA: https://www.tensorflow.org/xla
It looks like maybe nvFuser is an equivalent library for pytorch? https://pytorch.org/blog/introducing-nvfuser-a-deep-learning...
https://old.reddit.com/r/MachineLearning/comments/xa75km/p_p...
The difference is that `x += y` modifies `x` inplace, where `x = x + y` creates a new object. In other words, if anybody had a reference to `x` before the update, the "optimized" code would break things.
I guess this is the kind of this stuff that drew me to Rust. This kind of behavior gives me the creeps. Just like Ruby’s conventions.
Even so, there are absolutely silly things which can hint JS JITs to optimize (or to not deoptimize). Like defining and instantiating a class rather than just creating POJOs with the same values, or assigning NaN instead of null to uninitialized numeric variables/properties. Conditional control flow can deopt, but generally performs better around different function calls than within a single function. Even creating and throwing errors for control flow (which is generally expensive, and terrible for maintenance) can be optimal if your try/catch is the whole body of the function it resides in. And all of those might vary between JITs.
It would be interesting to get a benchmark using CPython vs Nuitka related to this change.
x = x + y
creates a copy of the array x, adds y to it, and then sets the variable x to that new array. In contrast, the line x += y
adds the array y in-place into the array x (and so hopefully no other piece of code is relying on x being immutable). This kind of trade-off occurs in pretty much all programming, for instance you see it whenever big-integer libraries are used in C++ or Rust.> Changing this back to the original implementation fixed an error I was getting when doing textual inversion on Windows
https://github.com/lstein/stable-diffusion/commit/62863ac586...
A "safe" way to do this is still straightforward, I think.
from copy import copy
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x = copy(x)
x += self.attn1(self.norm1(x))
x += self.attn2(self.norm2(x), context=context)
x += self.ff(self.norm3(x))
return x
It could be faster but I don't know what `x` is and I'm not going to guess. Also, `copy` may not be sufficient, `deepcopy` may be necessary - again, I don't know what `x` is so I can't figure that out. Pls use type annotations :) def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x = x + self.attn1(self.norm1(x))
x += self.attn2(self.norm2(x), context=context)
x += self.ff(self.norm3(x))
return xI would only do that if I had seen it to be faster, though, and add a comment on why the first line couldn’t do +=.
The problem here comes down to not knowing whether you’re allowed to modify a value in-place or not, because it’s not clear who owns it: it wasn’t written down anywhere, and in stable-diffusion alone it was fine to mutate it, but textual-inversion did something so it wasn’t (perhaps passing it something it expected to not be mutated). This is a moderately common type of bug that can be extraordinarily difficult to diagnose—it’s unusually easy to pinpoint here because it promptly raises a RuntimeError—and which is statically impossible in Rust, because the whole “am I allowed to mutate it” thing is resolved in the type system.
As everyone said, this is more performant because x is being modified in place, the reason this was not done in place is because you can't train a neural network if an instruction is being done in place. During training a network goes literally through all operations that were done and see how well they performed so they can be adjusted using a secondary value called a gradient, this is done during the backwards pass. If you replace something in place you're essentially overwriting the input values that were passed to that function, and by extension, the output values of the function called before, essentially breaking the network chain, unless you also copy the inputs together with the gradients, which would cause an even worse performance hit and be a memory hog.
The breakage bug later in the issue is proof of this, when sampling to generate an image only the forward pass is done on the network, but textual inversion requires you to train the network and therefore do the backwards pass, triggering the error since the dependency graph is broken. I should also note that technically the add operation should be safe to do in place as it's reversible, but I'm not a pytorch expert so I'm not sure exactly what's going on in there.
If the engineers that originally implemented the function intentionally chose the slower version, a quick comment as to why would have prevented this from happening in the first place.
One of the first things you're taught when learning pytorch is that you're not coding in python, but actually creating a network graph that is loaded and executed on a GPU. Other common sense things is knowing that you shouldn't use stuff that is in the stdlib or in numpy and use torch.* variants instead, not doing so will incur either undefined behavior, cause massive memory copies between the CPU and GPU or most likely, error out at runtime.
Note that this is a repo that is forked from the official repo, it's a community repo focused on inference and thus doesn't care about training so it has completely different considerations than the original code.
This is idd the time to place a comment, yet so many people don't do that.
z = y + x
x = z
Basically, creating an object `z` just to throw it away.`x += y` just adds y to x directly without any intermediary.
You could write this in any language pretty easily. For example, in Rust:
let x = "abc".to_string();
let y = "123".to_string();
let x = x + &y;
as opposed to the more efficient: let mut x = "abc".to_string();
let y = "123".to_string();
x.push_str(&y);
It's just using an operation to mutate in place vs an immutable operation.It’s the confusion / idea that this is trivial change which is the overload thing.
The problem is thinking `+' and `+=' are the same, they are not and `+' should not be used when `+=' can be used.
Btw there's ongoing work to automatically optimize expressions like this. See the XLA compiler for example. Right now deep learning has a ton of seemingly obvious compute/memory optimisations that are not done automatically.
Jokes aside this is pytorch so this is compiled to C++ or cuda, the problem likely comes from the different functions that are called for += vs +
>>> def f(x): x += 1
...
>>> def g(x): x = x + 1
...
>>> dis.dis(f)
1 0 LOAD_FAST 0 (x)
3 LOAD_CONST 1 (1)
6 INPLACE_ADD
7 STORE_FAST 0 (x)
10 LOAD_CONST 0 (None)
13 RETURN_VALUE
>>> dis.dis(g)
1 0 LOAD_FAST 0 (x)
3 LOAD_CONST 1 (1)
6 BINARY_ADD
7 STORE_FAST 0 (x)
10 LOAD_CONST 0 (None)
13 RETURN_VALUE1. In PyTorch (and other array programming libraries like Numpy), the operations being passed around are tensors/arrays (i.e. large chunks of memory). Thus, += is overloaded to mean "in-place write" to the arrays.
So, `+` vs `+=` is the equivalent of
a: float[1000]
b: float[1000]
for i in [0, 1000]:
b[i] = a[i] + 2
vs. a: float[1000]
for i in [0, 1000]:
a[i] = a[i] + 2
The main performance advantage comes in 1. no need to allocate an extra array, 2. you're using less memory overall, so various caching levels can work better. It has nothing to do with python bytecodes.2. As for whether it generally makes sense to do this optimization manually... Usually, PyTorch users don't use in-place operations as its a bit uglier mathematically and have various foot-guns/restrictions that users find confusing. Generally, it's best to have this optimization be done automatically by an optimizing compiler.
3. PyTorch in general does support using in-place operations during training, albeit with some caveats.
(PS) 4. Putting everything on one line (as some folks suggest) is almost certainly not going to help performance - the primary performance bottlenecks here have almost nothing to do with CPU perf.
> Generally, it's best to have this optimization be done automatically by an optimizing compiler.
What compiler should be optimizing this operation?
There are comments on the commit reporting errors under certain conditions.
There's many different paths to optimizing compilers folks use with PyTorch. One with close integration is NVFuser (see https://www.reddit.com/r/MachineLearning/comments/xa75km/p_p...), although there are other compilers like ONNXRuntime.
Yes, handling autograd (during training) is a whole different thing, and not all compilers support that.
I'm wondering, because recent version have improved performance a lot. 3.11 is much faster than 3.10, and what's in 3.12 is already much faster than 3.11.
https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a...
Many times I have had to decide if my Python code would be more legible or get free performance.
The thing I like about JavaScript is that I can _usually_ trust the JIT to make my code faster than I could, meaning I can focus entirely on writing clean code.
P.S. you can always hand optimize. If you do, just comment the heck out of it.
This is rarely an option that has presented itself to me. If there's a clear performance issue in my code then I probably picked the wrong algorithm or my code has a bug, unless you decided for some reason to do heavy calculations in raw python. If you're doing operations on big chunks of data you should always use something like numpy or jax.
Even OPs issue the clear reason is that it's doing an operation in place instead of creating a copy, for ML models this can only be done at inference time and not training time since you need to keep track of the whole network, hence why the code was in it's unoptimized state.
This isn't a case of "The Python interpreter is bad" it's just that the code is doing what the user asked it to do - create a completely new copy of the data, then overwrite the old copy with it. Immutable operations like this are slow, mutating the value (what += does) is fast.
Granted, a compiled language could recognize that you're doing this, but it also might not - is `+` and `+=` semantically identical such that the compiler can replace one with the other? Maybe? Probably not, if I had to guess. The correct answer is to just use the faster operation, as it is with all language.
I don't know the type of `x`, but I'd suggest another optimization here would be to:
a) Preallocate the buffer rather before mutating it 3x (which is still likely forcing some allocations)
b) Reuse that buffer if it's so important, store it in `self` and clear it before use.
He's the author of the essay "How Perl Saved the Genome Project", the books "Network Programming with Perl" and "Writing Apache Modules with Perl and C", and a number of Perl packages including CGI.pm - which helped power the dot-com era - and GD.pm.
It would be interesting to check whether changing every expression to x=x+y has a performance more similar to += or to ...+x
[0]: https://mobile.twitter.com/badamczewski01/status/15618171584...