I usually call it "head parallelism" (which is a type of tensor parallelism, but paralllelize for small clusters, and specific to attention). That is what you described: sharding input tensor by number of heads and send to respective Q, K, V shard. They can do Q / K / V projection, rope, qk norm whatever and attention all inside that particular shard. The out projection will be done in that shard too but then need to all reduce sum amongst shard to get the final out projection broadcasted to every participating shard, then carry on to do whatever else themselves.
I am asking, however, is whether that will speed up decoding as linearly as it would for prefilling.