o
    ho;                     @   st  d dl Z d dlmZ d dlmZ d dlmZmZ ejfddZ	d ejfddZ
d ejfdd	Zejejfd
dZejejfddZejfddZejfddZejfddZddejfddZejejfddZG dd deZG dd deZG dd deZG dd deZG d d! d!eZG d"d# d#eZG d$d% d%eZG d&d' d'eZG d(d) d)eZG d*d+ d+eZdS ),    N)Function)groupReduceOpc                 C      t ||| S )a  
    Broadcasts the tensor to the whole group.

    ``tensor`` must have the same number of elements in all processes
    participating in the collective.

    Arguments:
        tensor (Tensor): Data to be sent if ``src`` is the rank of current
            process.
        src (int): Source rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Received tensor from the broadcast op.

    )
_Broadcastapply)tensorsrcr    r
   s/var/www/html/construction_image-detection-poc/venv/lib/python3.10/site-packages/torch/distributed/nn/functional.py	broadcast   s   r   c                 C   r   )aT  
    Gathers a list of tensors in a single process.

    Arguments:
        tensor (Tensor): Input tensor.
        dst (int, optional): Destination rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
    )_Gatherr   )r   dstr   r
   r
   r   gather    s   r   c                 C   s   t j||g| R  S )a  
    Scatters a list of tensors to all processes in a group.

    Each process will receive exactly one tensor and store its data in the
    ``tensor`` argument.

    Arguments:
        tensors (list[Tensor]): List of tensors to scatter on the source rank.
            Receivers must pass ``None`.
        src (int, optional): Source rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output tensor from the scatter operation.

    )_Scatterr   )tensorsr	   r   r
   r
   r   scatter/   s   r   c                 C   s   t |||| S )a  
    Reduces the tensor data across all machines.

    Only the process with rank ``dst`` is going to receive the final result.

    Arguments:
        tensor (Tensor): Input of the collective.
        dst (int): Destination rank.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reducer   )r   r   opr   r
   r
   r   reduceC   s   r   c                 C   s   t j||| g|R  S )a  
    Reduces, then scatters a list of tensors to all processes in a group.

    Arguments:
        output (Tensor): Output tensor.
        input_list (list[Tensor]): List of tensors to reduce and scatter.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reduce_Scatterr   )output
input_listr   r   r
   r
   r   reduce_scatterX   s   r   c                 C   s   t || S )a  
    Gathers tensors from the whole group in a list.

    Arguments:
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )
_AllGatherr   )r   r   r
   r
   r   
all_gatherk   s   r   c                 C   s   t | ||S )a  
    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.

    Args:
        output_tensor (Tensor): Output tensor. It should contain
            correctly-sized tensors to be used for output of the collective.
        input_tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> # xdoctest: +SKIP("incorrect want text")
        >>> output_tensor = torch.zeros(2, dtype=torch.int64)
        >>> output_tensor
        [tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
        >>> tensor
        tensor([1]) # Rank 0
        tensor([2]) # Rank 1
        >>> dist.all_gather_base(output_tensor, tensor)
        >>> output_tensor
        tensor([1,2]) # Rank 0
        tensor([1,2]) # Rank 1

    .. warning::
        `_all_gather_base` is experimental and subject to change.
        It is the caller's responsibility to ensure the output_tensor
        is correctly sized.

    )_AllGatherBaser   )output_tensorinput_tensorr   r
   r
   r   _all_gather_basez   s   !r   c                 C   s   t j|| g|R  S )a  
    Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.

    Arguments:
        output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )	_AlltoAllr   )output_tensor_listinput_tensor_listr   r
   r
   r   
all_to_all   s   r#   c                 C   s   t || |||S )a  
    Each process splits input tensor and then scatters the split list to all processes in a group.

    Then concatenate the received tensors from all the processes in the group and return single output tensor.

    Arguments:
        output (Tensor): Gathered concatenated output tensor.
        input (Tensor): Input tensor to scatter.
        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
            if specified None or empty, dim 0 of ``output`` tensor must divide
            equally by ``world_size``.
        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
            if specified None or empty, dim 0 of ``input`` tensor must divide
            equally by ``world_size``.

    Returns:
        Tensor: Output of the collective.

    )_AlltoAllSingler   )r   inputoutput_split_sizesinput_split_sizesr   r
   r
   r   all_to_all_single   s   
r(   c                 C   r   )a&  
    Reduces the tensor data across all machines in such a way that all get the final result.

    After the call the returned tensor is going to be bitwise
    identical in all processes.

    Arguments:
        tensor (Tensor): Input of the collective.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective

    )
_AllReducer   )r   r   r   r
   r
   r   
all_reduce   s   r*   c                   @   $   e Zd Zedd Zedd ZdS )r   c                 C   s6   || _ || _tj|d| _| }tj|||d |S Nr   )r	   r   distget_rankrankcloner   )ctxr	   r   r   r
   r
   r   forward   s   z_Broadcast.forwardc                 C   s4   t | jtj| j|}| j| jkr|  d d |fS N)r   r   r	   r   SUMr   r0   zero_)r2   grad_outputgxr
   r
   r   backward   s   
z_Broadcast.backwardN__name__
__module____qualname__staticmethodr3   r9   r
   r
   r
   r   r      s
    

r   c                   @   r+   )r   c                    sv   || _ || _ fddttj|dD }   tj|d|kr.tj |||d t|S tj d ||d t|S )Nc                       g | ]}t  qS r
   )torch
zeros_like.0ir   r
   r   
<listcomp>       
z#_Gather.forward.<locals>.<listcomp>r-   )	r   r   ranger.   get_world_size
contiguousr/   r   tuple)r2   r   r   r   tensor_listr
   rE   r   r3      s   
z_Gather.forwardc                 G   s   dt j| j| jg|R  f S NNN)r   r   r   r   )r2   grad_outputsr
   r
   r   r9   
  s   z_Gather.backwardNr:   r
   r
   r
   r   r      s
    
r   c                   @   r+   )r   c                    st   || _ || _t fdd D sJ t d }tj|d|kr/tj|t ||d |S tj|d ||d |S )Nc                 3   s$    | ]}|   d    kV  qdS )r   NsizerC   tr   r
   r   	<genexpr>  s   " z#_Scatter.forward.<locals>.<genexpr>r   r-   )	r	   r   allr@   rA   r.   r/   r   list)r2   r	   r   r   r   r
   rT   r   r3     s   z_Scatter.forwardc                 C   s   dt | j| j| S rM   )r   r   r	   r   r2   r7   r
   r
   r   r9     s   z_Scatter.backwardNr:   r
   r
   r
   r   r     s
    
r   c                   @   r+   )r   c                 C   s*   || _ || _| }tj||||d |S )Nr   r   )r	   r   r1   r.   r   )r2   r	   r   r   r   r
   r
   r   r3   "  s
   z_Reduce.forwardc                 C      dt | j| j|f S N)NNN)r   r   r	   r   rX   r
   r
   r   r9   *     z_Reduce.backwardNr:   r
   r
   r
   r   r   !  
    
r   c                   @   r+   )r   c                 G   s:   || _ | }tdd |D }tj|t|||d |S )Nc                 s       | ]}|  V  qd S r4   rJ   rR   r
   r
   r   rU   5      z*_Reduce_Scatter.forward.<locals>.<genexpr>rY   )r   rJ   rK   r.   r   rW   )r2   r   r   r   r"   r
   r
   r   r3   0  s
   z_Reduce_Scatter.forwardc                 C   s   dt | j| S r[   )r   r   r   rX   r
   r
   r   r9   9  s   z_Reduce_Scatter.backwardNr:   r
   r
   r
   r   r   /  s
    
r   c                   @   r+   )r   c                    sD       || _ fddttj|dD }tj| |d t|S )Nc                    r?   r
   r@   
empty_like)rC   _rE   r
   r   rF   E  rG   z&_AllGather.forward.<locals>.<listcomp>r-   )rJ   r   rH   r.   rI   r   rK   )r2   r   r   out_tensor_listr
   rE   r   r3   ?  s   
z_AllGather.forwardc                 G   s   t j| jdt jju r*t j| jd}t|| }tj	t
j| j|g|R  }d |fS dd |D }tj	| j|g|R  }tjt|dd}d |fS )Nr-   c                 S   s   g | ]}t |qS r
   ra   )rC   r   r
   r
   r   rF   U  s    z'_AllGather.backward.<locals>.<listcomp>r   )dim)r.   get_backendr   BackendNCCLr/   r@   rb   r   r   r   r5   r    sumstack)r2   rO   r0   r8   rL   gxsr
   r
   r   r9   L  s   z_AllGather.backwardNr:   r
   r
   r
   r   r   >  s
    
r   c                   @   r+   )r   c                 C   s   || _ tj|| |d |S r,   )r   r.   r   rJ   )r2   r   r   r   r
   r
   r   r3   \  s   z_AllGatherBase.forwardc                 C   s   t j| jdt jju rLt j| jd}t| }|d | dkr*td| d| |d t j| jd |d< t	j
||j|jd}t ||tj| j ntdd |d fS )Nr-   r   zTensor with dimensions: z8 does not have first dimension divisible by world_size: devicedtypezBackend not supported!)r.   rf   r   rg   rh   rI   rW   rQ   RuntimeErrorr@   emptyrm   rn   _reduce_scatter_baser   r5   )r2   r7   
world_sizeout_sizer8   r
   r
   r   r9   b  s    

z_AllGatherBase.backwardNr:   r
   r
   r
   r   r   [  s
    
r   c                   @   r+   )r    c                    s   || _  fddttj|dD | _tj|d}tdd  D  tj|dtjj	u rPttj|dD ]}d }||kr@t
 }tj|| |||d q4t|S tj|t
 |d t|S )Nc                    s   g | ]} |   qS r
   rP   rB   rT   r
   r   rF   z  s    z%_AlltoAll.forward.<locals>.<listcomp>r-   c                 s   r^   r4   r_   rR   r
   r
   r   rU   ~  r`   z$_AlltoAll.forward.<locals>.<genexpr>)r   rH   r.   rI   input_tensor_size_listr/   rK   rf   rg   GLOOrW   r   r#   )r2   r   rd   r   my_rankrD   to_sendr
   rT   r   r3   w  s&   
z_AlltoAll.forwardc                    s.    fdd| j D }dtj| j|g R   S )Nc                    s(   g | ]}t j| d  j d  jdqS )r   rl   )r@   rp   rm   rn   )rC   rQ   rO   r
   r   rF     s    z&_AlltoAll.backward.<locals>.<listcomp>rN   )rt   r    r   r   )r2   rO   rL   r
   rx   r   r9     s   
z_AlltoAll.backwardNr:   r
   r
   r
   r   r    v  s
    
r    c                   @   r+   )r$   c                 C   s4   || _ | | _|| _|| _tj|||||d |S )N)r&   r'   r   )r   rQ   
input_sizer&   r'   r.   r(   )r2   r   r   r&   r'   r%   r
   r
   r   r3     s   
z_AlltoAllSingle.forwardc              	   C   s8   t j| j|j|jd}dt| j|| j| j	|
 f S )Nrl   )NNNN)r@   rp   ry   rm   rn   r$   r   r   r&   r'   rJ   )r2   r7   r   r
   r
   r   r9     s   z_AlltoAllSingle.backwardNr:   r
   r
   r
   r   r$     s
    
r$   c                   @   r+   )r)   c                 C   s.   || _ || _|jtjd}tj|||d |S )N)memory_formatrY   )r   r   r1   r@   contiguous_formatr.   r*   )r2   r   r   r   r
   r
   r   r3     s
   z_AllReduce.forwardc                 C   rZ   rM   )r)   r   r   r   rX   r
   r
   r   r9     r\   z_AllReduce.backwardNr:   r
   r
   r
   r   r)     r]   r)   )r@   torch.distributeddistributedr.   torch.autogradr   r   r   WORLDr   r   r   r5   r   r   r   r   r#   r(   r*   r   r   r   r   r   r   r   r    r$   r)   r
   r
   r
   r   <module>   s6   $
# 