Flash Attention

Related: Attention · Attention Mechanisms · Ring, Ulysses and Unified Attention · Gradient Checkpointing
Literature: FlashAttention Paper

An Efficient Attention Process

Standard Attention

We have the Q,K,VRNd matrices in the HBM

  1. Load Q,K blocks from HBM to the GPU's SRAM.
  2. Compute S=QKT in GPU's SRAM and then move it to HBM.
  3. Move S from HBM and then P=softmax(S), write P to HBM.
  4. Load P and V by blocks from HBM, compute O=PV and then move O to HBM.

Flash Attention

Goal - Minimize the number of HBM accesses

To compute numerically stable softmax ->

  1. m(x)=maxixi -> Find the max value from the input vector
  2. f(x)=[ex1m(x),,exBm(x)] -> before exponentiating each term, make sure to subtract the max from it.
  3. l(x)=f(x)i -> The running sum for the denominator.

We would like to have a way to compute softmax online and not after we calculate all the input terms. If we are able to do that, we would be able to compute the softmax without moving S back and forth from HBM.

Let x=[x1,x2], then we could write f(x) and l(x) as follows:

f(x)=[em(x1)m(x)f(x1),em(x2)m(x)f(x2)]