
    /j_                        d dl mZ d dlZd dlmZ ddlmZ ddlmZm	Z	 ddl
mZmZmZ ddlmZ  G d d	ej                  Z G d
 de          ZddZddZd d!dZd"dZ	 	 d#d$dZdS )%    )annotationsN   )LOGGER)bbox_iouprobiou)	xywh2xyxyxywhr2xyxyxyxy	xyxy2xywh)
TORCH_1_11c                       e Zd ZdZddddg dddfd fdZ ej                    d             Zd Zd Z	d Z
d ZddZd ZddZd Z xZS )TaskAlignedAssigneraG  A task-aligned assigner for object detection.

    This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
    classification and localization information.

    Attributes:
        topk (int): The number of top candidates to consider.
        topk2 (int): Secondary topk value for additional filtering.
        num_classes (int): The number of object classes.
        alpha (float): The alpha parameter for the classification component of the task-aligned metric.
        beta (float): The beta parameter for the localization component of the task-aligned metric.
        stride (list): List of stride values for different feature levels.
        stride_val (int): The stride value used for select_candidates_in_gts.
        eps (float): A small value to prevent division by zero.
       P         ?g      @)          &.>Ntopkintnum_classesalphafloatbetastridelistepsc                   t                                                       || _        |p|| _        || _        || _        || _        || _        t          | j                  dk    r| j        d         n| j        d         | _	        || _
        dS )a  Initialize a TaskAlignedAssigner object with customizable hyperparameters.

        Args:
            topk (int, optional): The number of top candidates to consider.
            num_classes (int, optional): The number of object classes.
            alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
            beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
            stride (list, optional): List of stride values for different feature levels.
            eps (float, optional): A small value to prevent division by zero.
            topk2 (int, optional): Secondary topk value for additional filtering.
        r   r   N)super__init__r   topk2r   r   r   r   len
stride_valr   )	selfr   r   r   r   r   r   r!   	__class__s	           Z/home/longshao/multi-rider-rag/.venv/lib/python3.11/site-packages/ultralytics/utils/tal.pyr    zTaskAlignedAssigner.__init__   s    * 		]d
&
	,/,<,<q,@,@$+a..dkRSn    c                  
 |j         d         | _        |j         d         | _        |j        
| j        dk    ryt	          j        |d         | j                  t	          j        |          t	          j        |          t	          j        |d                   t	          j        |d                   fS 	 |                     ||||||          S # t          $ ry}dt          |                                          v rPt          j        d           d ||||||fD             } | j        | }	t          
fd|	D                       cY d}~S  d}~ww xY w)	a  Compute the task-aligned assignment.

        Args:
            pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
            pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
            anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
            gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
            gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

        Returns:
            target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
            target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
            target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
            fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
            target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).

        References:
            https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
        r   r   ).r   zout of memoryz7CUDA OutOfMemoryError in TaskAlignedAssigner, using CPUc                6    g | ]}|                                 S  )cpu).0ts     r&   
<listcomp>z/TaskAlignedAssigner.forward.<locals>.<listcomp>g   s     rrr1quuwwrrrr'   c              3  B   K   | ]}|                               V  d S N)to)r,   r-   devices     r&   	<genexpr>z.TaskAlignedAssigner.forward.<locals>.<genexpr>i   s-      ::aQTT&\\::::::r'   N)shapebsn_max_boxesr2   torch	full_liker   
zeros_like_forwardRuntimeErrorstrlowerr   warningtuple)r$   	pd_scores	pd_bboxes
anc_points	gt_labels	gt_bboxesmask_gtecpu_tensorsresultr2   s             @r&   forwardzTaskAlignedAssigner.forward>   sl   , /!$$?1-!q  	& 143CDD ++ ++ 6!233 6!233 		==Iz9iY`aaa 	 	 	#a&&,,..00XYYYrrIzS\^gip0qrrr&4::::6:::::::::::	s%   2C 
EA-E
E	E

Ec                   |                      ||||||          \  }}}	|                     ||	| j        |          \  }
}}|                     |||
|          \  }}}||z  }|                    dd          }|	|z                      dd          }||z  || j        z   z                      d                              d          }||z  }||||                                |
fS )a  Compute the task-aligned assignment.

        Args:
            pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
            pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
            anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
            gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
            gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

        Returns:
            target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
            target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
            target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
            fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
            target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
        T)dimkeepdim)get_pos_maskselect_highest_overlapsr6   get_targetsamaxr   	unsqueezebool)r$   r@   rA   rB   rC   rD   rE   mask_posalign_metricoverlapstarget_gt_idxfg_masktarget_labelstarget_bboxestarget_scorespos_align_metricspos_overlapsnorm_align_metrics                     r&   r:   zTaskAlignedAssigner._forwardl   s   $ ,0+<+<y)Y
G,
 ,
(, ,0+G+Gh 0,,
 ,
(w
 7;6F6FyR[]jls6t6t3}m 	 (--"d-CC 8+11b$1GG)L8<MPTPX<XY__`bccmmnpqq%(99m]GLLNNMYYr'   c                   |                      |||          }|                     ||||||z            \  }}	|                     ||                    dd| j                                                            }
|
|z  |z  }|||	fS )a  Get positive mask for each ground truth box.

        Args:
            pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
            pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
            gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
            gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
            anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).

        Returns:
            mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
            align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
            overlaps (torch.Tensor): Overlaps between predicted vs ground truth boxes with shape (bs, max_num_obj, h*w).
        rK   )	topk_mask)select_candidates_in_gtsget_box_metricsselect_topk_candidatesexpandr   rT   )r$   r@   rA   rC   rD   rB   rE   mask_in_gtsrV   rW   	mask_topkrU   s               r&   rO   z TaskAlignedAssigner.get_pos_mask   s      33J	7SS!%!5!5iIW`bmpwbw!x!xh//WY[]_c_hHiHiHnHnHpHp/qq	{*W4x//r'   c                   |j         d         }|                                }t          j        | j        | j        |g|j        |j                  }t          j        | j        | j        |g|j        |j                  }t          j        d| j        | j        gt          j                  }	t          j	        | j                  
                    dd                              d| j                  |	d<   |                    d          |	d<   ||	d         d	d	|	d         f         |         ||<   |                    d                              d| j        dd          |         }
|                    d                              dd|d          |         }|                     ||
          ||<   |                    | j                  |                    | j                  z  }||fS )
a/  Compute alignment metric given predicted and ground truth bounding boxes.

        Args:
            pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
            pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
            gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
            gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).

        Returns:
            align_metric (torch.Tensor): Alignment metric combining classification and localization.
            overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
        rN   dtyper2      )rj   )endrK   r   r   N)r4   rT   r7   zerosr5   r6   rj   r2   longarangeviewre   squeezerS   iou_calculationpowr   r   )r$   r@   rA   rC   rD   rE   narW   bbox_scoresindpd_boxesgt_boxesrV   s                r&   rc   z#TaskAlignedAssigner.get_box_metrics   s    _R ,,..;)92>io^g^noook47D,<b"Aajaqrrrk1dgt'78
KKK$'***//A66==b$BRSSA""2&&A(QCF):;GDG &&q))00T5Er2NNwW&&q))00RR@@I 008DD"tz22X\\$)5L5LLX%%r'   c                t    t          ||dd                              d                              d          S )a
  Calculate IoU for horizontal bounding boxes.

        Args:
            gt_bboxes (torch.Tensor): Ground truth boxes.
            pd_bboxes (torch.Tensor): Predicted boxes.

        Returns:
            (torch.Tensor): IoU values between each pair of boxes.
        FT)xywhCIoUrK   r   )r   rq   clamp_r$   rD   rA   s      r&   rr   z#TaskAlignedAssigner.iou_calculation   s8     	95tDDDLLRPPWWXYZZZr'   c           
        t          j        || j        dd          \  }}|9|                    dd          d         | j        k                        |          }|                    | d           t          j        |j        t           j        |j	                  }t          j
        |ddddddf         t           j        |j	                  }t          | j                  D ]+}|                    d|dddd||dz   f         |           ,|                    |dk    d           |                    |j                  S )	a  Select the top-k candidates based on the given metrics.

        Args:
            metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
                the maximum number of objects, and h*w represents the total number of anchor points.
            topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where topk
                is the number of top candidates to consider. If not provided, the top-k values are automatically
                computed based on the given metrics.

        Returns:
            (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
        rK   TrL   largestN)rM   r   ri   r   )r7   r   maxr   	expand_asmasked_fill_rm   r4   int8r2   	ones_likerangescatter_add_r1   rj   )r$   metricsra   topk_metrics	topk_idxscount_tensoronesks           r&   rd   z*TaskAlignedAssigner.select_topk_candidates   sF    #(*WdiRQU"V"V"Vi%))"d);;A>ITTU^__I	z1--- {7=
9K[\\\yAAArr2%*YM]^^^ty!! 	L 	LA%%b)AAAqqq!a!e)O*DdKKKK!!,"2A666w}---r'   c                   t          j        | j        t           j        |j                  d         }||| j        z  z   }|                                                                |         }|                    d|j	        d                   |         }|
                    d           t          j        |j	        d         |j	        d         | j        ft           j        |j                  }|                    d|                    d          d           |dddddf                             dd| j                  }	t          j        |	dk    |d          }|||fS )	a@  Compute target labels, target bounding boxes, and target scores for the positive anchor points.

        Args:
            gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the batch size and
                max_num_obj is the maximum number of objects.
            gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
            target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive anchor points, with
                shape (b, h*w), where h*w is the total number of anchor points.
            fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive (foreground) anchor
                points.

        Returns:
            target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
            target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
            target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
        )rl   rj   r2   ).NrK   r   r   ri   rk   N)r7   ro   r5   int64r2   r6   rn   flattenrp   r4   r|   rm   r   scatter_rS   repeatwhere)
r$   rC   rD   rX   rY   	batch_indrZ   r[   r\   fg_scores_masks
             r&   rQ   zTaskAlignedAssigner.get_targets   sJ   $ LTWEK	HXYYYZcd	%	D4D(DD!((0022=A "r9?2+>??N 	Q  #]%8%;T=MN+ '
 
 

 	q-"9"9""="=qAAA AAAt,33Aq$:JKKNQ$6qIIm]::r'   c                   t          |          }|dddf         | j        d         k     }t          j        ||z                                  t          j        | j        |j        |j                  |dddf                   |dddf<   t          |          }|j
        d         }|j
        \  }}	}
|                    ddd                              dd          \  }}t          j        |d         |z
  ||d         z
  fd	                              ||	|d          }|                    d
                              |          S )a  Select positive anchor centers within ground truth bounding boxes.

        Args:
            xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
            gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes, shape (b, n_boxes, 1).
            eps (float, optional): Small value for numerical stability.

        Returns:
            (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).

        Notes:
            - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
            - Bounding box format: [x_min, y_min, x_max, y_max].
        .rk   Nr   ri   rK   r      rL      )r
   r   r7   r   rT   tensorr#   rj   r2   r   r4   rp   chunkcatamingt_)r$   
xy_centersrD   rE   r   gt_bboxes_xywhwh_mask	n_anchorsr5   n_boxes_ltrbbbox_deltass                 r&   rb   z,TaskAlignedAssigner.select_candidates_in_gts!  sF     #9-- abb)DKN:"'+w$$&&L0D^Mbccc37##
 #
sABBw
 n--	$Q'	"GQAq))//155BiD!1B!6Z=M8M NTUVVV[[\^`girtvww""&&s+++r'   c                L   |                     d          }|                                dk    r|                    d          dk                        d|d          }|                    d          }t          j        |j        |j        |j	                  }|
                    d|                    d          d           t          j        |||                                          }|                     d          }| j        | j        k    r~||z  }t          j        || j        dd          j        }t          j        |j        |j        |j	                  }	|	
                    d|d           ||	z  }|                     d          }|                    d          }
|
||fS )a  Select anchor boxes with highest IoU when assigned to multiple ground truths.

        Args:
            mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
            overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
            n_max_boxes (int): Maximum number of ground truth boxes.
            align_metric (torch.Tensor): Alignment metric for selecting best matches.

        Returns:
            target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
            fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
            mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
        rN   r   rK   ri   Tr   r   )sumr   rS   re   argmaxr7   rm   r4   rj   r2   r   r   r   r!   r   indices)r$   rU   rW   r6   rV   rY   mask_multi_gtsmax_overlaps_idxis_max_overlapstopk_idxrX   s              r&   rP   z+TaskAlignedAssigner.select_highest_overlaps@  s    ,,r"";;==1%//22Q6>>r;PRSSN'q11#k(.W_WfgggO$$Q(8(B(B1(E(EqIII{>?HMMSSUUHll2&&G:""'(2L$z,
TXYYYa{8>PXP_```Hb"2C888 Hll2&&G ++gx//r'   )r   r   r   r   r   r   r   r   r   r   r   r   r0   )r   )__name__
__module____qualname____doc__r    r7   no_gradrI   r:   rO   rc   rr   rd   rQ   rb   rP   __classcell__)r%   s   @r&   r   r      s        $ "{{      > U]__+ + _+Z$Z $Z $ZL0 0 04& & &B
[ 
[ 
[. . . .>'; '; ';R, , , ,>#0 #0 #0 #0 #0 #0 #0r'   r   c                      e Zd ZdZd Zd ZdS )RotatedTaskAlignedAssignerzSAssigns ground-truth objects to rotated bounding boxes using a task-aligned metric.c                n    t          ||                              d                              d          S )z)Calculate IoU for rotated bounding boxes.rK   r   )r   rq   r|   r}   s      r&   rr   z*RotatedTaskAlignedAssigner.iou_calculationi  s.    y),,44R88??BBBr'   c                j   |dddf         | j         d         k     }t          j        ||z                                  t          j        | j        |j        |j                  |dddf                   |dddf<   t          |          }|	                    dd          \  }}}}	||z
  }
|	|z
  }||z
  }|
|
z  
                    d	          }||z  
                    d	          }||
z  
                    d	          }||z  
                    d	          }|dk    ||k    z  |dk    z  ||k    z  S )
a  Select the positive anchor center in gt for rotated bounding boxes.

        Args:
            xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
            gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
            mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (b, n_boxes, 1).

        Returns:
            (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
        .rk   r   r   ri   r   rN   r   rK   )r   r7   r   rT   r   r#   rj   r2   r	   splitr   )r$   r   rD   rE   r   cornersabr   dabadapnorm_abnorm_ad	ap_dot_ab	ap_dot_ads                    r&   rb   z3RotatedTaskAlignedAssigner.select_candidates_in_gtsm  sK    C1H%A6#kw$$&&L		HXYYYc1Q3h
 
	#qs( !++]]1"]--
1aUU !^7--B-''7--B-''"WMMbM))	"WMMbM))	Q9#78INKy\cOcddr'   N)r   r   r   r   rr   rb   r*   r'   r&   r   r   f  s@        ]]C C Ce e e e er'   r         ?c           	     ^   g g }}| J | d         j         | d         j        }}t          t          |                     D ]F}||         }t	          | t
                    r| |         j        dd         n5t          | |         d                   t          | |         d                   f\  }	}
t          j	        |
||          |z   }t          j	        |	||          |z   }t          rt          j        ||d          nt          j        ||          \  }}|                    t          j        ||fd                              dd                     |                    t          j        |	|
z  df|||	                     Ht          j        |          t          j        |          fS )
zGenerate anchors from features.Nr   rk   r   )rl   r2   rj   ij)indexingrK   ri   )rj   r2   r   r"   
isinstancer   r4   r   r7   ro   r   meshgridappendstackrp   fullr   )featsstridesgrid_cell_offsetanchor_pointsstride_tensorrj   r2   ir   hwsxsys                r&   make_anchorsr     s   #%r=M!HNE!HO6E3u:: Y Y%/t%<%<fuQx~abb!!3uQxPQ{CSCSUXY^_`YabcYdUeUeBf1\ae<<<?OO\ae<<<?OO:D`B6666%.Y[]_J`J`BU["b266;;BBBCCCUZQ
F%PVWWWXXXX9]##UY}%=%===r'   TrK   c                    |                      d|          \  }}||z
  }||z   }|r$||z   dz  }||z
  }	t          j        ||	g|          S t          j        ||f|          S )z.Transform distance(ltrb) to box(xywh or xyxy).rk   )r   r7   r   )
distancer   rz   rL   r   r   x1y1x2y2c_xywhs
             r&   	dist2bboxr     sx    ^^As##FB2D2D *tq D[y$S)))9dD\3'''r'   r   torch.Tensorbboxreg_max
int | Nonereturnc                    |                     dd          \  }}t          j        | |z
  || z
  fd          }||                    d|dz
            }|S )z#Transform bbox(xyxy) to dist(ltrb).rk   rK   Nr   {Gz?)r   r7   r   r|   )r   r   r   r   r   dists         r&   	bbox2distr     s[    Ar""JD$9md*D=,@A2FFD{{1gn--Kr'   c                ^   |                      d|          \  }}t          j        |          t          j        |          }}||z
  dz                       d|          \  }}	||z  |	|z  z
  ||z  |	|z  z   }}
t          j        |
|g|          |z   }t          j        |||z   g|          S )a  Decode predicted rotated bounding box coordinates from anchor points and distribution.

    Args:
        pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
        pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
        anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
        dim (int, optional): Dimension along which to split.

    Returns:
        (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
    rk   r   r   )r   r7   cossinr   )	pred_dist
pred_angler   rL   r   r   r   r   xfyfxyxys                r&   	dist2rboxr     s     __QC_((FBy$$ei
&;&;CBw!m""1#"..FB8b3hS28 3qA	Aq6s	#	#	#m	3B9b"r'],,,,r'   r[   target_anglerL   r   c                   |                      d|          \  }}||z
  }|                     d|          \  }}	t          j        |          t          j        |          }}
||
z  |	|z  z   }| |z  |	|
z  z   }|                     d|          \  }}|dz  |z
  }|dz  |z
  }|dz  |z   }|dz  |z   }t          j        ||||g|          }||                    d|dz
            }|S )a[  Transform rotated bounding box (xywh) to distance (ltrb). This is the inverse of dist2rbox.

    Args:
        target_bboxes (torch.Tensor): Target rotated bounding boxes with shape (bs, h*w, 4), format [x, y, w, h].
        anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
        target_angle (torch.Tensor): Target angle with shape (bs, h*w, 1).
        dim (int, optional): Dimension along which to split.
        reg_max (int, optional): Maximum regression value for clamping.

    Returns:
        (torch.Tensor): Rotated distance with shape (bs, h*w, 4), format [l, t, r, b].
    rk   r   r   Nr   r   )r   r7   r   r   r   r|   )r[   r   r   rL   r   r   r   offsetoffset_xoffset_yr   r   r   r   r   r   target_ltarget_ttarget_rtarget_br   s                        r&   	rbox2distr     s   &    ,,FB-FaS11Hhy&&	,(?(?C	C(S.	(B
S8c>	)B88A38DAq1urzH1urzH1urzH1urzH9h(H=3GGGD{{1gn--Kr'   )r   )TrK   r0   )r   r   r   r   r   r   r   r   )rK   )rK   N)
r[   r   r   r   r   r   rL   r   r   r   )
__future__r   r7   torch.nnnn r   r   r   r   opsr   r	   r
   torch_utilsr   Moduler   r   r   r   r   r   r   r*   r'   r&   <module>r     sw   # " " " " "              & & & & & & & & 5 5 5 5 5 5 5 5 5 5 # # # # # #U0 U0 U0 U0 U0") U0 U0 U0p
&e &e &e &e &e!4 &e &e &eR> > > > 	( 	( 	( 	(    - - - -2 $ $ $ $ $ $ $r'   