O(seq_len*dk + seq_len^2)
whereas Att(i) computation with FA runs in O(seq_len^2*dk^2/SRAM_size)
Q, K, V computation remains the same. And ATTN(0,n)*Wo also remains the same.In a smaller model, with N=12, D=768, dk=64, seq_len=1k, SRAM=32KB, ..., FA optimization would roughly translate to 0.5M vs 4.5M per-head(att(i)). So ~10x improvement but in the grand scheme of things, in per-attention-layer it becomes ~91M vs ~45M so ~2x of net improvement.
> This is why "prefill" is compute intensive instead of memory intensive.
Yes, I think I agree and I have corrected myself elsewhere in the thread. The original thought that I actually wanted to convey in my initial comment which was somehow lost throughout the discussion is that - prefill/training will benefit from the FlashAttention/MLA but the inference will not. I can agree that the formulation "only when memory access time dominates the compute in attention implementation" was wrong.
> During token generation ... memory bandwidth is the deciding factor for token generation.
LLama3-70B MLP layer roughly takes 1 TFLOPS and 0.6 GB of bandwidth for 1024 tokens. Assuming that 1023 entries are taken from a KV-cache, attention layer computation for a single token will take ~0.6 GFLOPS and ~0.2 GB of bandwidth. To load the rest of the values from KV-cache at FP16 precision, it will take us 1023*0.1MB or ~1 GB.
So, ~1 TFLOPS and ~1 GB of bandwidth per each Transformers layer. On hardware such as H100, this still looks like a compute-bound problem to me. OTOH on the CPU with 15 TFLOPS of compute but only <1TB/s of memory bandwidth, it becomes memory-bound problem. Or no?