o
    nhM                     @   s,  d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ d dlmZ ejejeZejedgZejed	Zd
gZe  dd Ze  dd Ze  dd Zdd ZG dd deZ dd Z!dd Z"G dd deZ#G dd deZ$dS )    N)Path)_build)get_cache_manager)_allocation)	GPUTarget)	GPUDriverincludelibcudac                  C   s   t d} | r
| gS tddg }dd | D }dd |D }t d}|r6|s6dd |d	D }d
}|rG|dt| 7 }|d7 }n|d7 }|d7 }tdd |D s\J ||S )NTRITON_LIBCUDA_PATHz/sbin/ldconfigz-pc                 S   s    g | ]}d |v r|  d qS )libcuda.so.1)split).0line r   q/var/www/html/construction_image-detection-poc/venv/lib/python3.10/site-packages/triton/backends/nvidia/driver.py
<listcomp>   s     z libcuda_dirs.<locals>.<listcomp>c                 S   s   g | ]}t j|qS r   )ospathdirname)r   locr   r   r   r      s    LD_LIBRARY_PATHc                 S   s&   g | ]}t jt j|d r|qS )r   r   r   existsjoin)r   dirr   r   r   r   !   s   & :zlibcuda.so cannot found!
z!Possible files are located at %s.z:Please create a symlink of libcuda.so to any of the files.z<Please make sure GPU is set up and then run "/sbin/ldconfig"z- (requires sudo) to refresh the linker cache.c                 s   s&    | ]}t jt j|d V  qdS )r   Nr   )r   r   r   r   r   	<genexpr>)   s   $ zlibcuda_dirs.<locals>.<genexpr>)	r   getenv
subprocesscheck_outputdecode
splitlinesr   strany)env_libcuda_pathlibslocsdirsenv_ld_library_pathmsgr   r   r   libcuda_dirs   s"   


r,   c                   C   s   t gt S N)libdevice_dirr,   r   r   r   r   library_dirs-   s   r/   c                  C   s,   ddl m} m}m} d|  | g| S )Nr   machinesystemarchitecture,)platformr1   r2   r3   r   r0   r   r   r   platform_key2   s   r6   c              	   C   sJ  t | t  d }t|}tddd }|	| d| }|d u rt
 W}tj|d}t|d}||  W d    n1 sJw   Y  t|||t tt}	t|	d}|j| | d| dd	}W d    n1 sxw   Y  W d    n1 sw   Y  d
d l}
|
j||}|
j|}|j| |S )Nzutf-8
EXT_SUFFIX.r   zmain.cwrbT)binaryr   )hashlibsha256r6   encode	hexdigestr   	sysconfigget_config_varr   get_filetempfileTemporaryDirectoryr   r   r   openwriter   r/   include_dir	librariesputreadimportlib.utilutilspec_from_file_locationmodule_from_specloaderexec_module)srcnamekeycacheext
cache_pathtmpdirsrc_pathfso	importlibspecmodr   r   r   compile_module_from_src8   s*   
 r^   c                       s$   e Zd Z fddZdd Z  ZS )	CudaUtilsc                    s"   t | dstt| | | _| jS )Ninstance)hasattrsuperr_   __new__r`   )cls	__class__r   r   rc   S   s   
zCudaUtils.__new__c                 C   sP   t ttjtd d}|j| _|j| _|j	| _	|j
| _
|j| _|j| _d S )Nzdriver.c
cuda_utils)r^   r   r   r   r   r   	read_textload_binaryget_device_propertiescuOccupancyMaxActiveClustersset_printf_fifo_sizefill_1d_tma_descriptorfill_2d_tma_descriptor)selfr]   r   r   r   __init__X   s   zCudaUtils.__init__)__name__
__module____qualname__rc   rp   __classcell__r   r   re   r   r_   Q   s    r_   c                 C   sx   | d dkrdS i dddddd	d
ddddddddddddddddddddddddd|  S )Nr   *CUdeviceptri1int32_ti8int8_ti16int16_ti32i64int64_tu1uint32_tu8uint8_tu16uint16_tu32u64uint64_tfp16floatbf16fp32f32fp64double	nvTmaDescCUtensorMapr   )tyr   r   r   	ty_to_cppg   sJ   	
r   c                    s  fdd fdd fddd fdd	| D }d
| }d t| }ttt|d}dd t|D }t|dkrUdd dd |	 D  nd}d dd |	 D }g }|	 D ]-\}}|d dkr}|
d| d qi|dkr|
d|  qi|dkr|
d|  qitt|}	d}
dd	 |	 D }dd	 |	 D }dd	 |	 D }	|	
d dt|dkrd| nd d d |	 d!|
  fd"d	|	 D  d#| d$| d%|
 | d|
 | d&t|dkrdd | nd d'}|S )(Nc                    s   t | trdt | S | S )Nr4   )
isinstancetupler   map)sig)_serialize_signaturer   r   r      s   
z+make_launcher.<locals>._serialize_signaturec                    sJ   t | trdt | }d| dS | d dkrdS | dv r!dS t| S )Nr4   []r   ru   z	PyObject*	constexprr   r   r   r   r   r   r   val_extracted_typer   r   r      s   
z&make_launcher.<locals>._extracted_typec                    sf   t | trdt | }d| dS | d dkrdS | dv r!dS dd	d
dddddddddt|  S )N ()r   ru   Or   rY   dlbhiLBHIK)r   r   longrz   r|   rx   r   r   r   r   r   r   r   	format_ofr   r   r      s*   
z make_launcher.<locals>.format_ofr   c                    s   g | ]} |qS r   r   )r   r   r   r   r   r      s    z!make_launcher.<locals>.<listcomp>iiiKKpOOOOOr4   c                 S      i | ]\}}||qS r   r   )r   r   sr   r   r   
<dictcomp>       z!make_launcher.<locals>.<dictcomp>r   , c                 s   s    | ]
\}}d | V  qdS )z&_argNr   r   r   r   r   r   r   r      s    z make_launcher.<locals>.<genexpr>c                 s   s.    | ]\}}|d krt | d| V  qdS )r   z argN)r   r   r   r   r   r      s   , ru   ptr_infoz.dev_ptrr   z*tma_ptrr   _argz
  c                 S   s:   g | ]\}}|d  dkrd| d| d| d| d	qS )r   ru   zDevicePtrInfo ptr_infoz = getPointer(_argr   z); if (!ptr_infoz.valid) return NULL;r   r   r   r   r   r      s
    c              	   S   s0   g | ]\}}|d krd| d| d| dqS )r   zCUtensorMap* tma_ptrz = getTmaDesc(_argz); if (!tma_ptrz) return NULL;r   r   r   r   r   r      s
    c                 S   s"   g | ]\}}|d krd| qS )r   z&argr   r   r   r   r   r      s   " z&global_scratchaB  
#include "cuda.h"
#include <stdbool.h>
#include <Python.h>
#include <dlfcn.h>

static inline void gpuAssert(CUresult code, const char *file, int line)
{
   if (code != CUDA_SUCCESS)
   {
      const char* prefix = "Triton Error [CUDA]: ";
      const char* str;
      cuGetErrorString(code, &str);
      char err[1024] = {0};
      strcat(err, prefix);
      strcat(err, str);
      PyGILState_STATE gil_state;
      gil_state = PyGILState_Ensure();
      PyErr_SetString(PyExc_RuntimeError, err);
      PyGILState_Release(gil_state);
   }
}

#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }

typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);

static cuLaunchKernelEx_t getLaunchKernelExHandle() {
  // Open the shared library
  void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
  if (!handle) {
    PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
    return NULL;
  }
  // Clear any existing error
  dlerror();
  cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
  // Check for errors
  const char *dlsym_error = dlerror();
  if (dlsym_error) {
    PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
    return NULL;
  }
  return cuLaunchKernelExHandle;
}

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratchz) {
  void *params[] = { aO   };
  if (gridX*gridY*gridZ > 0) {
    if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {
      CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
    } else if ((num_ctas == 1) && (0 != launch_cooperative_grid)) {
      CUlaunchAttribute launchAttr[1];
      CUlaunchAttribute coopAttr = { .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1};
      launchAttr[0] = coopAttr;

      CUlaunchConfig config;
      config.gridDimX = gridX;
      config.gridDimY = gridY;
      config.gridDimZ = gridZ;
      config.blockDimX = 32 * num_warps;
      config.blockDimY = 1;
      config.blockDimZ = 1;
      config.sharedMemBytes = shared_memory;
      config.hStream = stream;
      config.attrs = launchAttr;
      config.numAttrs = 1;

      static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
      if (cuLaunchKernelExHandle == NULL) {
        cuLaunchKernelExHandle = getLaunchKernelExHandle();
      }
      CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));

    } else {
      CUlaunchAttribute launchAttr[3];
      launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
      launchAttr[0].value.clusterDim.x = clusterDimX;
      launchAttr[0].value.clusterDim.y = clusterDimY;
      launchAttr[0].value.clusterDim.z = clusterDimZ;
      launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
      launchAttr[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;

      unsigned numAttrs = 2;
      if (0 != launch_cooperative_grid) {
        CUlaunchAttribute coopAttr = { .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1};
        launchAttr[2] = coopAttr;
        numAttrs = 3;
      }

      CUlaunchConfig config;
      config.gridDimX = gridX * clusterDimX;
      config.gridDimY = gridY * clusterDimY;
      config.gridDimZ = gridZ * clusterDimZ;
      config.blockDimX = 32 * num_warps;
      config.blockDimY = 1;
      config.blockDimZ = 1;
      config.sharedMemBytes = shared_memory;
      config.hStream = stream;
      config.attrs = launchAttr;
      config.numAttrs = numAttrs;
      static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
      if (cuLaunchKernelExHandle == NULL) {
        cuLaunchKernelExHandle = getLaunchKernelExHandle();
      }
      CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
    }
  }
}

typedef struct _DevicePtrInfo {
    CUdeviceptr dev_ptr;
    bool valid;
} DevicePtrInfo;

static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {
  DevicePtrInfo ptr_info;
  ptr_info.dev_ptr = 0;
  ptr_info.valid = true;
  if (PyLong_Check(obj)) {
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
    return ptr_info;
  }
  if (obj == Py_None) {
    // valid nullptr
    return ptr_info;
  }
  PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
  if(ptr){
    PyObject *empty_tuple = PyTuple_New(0);
    PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
    Py_DECREF(empty_tuple);
    Py_DECREF(ptr);
    if (!PyLong_Check(ret)) {
      PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
      ptr_info.valid = false;
      return ptr_info;
    }
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
    if(!ptr_info.dev_ptr)
      return ptr_info;
    uint64_t dev_ptr;
    int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
    if (status == CUDA_ERROR_INVALID_VALUE) {
        PyErr_Format(PyExc_ValueError,
                     "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
        ptr_info.valid = false;
    } else if (status != CUDA_SUCCESS) {
        CUDA_CHECK(status);  // Catch any other cuda API errors
        ptr_info.valid = false;
    }
    ptr_info.dev_ptr = dev_ptr;
    Py_DECREF(ret);  // Thanks ChatGPT!
    return ptr_info;
  }
  PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
  ptr_info.valid = false;
  return ptr_info;
}

static inline CUtensorMap* getTmaDesc(PyObject *obj) {
  if (sizeof(CUtensorMap*) != 8) {
    PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
    return NULL;
  }

  PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr");
  if (!method_handle) {
    PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
    return NULL;
  }

  PyObject *empty_tuple = PyTuple_New(0);
  if (!empty_tuple) {
    Py_DECREF(method_handle);
    PyErr_SetString(PyExc_SystemError, "Internal Python error!");
    return NULL;
  }
  PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
  Py_DECREF(empty_tuple);
  Py_DECREF(method_handle);
  if (!method_ret) {
    PyErr_SetString(PyExc_SystemError, "Internal Python error!");
    return NULL;
  }

  if (!PyLong_Check(method_ret)) {
    PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int");
    Py_DECREF(method_ret);
    return NULL;
  }

  uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret);
  Py_DECREF(method_ret);
  if (!ptr_as_uint) {
    PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()");
    return NULL;
  }
  if (ptr_as_uint % 64 != 0) {
    PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned");
    return NULL;
  }

  return (CUtensorMap*)(ptr_as_uint);
}

static void ensureCudaContext() {
  CUcontext pctx;
  CUDA_CHECK(cuCtxGetCurrent(&pctx));
  if (!pctx) {
    // Ensure device context.
    CUdevice device;
    CUDA_CHECK(cuDeviceGet(&device, 0));
    CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
    CUDA_CHECK(cuCtxSetCurrent(pctx));
  }
}

static PyObject* launch(PyObject* self, PyObject* args) {
  // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
  ensureCudaContext();

  int gridX, gridY, gridZ;
  uint64_t _stream;
  uint64_t _function;
  int launch_cooperative_grid;
  PyObject *launch_enter_hook = NULL;
  PyObject *launch_exit_hook = NULL;
  PyObject *kernel_metadata = NULL;
  PyObject *launch_metadata = NULL;
  PyObject *global_scratch_obj = NULL;
  c                    s$   g | ]\}} | d | dqS )z _arg;r   r   r   r   r   r     s   $ z
  if(!PyArg_ParseTuple(args, "a*  ", &gridX, &gridY, &gridZ,
                                           &_stream, &_function, &launch_cooperative_grid, &global_scratch_obj,
                                           &kernel_metadata, &launch_metadata,
                                           &launch_enter_hook, &launch_exit_hookat  )) {
    return NULL;
  }

  int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
  if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {
    PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
    return NULL;
  }

  // extract launch metadata
  if (launch_enter_hook != Py_None){
    PyObject* args = Py_BuildValue("(O)", launch_metadata);
    PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
    Py_DECREF(args);
    if (!ret)
      return NULL;
  }

  CUdeviceptr global_scratch = 0;
  if (global_scratch_obj != Py_None) {
    DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
    if (!global_scratch_info.valid) {
      return NULL;
    }
    global_scratch = global_scratch_info.dev_ptr;
  }

  // raise exception asap
  z
  Py_BEGIN_ALLOW_THREADS;
  _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratcha0  );
  Py_END_ALLOW_THREADS;
  if (PyErr_Occurred()) {
    return NULL;
  }

  if(launch_exit_hook != Py_None){
    PyObject* args = Py_BuildValue("(O)", launch_metadata);
    PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
    Py_DECREF(args);
    if (!ret)
      return NULL;

  }

  Py_RETURN_NONE;
}

static PyMethodDef ModuleMethods[] = {
  {"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
  {NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef ModuleDef = {
  PyModuleDef_HEAD_INIT,
  "__triton_launcher",
  NULL, //documentation
  -1, //size
  ModuleMethods
};

PyMODINIT_FUNC PyInit___triton_launcher(void) {
  PyObject *m = PyModule_Create(&ModuleDef);
  if(m == NULL) {
    return NULL;
  }
  PyModule_AddFunctions(m, ModuleMethods);
  return m;
}
)r   valuesr   listfilterboolr   	enumeratelenitemsappendrange)	constants	signatureargs_formatformat	args_list	arg_declsinternal_args_listr   r   paramsnewline	ptr_decls	tma_declsrQ   r   )r   r   r   r   make_launcher~   s   
,
./ h  i  l   
            5r   c                   @   s   e Zd Zdd Zdd ZdS )CudaLauncherc                    s   t drjnt }fdd  fdd| D }dd j D }t||td}|j| _|j| _|j	| _	|j
| _
d S )Nr   c                    s   t | tr jj| fS | S r-   )r   r$   fn	arg_namesindex)x)rQ   r   r   <lambda>   s    z'CudaLauncher.__init__.<locals>.<lambda>c                    s   i | ]	\}} ||qS r   r   r   idxvalue)arg_idxr   r   r     s    z)CudaLauncher.__init__.<locals>.<dictcomp>c                 S   r   r   r   r   r   r   r   r     r   __triton_launcher)ra   r   dictr   r   r   r^   launchglobal_scratch_sizeglobal_scratch_alignlaunch_cooperative_grid)ro   rQ   metadatar   r   r]   r   )r   rQ   r   rp     s   

zCudaLauncher.__init__c           
      G   sZ   | j dkr|| | }|| j  }t|| j|}	nd }	| j|||||| j|	g|R   d S Nr   )r   r   
_allocatorr   r   r   )
ro   gridXgridYgridZstreamfunctionargs	grid_size
alloc_sizeglobal_scratchr   r   r   __call__
  s   

$zCudaLauncher.__call__N)rq   rr   rs   rp   r   r   r   r   r   r     s    r   c                       sX   e Zd Z fddZdd Zdd Zdd Zed	d
 Zdd Z	dd Z
dd Z  ZS )
CudaDriverc                    s   t  | _t| _t   d S r-   )r_   utilsr   launcher_clsrb   rp   )ro   re   r   r   rp     s   zCudaDriver.__init__c                 C   s6   |   }| |}|d d |d  }d}td||S )Nr   
          r
   )get_current_deviceget_device_capabilityr   )ro   device
capability	warp_sizer   r   r   get_current_target  s
   
zCudaDriver.get_current_targetc                 C   s   dd l }|d|  S )Nr   r
   )torchr   r   ro   r   r   r   r   get_active_torch_device"  s   z"CudaDriver.get_active_torch_devicec                 C   s   dd l }|jS r   )r   r
   r   r   r   r   get_device_interface&  s   zCudaDriver.get_device_interfacec                  C   s6   zdd l } | j o| jjd u W S  ty   Y dS w )Nr   F)r   r
   is_availableversionhipImportError)r   r   r   r   	is_active*  s   zCudaDriver.is_activec                 C   s   ddl m} |S )Nr   )do_bench)triton.testingr  )ro   r  r   r   r   get_benchmarker2  s   zCudaDriver.get_benchmarkerc                 C   s&   dd l }d}|jt|d |jddS )Nr   i      r
   )dtyper   )r   emptyint)ro   r   
cache_sizer   r   r   get_empty_cache_for_benchmark6  s   z(CudaDriver.get_empty_cache_for_benchmarkc                 C   s   |   d S r-   )zero_)ro   rT   r   r   r   clear_cache?  s   zCudaDriver.clear_cache)rq   rr   rs   rp   r   r   r   staticmethodr  r  r  r  rt   r   r   re   r   r     s    
	r   )%	functoolsr   r@   r<   r    rC   pathlibr   triton.runtime.buildr   triton.runtime.cacher   triton.runtimer   triton.backends.compilerr   triton.backends.driverr   r   r   realpath__file__r   rG   r.   rH   	lru_cacher,   r/   r6   r^   objectr_   r   r   r   r   r   r   r   r   <module>   s>    


   