12-3

Inside ThunderKittens' Python Bindings

In this post, we'll look at pyutils.cuh. This is how ThunderKittens wraps CUDA kernels to be launchable from Python via PyTorch. Original code here!

An Introduction to the Libraries

ThunderKittens (TK) is a tile-based DSL. The goal of this post (and subsequent posts on TK) is that by studying key parts of TK, we'll develop a decent mental model that will help us understand the rest. We assume some knowledge of C++ and CUDA.

Pybind11 is a library that creates Python bindings for C++ code. It allows us to call C++ functions from Python, expose C++ classes to Python, and pass data between C++ and Python. Connecting these languages is nice because Python (Pytorch specifically) is easier to write, and C++ executes faster.

pyutils.cuh uses pybind11 to wrap CUDA kernels so you can launch them from Python, convert PyTorch tensors to C++ GL/PGL objects, and enable multiGPU operations from Python.

Primary template:

template<typename T> struct from_object {
    static T make(pybind11::object obj) {
        return obj.cast<T>();
    }
    static T unwrap(pybind11::object obj, int dev_idx) {
        return make(obj); // Scalars should be passed in as a scalar
    }
};

The primary from_object casts Python objects into the target C++ type. Pybind11 supports casting for int and other scalars, string, and pybind11/stl.h (which we import) supports basic data structures like vector, map, set, tuple.

make and unwrap seem redundant here. Turns out, make is for single-GPU kernel launches and unwrap is for multi-GPU kernel launches, and these do become distinct when we look at the GL/PGL specializations.

Template specialization for GL:

template<ducks::gl::all GL> struct from_object<GL> {

ducks::gl::all is a concept that ensures this specialization matches GL types. A little joke on duck typing (if it quacks like a duck...)

Creating GLs from Tensors

Validation:

if (pybind11::hasattr(obj, "__class__") && obj.attr("__class__").attr("__name__").cast<std::string>() == "Tensor") {
    if (!obj.attr("is_contiguous")().cast<bool>()) {
        throw std::runtime_error("Tensor must be contiguous");
    }
    if (obj.attr("device").attr("type").cast<std::string>() == "cpu") {
        throw std::runtime_error("Tensor must be on CUDA device");
    }

Confirms that the input is a contiguous CUDA tensor. TK operates on raw memory, so strided or view tensors (non-contiguous memory) is not supported.

Shape normalization:

std::array<int, 4> shape = {1, 1, 1, 1};
auto py_shape = obj.attr("shape").cast<pybind11::tuple>();
size_t dims = py_shape.size();
for (size_t i = 0; i < dims; ++i) {
    shape[4 - dims + i] = pybind11::cast<int>(py_shape[i]);
}

TK always works with 4 dimensions (representing batch, depth, height, width in that order). Lower-dimensional tensors and right-aligned and left-padded with 1s. For example, [32, 64] becomes [1, 1, 32, 64].

Pointer extraction:

uint64_t data_ptr = obj.attr("data_ptr")().cast<uint64_t>();
return make_gl<GL>(data_ptr, shape[0], shape[1], shape[2], shape[3]);

We extract the GPU memory address of the Tensor and pass it in along with dimensions to make_gl (in include/types/global/gl.cuh) to construct the GL.

Selecting GLs for Multi-GPU

return *lst[dev_idx].cast<std::shared_ptr<GL>>();

For multi-GPU launches, we expect a Python list of std::shared_ptr<GL> made by multigpu_make and select it by dev_idx.

std::shared_ptr<GL> is a C++ smart pointer that holds a GL. We want a shared pointer to ensure that the descriptor is not deallocated when at least one language (Python/C++) is still using it. Note that we are managing the descriptor's lifetime, not the actual GPU memory.

A descriptor is a lightweight struct that describes a resource but doesn't contain it. The actual tensor data lives in HBM (potentially hundreds of megabytes). The GL, which is a descriptor, consists of just a pointer and shape metadata.

We can think of make as a factory and unwrap as a selector, picking the right C++ object from a prepared list, thanks to multigpu_make.

Template specialization for PGL:

A PGL (Parallel Global Layout) is a descriptor for a distributed tensor, which is the same matrix split across multiple GPUs. It's a single struct containing pointers to each GPU's chunk. This matters for multi-GPU operations (all-reduce, ring-all-gather, direct memory access over NVLink). Since TK uses SPMD (single program, multiple data), every GPU runs the same kernel and needs to know where every other GPU's data lives. A PGL looks something like:

uint64_t data_ptrs[4]; // data address on each GPU
int common_shape[4];  // shared shape

Creating PGLs from Tensor Lists

Validation:

static_assert(!PGL::MULTICAST, "Multicast not yet supported on pyutils. Please initialize the multicast pointer manually.");
        if (!pybind11::isinstance<pybind11::list>(obj))
            throw std::runtime_error("PGL from_object expected a Python list.");
        pybind11::list tensors = pybind11::cast<pybind11::list>(obj);
        if (tensors.size() != PGL::num_devices)
            throw std::runtime_error("Expected a list of " + std::to_string(PGL::num_devices) + " tensors");

Input must be a Python list of tensors with length matching PGL::num_devices. Multicast is not yet supported via pyutils.

Consistency check:

std::array<int, 4> shape = {1, 1, 1, 1};
    uint64_t data_ptrs[PGL::num_devices];
    for (int i = 0; i < PGL::num_devices; i++) {
      auto tensor = tensors[i];
      if (!pybind11::hasattr(tensor, "__class__") ||
          tensor.attr("__class__").attr("__name__").cast<std::string>() !=
              "Tensor")
        throw std::runtime_error("Expected a list of torch.Tensor");
      if (!tensor.attr("is_contiguous")().cast<bool>())
        throw std::runtime_error("Tensor must be contiguous");
      if (tensor.attr("device").attr("type").cast<std::string>() == "cpu")
        throw std::runtime_error("Tensor must be on CUDA device");
      auto py_shape = tensor.attr("shape").cast<pybind11::tuple>();
      size_t dims = py_shape.size();
      if (dims > 4)
        throw std::runtime_error("Expected Tensor.ndim <= 4");
      for (size_t j = 0; j < dims; ++j) {
        if (i == 0)
          shape[4 - dims + j] = pybind11::cast<int>(py_shape[j]);
        else if (shape[4 - dims + j] != pybind11::cast<int>(py_shape[j]))
          throw std::runtime_error("All tensors must have the same shape");
      }
      data_ptrs[i] = tensor.attr("data_ptr")().cast<uint64_t>();
    }

Iterates over all tensors to verify each is a contiguous CUDA tensor with ≤4 dimensions. The first tensor's shape becomes the reference; all others must match. Collects each tensor's memory address into data_ptrs[i].

Construction:

return make_pgl<PGL>(data_ptrs, shape[0], shape[1], shape[2], shape[3]);

Calls make_pgl with the full array of pointers so that the resulting C++ object knows every GPU's data location.

Accessing the PGL

static PGL unwrap(pybind11::object obj, int dev_idx) {
    return *obj.cast<std::shared_ptr<PGL>>();
}

Unlike GL's unwrap, this ignores dev_idx. A PGL already contains the complete map of all devices, so every GPU gets the same descriptor.

Class Registration

static std::unordered_set<std::string> registered;
template <typename T> static void register_pyclass(pybind11::module &m) {
  if constexpr (ducks::gl::all<T> || ducks::pgl::all<T>) {
    std::string _typename = typeid(T).name();
    if (registered.find(_typename) == registered.end()) {
      pybind11::class_<T, std::shared_ptr<T>>(m, _typename.c_str());
      registered.insert(_typename);
    }
  }
}

The registered set tracks which C++ classes have been exposed to Python.

register_pyclass is called by the binding macros (bind_kernel, bind_multigpu_kernel, etc.) for each kernel argument. The if constexpr check ensures we only register GL/PGL types - pybind11 already handles primitives like int or float.

If the type hasn't been registered yet, we create a Python class via pybind11::class_<T, std::shared_ptr<T>>. The shared_ptr holder ensures Python holds references the same way our make/unwrap functions do (do not take raw ownership, do not copy). The class name is the messy typeid string - not user-facing, so this is fine.

Preparing Arguments for Multi-GPU

template <typename T>
static pybind11::object multigpu_make(pybind11::object obj) {
  if constexpr (ducks::gl::all<T>) { ...
  } 
  else if constexpr (ducks::pgl::all<T>) { ...
  } 
  else { ...
  }
}

A factory that prepares Python arguments for multi-GPU distribution. The if constexpr evaluates at compile time - when T is an int, the compiler deletes the GL/PGL branches entirely.

GL Branch

if (!pybind11::isinstance<pybind11::list>(obj))
   throw std::runtime_error("multigpu_make [GL] expected a Python list.");
pybind11::list lst = pybind11::cast<pybind11::list>(obj);
std::vector<std::shared_ptr<T>> gls;
for (int i = 0; i < lst.size(); i++)
  gls.push_back(std::make_shared<T>(from_object<T>::make(lst[i])));
return pybind11::cast(gls);

Expects a Python list of tensors. Converts each tensor to a GL via from_object<T>::make, wraps each in a shared_ptr, and returns the vector as a Python object. This is what unwrap later indexes into with dev_idx.

PGL Branch

return pybind11::cast(std::make_shared<T>(from_object<T>::make(obj)));

PGLs are already distributed (one struct holds all device pointers), so we just construct it and wrap in a shared_ptr.

Scalar Branch

return pybind11::cast(from_object<T>::make(obj));

For primitives (int, float, etc.). No shared_ptr is needed - scalars are lightweight and copied by value.

std::make_shared<T>(...) allocates a T on the heap and returns a std::shared_ptr<T> pointing to it. It's preferred over new because it combines the object allocation and the reference count allocation into one memory block, and it's exception-safe.

Concepts

Concepts are compile-time validators that enforce type requirements upfront.

Checking for Dynamic Shared Memory

template<typename T> concept has_dynamic_shared_memory = requires(T t) { 
    { t.dynamic_shared_memory() } -> std::convertible_to<int>; 
};

Checks that T has a .dynamic_shared_memory() method returning an integer. Most CUDA kernels define shared memory size at compile time. This concept identifies kernels that set it dynamically.

Validating Multi-GPU Globals

template<typename T> concept is_multigpu_globals = requires { 
    { T::num_devices } -> std::convertible_to<std::size_t>;
    { T::dev_idx } -> std::convertible_to<std::size_t>;
} && T::num_devices >= 1;

Checks that T has num_devices and dev_idx members, with at least one device. Prevents accidentally using the multi-GPU launcher with single-GPU kernels.

Member Pointer Introspection

template<typename> struct trait;
template<typename MT, typename T> struct trait<MT T::*> { 
    using member_type = MT; 
    using type = T; 
};

A mechanism to extract types from a pointer-to-member. Given &MyGlobals::batch_size (which has type int MyGlobals::*), the compiler matches the pattern MT T::* and deduces MT = int, T = MyGlobals.

auto ptr = &MyGlobals::batch_size;
using DetectedType = trait<decltype(ptr)>::member_type;  // int

This lets the binding macros deduce argument types at compile time from the struct definition alone.

Binding Single-GPU Kernels

template<auto kernel, typename TGlobal> 
static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
    m.def(name, [](object<decltype(member_ptrs)>... args, pybind11::kwargs kwargs) {
        // ...
    });
}

Generates a Python wrapper for a CUDA kernel. After this runs, Python can call import kittens; kittens.my_kernel(...).

kernel is the GPU function pointer, TGlobal is the globals struct type.

m is the Python module, name is the function name, member_ptrs are pointers to each struct member.

Argument unpacking

TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};

Constructs the globals struct via aggregate initialization. The variadic pack zips Python arguments with their compile-time types - if you passed &int_var and &tensor_var, this expands to from_object<int>::make(arg0) and from_object<GL>::make(arg1).

Stream handling

cudaStream_t raw_stream = nullptr;
if (kwargs.contains("stream")) {
    uintptr_t stream_ptr = kwargs["stream"].attr("cuda_stream").cast<uintptr_t>();
    raw_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
}

Extracts the CUDA stream from a PyTorch stream object, allowing the kernel to run asynchronously alongside other PyTorch operations.

A CUDA stream is a work queue for the GPU. Operations in the same stream run sequentially; operations in different streams can run concurrently. PyTorch creates its own streams to manage operations - here, we inject the TK kernel into PyTorch's existing queue to keep things flowing.

Kernel launch

if constexpr (has_dynamic_shared_memory<TGlobal>) {
    int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory();
    cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
    kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__, raw_stream>>>(__g__);
} else {
    kernel<<<__g__.grid(), __g__.block(), 0, raw_stream>>>(__g__);
}

If TGlobal defines dynamic_shared_memory(), we reconfigure the L1/shared memory split at runtime (up to 227KB on Hopper, vs. 48KB default). Otherwise, shared memory is assumed static and we pass 0.

Binding Host Functions

template<auto function, typename TGlobal> 
static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) {
    m.def(name, [](object<decltype(member_ptrs)>... args) {
        TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
        function(__g__);
    });
}

The host-side sibling of bind_kernel. While bind_kernel launches code on the GPU, bind_function executes a C++ function on the CPU. Same argument conversion pattern - Python objects become the TGlobal struct via from_object - but it calls function(__g__) directly instead of a kernel launch.

Useful for debug utilities, print helpers, or complex wrappers that don't need GPU execution.

Multi-GPU Setup Utilities

Enabling Peer-to-Peer Access

m.def("enable_all_p2p_access", [](const std::vector<int>& device_ids) {
    // ...
    for (int i = 0; i < device_ids.size(); i++) {
        CUDACHECK(cudaSetDevice(device_ids[i]));
        for (int j = 0; j < device_ids.size(); j++) {
            if (i == j) continue;
            cudaDeviceEnablePeerAccess(device_ids[j], 0);
        }
    }
});

Enables peer-to-peer access so GPUs can directly read each other's memory via NVLink or PCIe, bypassing CPU RAM. Required for PGL operations.

The Multi-GPU Executor

pybind11::class_<KittensClub, std::shared_ptr<KittensClub>>(m, "KittensClub")
    .def(pybind11::init([](const std::vector<int>& device_ids) {
        auto club = std::make_shared<KittensClub>(device_ids.data(), device_ids.size());
        club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup
        return club;
    }), pybind11::arg("device_ids"))

KittensClub (defined in pyutils/club.cuh) is a multi-GPU executor. The empty execute() call is a warmup - CUDA context initialization is lazy and slow, so we pay that cost upfront.

Second (stream-friendly) constructor

.def(pybind11::init([](const std::vector<int>& device_ids, 
                       const std::vector<pybind11::object>& streams) {
    std::vector<cudaStream_t> raw_streams(streams.size());
    for (size_t i = 0; i < streams.size(); ++i) {
        uintptr_t stream_ptr = streams[i].attr("cuda_stream").cast<uintptr_t>();
        raw_streams[i] = reinterpret_cast<cudaStream_t>(stream_ptr);
    }
    auto club = std::make_shared<KittensClub>(device_ids.data(), raw_streams.data(), device_ids.size());
    club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup
    return club;
}), pybind11::arg("device_ids"), pybind11::arg("streams"))

Accepts PyTorch streams so kernels run in sync with other PyTorch operations. The pybind11::arg tags enable keyword arguments in Python:

club = kittens.KittensClub(
    device_ids=[0, 1, 2], 
    streams=[s1, s2, s3]
)

Binding Multi-GPU Kernels

template<auto kernel, typename TGlobal> 
static void bind_multigpu_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {

The multi-GPU equivalent of bind_kernel. Takes Python arguments, distributes them across GPUs, and launches the same kernel on each device simultaneously via KittensClub.

Validation

static_assert(is_multigpu_globals<TGlobal>, 
    "Multigpu globals must have a member num_devices >= 1 and dev_idx");

Ensures TGlobal has num_devices and dev_idx - prevents using single-GPU globals with the multi-GPU launcher.

Class registration

(register_pyclass<typename trait<decltype(member_ptrs)>::member_type>(m), ...);

Fold expression that registers all GL/PGL argument types with Python. Without this, Python crashes with a type error.

Debug helper

m.def((std::string("make_globals_")+name).c_str(), 
    [](object<decltype(member_ptrs)>... args) -> std::vector<pybind11::object> {
        return {multigpu_make<typename trait<decltype(member_ptrs)>::member_type>(args)...};
    });

Creates a helper function make_globals_<kernel_name> that runs multigpu_make on each argument and returns the results. Useful for verifying the data conversion is correct before kernel launch.

Kernel binding

m.def(name, [](std::shared_ptr<KittensClub> club, object<decltype(member_ptrs)>... args) {
    std::vector<TGlobal> __g__;
    for (int i = 0; i < TGlobal::num_devices; i++) {
        __g__.emplace_back(from_object<typename trait<decltype(member_ptrs)>::member_type>::unwrap(args, i)...);
        __g__.back().dev_idx = i;
    }
    // ...
});

Builds a globals struct for each GPU. unwrap scatters the data - scalars are copied to every GPU, GL lists are indexed by dev_idx, PGLs are shared as-is.

Kernel launch

if constexpr (has_dynamic_shared_memory<TGlobal>) {
    club->execute([&](int dev_idx, cudaStream_t stream) {
        int __dynamic_shared_memory__ = (int)__g__[dev_idx].dynamic_shared_memory();
        cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
        kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), __dynamic_shared_memory__, stream>>>(__g__[dev_idx]);
    });
} else {
    club->execute([&](int dev_idx, cudaStream_t stream) {
        kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), 0, stream>>>(__g__[dev_idx]);
    });
}

club->execute dispatches the kernel across all GPUs, each with its own globals struct, grid/block configuration, and stream.

Wrap-Up

In pyutils.cuh, we describe how to convert Python objects to C++ types (from_object), preprocess Python arguments for distribution (multigpu_make), generate wrappers that Python can call (bind_kernel for single-GPU and bind_multigpu_kernel for multi-GPU), and set up multi-GPU infra (enable_all_p2p_access and KittensClub).

How it all flows: Python calls a kernel launch, bind_kernel receives Tensors as pybind11::objects, from_object converts each object to C++ (descriptors and scalars), the globals struct is aggregate-initialized, and the CUDA kernel launches with __g__ as its argument.