Porting FlashAttention to Metal will be quite hard. Because for performance reasons, they did a lot of shenanigans to respect the memory hierarchy.
Thankfully, you can probably do something slower but more adapted to your memory constraints.
If you relax this need for performance and allow some re-computations, you can write a qkvatt function which takes q,k,v and a buffer to store the resulting attention, and compute without needing any extra memory.
The algorithm is still quadratic in time with respect to the attention horizon (although with a bigger constant (2x or 3x) due to the re computation). But it doesn't need any extra memory allocation which makes it easy to parallelize.
Alternatively you can use an O(attention horizon * number of thread in parallel) (like flash attention) extra memory buffer to avoid the re-computation.
Concerning the backward pass, that's the same thing, you don't need extra memory if you are willing to do some re-computation, or linear in attention horizon to not do re-computation.
One interesting thing to notice in the backward pass, is that it doesn't use the attn of the forward pass, so it doesn't need to be kept preserved (only need to preserve Q,K,V).
One little caveat of the backward pass (which you only need for training) is that it needs atomic_add to be easy to parallelize. This mean, it will be hard on Metal (afaik they don't have atomics for floats though they do have atomics for integer so you can probably use fixed points numbers).
No comments yet.