
    /jK                       U d Z ddlmZ ddlZddlZddlmZ erddlmZ ddl	m
Z
 ddlmZ ddlmZmZ ddlZdd	lmZ d
dlmZ dgZdaded<   daded<   e
 G d d                      Zedad            Z	 dbdcdZdddZdedZdfd+Zdgd2Zdhd9Z  ed:          Z!did=Z"djd?Z#	 	 	 	 dkdldHZ$	 dmdndNZ%	 	 	 doddOdOddddPdpdRZ&ddOdOddddPdqdSZ'ddddTdrdWZ(	 	 	 	 	 	 dsddYdtdZZ)	 	 	 duddYdvd[Z*ddYdwd^Z+ ej,        d_e`           dS )xz
PROTOTYPE!
Flash Attention 3 implementation.
For fp8: only supports forward pass right now.
For fp16/bf16: supports forward and backward pass.
    )annotationsN)TYPE_CHECKING)Callable)	dataclass)cache)TypeVarTupleUnpack)Library   )	_registryregister_flash_attention_fa3zCallable | None_FA3_CUDA_FWD_FA3_CUDA_BWDc                  "    e Zd ZU ded<   ddZdS )
_FA3HandlezLibrary | NonelibraryreturnNonec                R    d | _         t          j                            d           d S )NF)r   torch_C_set_sdp_use_fa3)selfs    \/home/longshao/multi-rider-rag/.venv/lib/python3.11/site-packages/torch/nn/attention/_fa3.pyremovez_FA3Handle.remove*   s%    !!%(((((    N)r   r   )__name__
__module____qualname____annotations__r    r   r   r   r   &   s6         ) ) ) ) ) )r   r   devicetorch.devicer   intc                J    t           j                            |           \  }}|S N)r   cudaget_device_capability)r"   major_s      r   _get_device_majorr+   0   s     z//77HE1Lr   flash_attn_interfacemodule_pathstrc                    t          |            t          j                            d           t	          t                                S )z
    Register FA3 flash attention kernels with the PyTorch dispatcher.

    Args:
        module_path: Python module path to the FA3 implementation.
    T)_fa3_import_moduler   r   r   r   _fa3_register_kernelsr-   s    r   r   r   6   s?     {### 
Hd###+--...r   r   c                   t          j        |            t          t          j        d          st          d|  d          t          t          j        j        d          st          d|  d          t          t          j        j        d          st          d|  d          t          j        j        j        at          j        j        j	        a
d S )Nflash_attn_3zModule 'z' does not expose FA3 kernelsfwdz%' does not expose FA3 forward kernelsbwdz&' does not expose FA3 backward kernels)	importlibimport_modulehasattrr   opsRuntimeErrorr4   r5   r   r6   r   r2   s    r   r0   r0   G   s    K(((59n-- RPkPPPQQQ59)511 
I{III
 
 	
 59)511 
J{JJJ
 
 	
 I*.MI*.MMMr   r
   c                 x   t          ddd          } |                     dt          d           |                     dt          d           |                     dt          d           |                     dt
          d           |                     dt          d           |                     d	t          d           | S )
NatenIMPLCUDAz"_flash_attention_forward.quantizedz-_scaled_dot_product_flash_attention.quantized_flash_attention_forward#_scaled_dot_product_flash_attention_flash_attention_backward,_scaled_dot_product_flash_attention_backward)r
   impl!_fa3_flash_attention_forward_impl4_fa3_scaled_dot_product_flash_attention_forward_impl)_fa3_flash_attention_forward_impl_default<_fa3_scaled_dot_product_flash_attention_forward_impl_default"_fa3_flash_attention_backward_impl5_fa3_scaled_dot_product_flash_attention_backward_impl)libs    r   r1   r1   X   s    
&&&
)
)CHH,.OQW   HH7<  
 HH"$Mv   HH-D   HH(*LfUUUHH6=  
 Jr   querytorch.Tensortensorstuple[torch.Tensor, ...]	dropout_pfloat	cum_seq_qtorch.Tensor | None	q_descale	k_descale	v_descale
str | Nonec                   |dk    rdS t          d |D                       sdS t          d |D                       dk    rdS | j        t          j        k    r |||t          j        dt                     ||                                 d	k    rd
S ||                                 dk    rdS t          j	        
                                sdS t          | j                  dk    rdS d S )N        zdropout_p must be 0c              3  $   K   | ]}|j         V  d S r&   )is_cuda.0ts     r   	<genexpr>z,_fa3_common_support_error.<locals>.<genexpr>   s$      **Qqy******r   zinputs must be CUDA tensorsc                    h | ]	}|j         
S r!   )r"   r\   s     r   	<setcomp>z,_fa3_common_support_error.<locals>.<setcomp>   s    &&&AH&&&r   r   inputs must share devicezWhen using SDPA with fp8, descale tensor should always be used for accurate dequantization. Please use _scaled_dot_product_attention_quantized and provide the descale tensors.   zdense query must be 4D   zragged query must be 3DzCUDA not available	   z#FA3 requires compute capability 9.0)alllendtyper   float8_e4m3fnwarningswarnUserWarningdimr'   is_availabler+   r"   )rL   rN   rP   rR   rT   rU   rV   s          r   _fa3_common_support_errorro   t   s    C$$**'***** -,,
&&g&&&''1,,)){e)))Y.)2C+ 	
 	
 	
 UYY[[A--''!1!1((:""$$ $##&&!++444r   keyvaluereturn_debug_maskboolalibi_slopes	seqused_kc           	     t   |rdS |dS | |j         t          j        k    rdS |j        sdS t          j        t          j        t          j        ft          fd| ||hD                       sd S t          d | ||hD                       dk    rd	S t          | | ||f||||	|
          }|
|d
k    rdS |S d S )Nzreturn_debug_mask must be Falsezalibi_slopes not supportedzseqused_k must be int32zseqused_k must be CUDAc              3  *   K   | ]}|j         v V  d S r&   rh   r]   r^   supported_dtypess     r   r_   z-_fa3_forward_support_error.<locals>.<genexpr>   s+      HHqqw**HHHHHHr   inputs must be one of c                    h | ]	}|j         
S r!   rx   r\   s     r   ra   z-_fa3_forward_support_error.<locals>.<setcomp>   s    111AG111r   r   #all inputs must have the same dtyperb   z(query, key, value must be on same device)
rh   r   int32r[   ri   float16bfloat16rf   rg   ro   )rL   rp   rq   rP   rr   rt   ru   rR   rT   rU   rV   errorrz   s               @r   _fa3_forward_support_errorr      s     100++?ek)),,  	,+++U]ENKHHHHUC4GHHHHH ;:(8:::
11eS%011122a7744%	U E ...==4r   grad_outout	logsumexpwindow_size_left
int | Nonewindow_size_rightc
           	     f   |j         t          j        k    r	 dS |j         t          j        k    rdS t          j        t          j        ft          fd| ||||hD                       sd S t          d | ||||hD                       dk    rdS t          || |||||f||d d d           }
|
|
S d S )NzHFA3 backward does not support fp8 - use inference only (torch.no_grad())zlogsumexp dtype must be float32c              3  *   K   | ]}|j         v V  d S r&   rx   ry   s     r   r_   z._fa3_backward_support_error.<locals>.<genexpr>   s+      WWqqw**WWWWWWr   r{   c                    h | ]	}|j         
S r!   rx   r\   s     r   ra   z._fa3_backward_support_error.<locals>.<setcomp>   s    @@@AG@@@r   r   r}   )	rh   r   ri   float32r   r   rf   rg   ro   )r   rL   rp   rq   r   r   rP   rR   r   r   r   rz   s              @r   _fa3_backward_support_errorr      s     {e)))V	
 	
 %-''00u~6WWWWXuc5RU4VWWWWW ;:(8:::
@@hsE3?@@@AAQFF44%	5#uc95 E 4r   Ts
Unpack[Ts]tuple[Unpack[Ts]]c                 4    t          d | D                       S )Nc              3  B   K   | ]}|                     d d          V  dS )r      N)	transposer\   s     r   r_   z#_transpose_dense.<locals>.<genexpr>   s0      44qQ""444444r   )tuple)rN   s    r   _transpose_denser      s    44G444444r   xc                d    | -|                      d          dk    r|                                 n| S )z2Ensure tensor is contiguous in the last dimension.Nr   )stride
contiguous)r   s    r   _maybe_contiguousr      s,    ]qxx||q/@/@1<<>>>aGr   cu_seq_qcu_seq_kmax_qmax_kscalefloat | None	is_causal!tuple[torch.Tensor, torch.Tensor]c                   t           t          d          t          |           }t          |          }|j        t          j        k    rF|                    d          dk    r-|                    d          dk    r|                                nt          |          }t          |          }t          |          }t          |          }t          g |||ddd|||dd|||dddddd||||||	|	nd|
|
ndddddt	          j                    rdnddt          j	        
                                pdR  \  }}}}||                                fS )	zF
    Run the FA3 forward pass by calling the C++ kernel directly.
    NFA3 not registeredr   r   r   rY   T)r   r;   r   rh   r   ri   r   r   $are_deterministic_algorithms_enabledr   _get_sm_carveout_experimental)rL   rp   rq   r   r   r   r   r   r   r   r   ru   r   rT   rU   rV   qkvcu_seqlens_qcu_seqlens_ksoftmax_lse	out_accumsoftmax_lse_accums                           r   _fa3_run_forwardr      s   * /000%  A#A ;%---LL!!LL!! 	 u%%  %X..L$X..L!),,I5B #6	#6	#6 	
#6 		#6
 	#6 	#6 	#6 	#6 	#6 	#6 	#6 	#6 	#6 	#6 	#6  	!#6" 	##6$ 	%#6& 	'#6( 	)#6* 	+#6, 	-#6. 	/#60 	1#62 	3#64 -8b5#66 /:7#68 	
9#6: 	;#6< 	=#6> 	?#6@ 799@qA#6B 	C#6D 	..005AE#6 #6 #62Ci!2H &&((((r   Fmax_seqlen_qmax_seqlen_kdeterministic/tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                   t           t          d          t          |           }|                    d          dk    r|                                n|}|                    d          dk    r|                                n|}|                    d          dk    r|                                n|}t          |          }t          |          }t          j        |          }t          j        |          }t          j        |          }t          |||||||||||d d ||	|
|||d|t
          j                                        pd           |||fS )Nr   r   r   rY   r   )	r   r;   r   r   r   r   
empty_liker   r   )r   rL   rp   rq   r   r   r   r   r   r   r   r   r   r   r   doutr   r   r   olsedqdkdvs                           r   _fa3_run_backwardr   C  si   " /000 X&&D#ll2..!33AJJrNNa//SA#ll2..!33A#A
I
&
&C 
	!		B		!		B		!		B				


..005A-  0 r2:r   r   r   r   r   ru   rt   r   	cum_seq_kc                  t          | ||||	||||
||          }|t          d|           t          | |||||||||||||
||          \  }}t          j        dt          j        | j                  }t          j        dt          j        | j                  }t          j        d| j        | j                  }|||||fS )Nz)FA3 flash_attention forward unsupported: )r   )rh   r"   r!   r   )	r   r;   r   r   zerosuint64r"   emptyrh   )rL   rp   rq   rR   r   r   r   rP   r   rr   rT   rU   rV   r   r   r   ru   rt   r   r   r   	rng_statephilox_offset
debug_masks                           r   rE   rE   ~  s    , ' E NuNNOOO! HC$ DU\JJJIK%,u|LLLMQek%,GGGJYz99r   c
               F    t          | |||||||||	d d d |
|||||          S )Nr   )rE   )rL   rp   rq   rR   r   r   r   rP   r   rr   r   r   r   ru   rt   r   s                   r   rG   rG     sT    & -)+!'   r   )r   r   r   r   unusedc                   t          | ||||||
|||
  
        }|t          d|           t          j                    }t	          | |||||||||	||||nd||nd|          \  }}}|||fS )z0FA3 implementation of _flash_attention_backward.Nz*FA3 flash_attention backward unsupported: r   )r   r;   r   r   r   )r   rL   rp   rq   r   r   rR   r   r   r   rP   r   r   r   r   r   r   r   r   r   r   r   s                         r   rI   rI     s    * ( E OOOPPP>@@M",8b.: JBB" r2:r   rY   r   c	               4   t          | ||||d d d |||          }
|
t          d|
           t          | ||          \  }}}| j        t          j        k    rt          j        n| j        }t	          j        | |          }|                    dd          }|	                    d          }|	                    d          }t          |||d d ||||||	||||          \  }}}}}| 	                    d          }|	                    d          }||d d |||||f	S )NzFA3 SDPA forward unsupported: rx   r   r   )r   r   rT   rU   rV   )r   r;   r   rh   r   ri   r   r   r   sizerE   )rL   rp   rq   rT   rU   rV   rP   r   rr   r   r   r   r   r   	out_dtypeout_bhsdout_bshdmax_q_flashmax_k_flashr*   r   r   r   r   r   r   s                             r   rF   rF      sa    ' E CECCDDDuc511GAq!
 #(+1D"D"D%+IY777H!!!Q''H&&))K&&))K3T			4 4 40AsI}j" JJqMMEHHQKKE
 
r   c               4    t          | ||d d d ||||
  
        S )Nr   )rF   )rL   rp   rq   rP   r   rr   r   s          r   rH   rH   g  s:     @   r   philox_seedr   c                  t          | ||||||
ddd
  
        }|t          d|           t          | ||||          \  }}}}}t          ||||||dd||	|
||||          \  }}}t          |||          \  }}}|||fS )zCFA3 implementation of _scaled_dot_product_flash_attention_backward.NzFA3 SDPA backward unsupported: r   )r   r;   r   rI   )r   rL   rp   rq   r   r   rR   r   r   r   rP   r   r   r   r   r   
grad_out_tq_tk_tv_tout_tr   r   r   dq_outdk_outdv_outs                              r   rJ   rJ     s    & (%eS)YdD E DUDDEEE (8%eS( ($JS#u 4  JBB& .b"b99FFF66!!r   FA3)register_fn)r"   r#   r   r$   )r,   )r-   r.   r   r   )r-   r.   r   r   )r   r
   )rL   rM   rN   rO   rP   rQ   rR   rS   rT   rS   rU   rS   rV   rS   r   rW   )rL   rM   rp   rM   rq   rM   rP   rQ   rr   rs   rt   rS   ru   rS   rR   rS   rT   rS   rU   rS   rV   rS   r   rW   )r   rM   rL   rM   rp   rM   rq   rM   r   rM   r   rM   rP   rQ   rR   rS   r   r   r   r   r   rW   )rN   r   r   r   )r   rS   r   rS   )NNNN)"rL   rM   rp   rM   rq   rM   r   rS   r   rS   r   r$   r   r$   r   r   r   rs   r   r   r   r   ru   rS   r   rS   rT   rS   rU   rS   rV   rS   r   r   )F) r   rM   rL   rM   rp   rM   rq   rM   r   rM   r   rM   r   rS   r   rS   r   r   r   r   r   r   r   rs   r   r$   r   r$   r   rs   r   r   )NNN)&rL   rM   rp   rM   rq   rM   rR   rS   r   rS   r   r$   r   r$   rP   rQ   r   rs   rr   rs   rT   rS   rU   rS   rV   rS   r   r   r   r$   r   r$   ru   rS   rt   rS   r   rS   ) rL   rM   rp   rM   rq   rM   rR   rS   r   rS   r   r$   r   r$   rP   rQ   r   rs   rr   rs   r   r   r   r$   r   r$   ru   rS   rt   rS   r   rS   )"r   rM   rL   rM   rp   rM   rq   rM   r   rM   r   rM   rR   rS   r   rS   r   r$   r   r$   rP   rQ   r   rs   r   rM   r   rM   r   r   r   r   r   r   )NNNrY   FF)rL   rM   rp   rM   rq   rM   rT   rS   rU   rS   rV   rS   rP   rQ   r   rs   rr   rs   r   r   )rY   FF)rL   rM   rp   rM   rq   rM   rP   rQ   r   rs   rr   rs   r   r   )r   rM   rL   rM   rp   rM   rq   rM   r   rM   r   rM   rR   rS   r   rS   r   r$   r   r$   rP   rQ   r   rs   r   rM   r   rM   r   r   )-__doc__
__future__r   r7   rj   typingr   collections.abcr   dataclassesr   	functoolsr   typing_extensionsr   r	   r   torch.libraryr
    r   __all__r   r    r   r   r+   r   r0   r1   ro   r   r   r   r   r   r   r   rE   rG   rI   rF   rH   rJ   register_flash_attention_implr!   r   r   <module>r      s     # " " " " "                   )(((((( ! ! ! ! ! !       2 2 2 2 2 2 2 2  ! ! ! ! ! !       #
 "& % % % %!% % % % % ) ) ) ) ) ) ) )     ./ / / / /"/ / / /"   8" " " "J( ( ( (V# # # #L \$5 5 5 5H H H H$  $%)%)%)!J) J) J) J) J)x  8 8 8 8 8L &*%)%):: %)(,#):: :: :: :: :: ::R %)(,##' ' ' ' ' 't #'$(%8 8 8 8 8 8~ &*%)%)#D D D D D D DV #      P !2" 2" 2" 2" 2" 2"j (	 ';W X X X X X Xr   