The CUDA Runtime and Driver APIs allow you to use“blocking synchronization” where the CPU will go to sleep while waiting for synchronization with the device. However, it seems that PyTorch doesn’t expose this functionality in any of its Python APIs:
https://github.com/pytorch/pytorch/issues/28224
What happens when you try using ctypes to call into libcudart.so to set the device flags as described in the above issue? You’ll have to call torch.cuda.init() for it to work, and unfortunately it won’t work if PyTorch is launching kernels from other threads.
* `cuCtxCreate`
* `cuCtxCreate_v3`
* `cuCtxSetFlags`
* `cuDevicePrimaryCtxRetain`
* `cuDevicePrimaryCtxSetFlags`
... and make sure that the three least significant bits of any `flags` variable are set to `CU_CTX_SCHED_BLOCKING_SYNC`.cuDevicePrimaryCtxSetFlags: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PR...
dlsym(3): https://man.archlinux.org/man/dlsym.3.en
ld.so(8): https://man.archlinux.org/man/ld.so.8.en
Also, this kind of ultra wide buffering consumes a ton of memory bandwidth for each operation, instead of keeping a small portion in cache/registers. FLOPs are scaling sort of infinitely, whereas memory speed is flat, so this is increasingly a losing game; just because it's faster than glacial Python doesn't mean it's fast compared to a language which actually concerns itself with performance or a more cache aware approach.
For an extreme example of how you can even sometimes beat ultra optimised GPU ML libraries in this way, check out https://github.com/NVlabs/tiny-cuda-nn
> I studied the CUDA traces closely and found that vectorization does indeed reduce many aspects of the GPU workload, greatly reducing the number of operations and decreasing the total amount of time spent on the fundamental computations of the algorithm. However it also introduces overhead (mentioned above) by interspersing operations that permute and reorder the tensors, or splitting them into groups then concatenating results. Sometimes the reduced “fundamental” time outweighs the additional overhead, while other times the overhead outweighs the reduction in fundamental time.
Here are some examples not included in the blog post:
- Total time spent in aten::cdist kernel
- Baseline: 2.834s (4900 calls)
- Vectorized: 2.686s (500 calls)
- Total time spent in aten::mul kernel - Baseline: 5.745s (80700 calls)
- Vectorized: 5.555s (8100 calls)
This nice little win applies to tons of other kernels, almost across the board. As you point out, CPU intuition suggests this should have been slower, so this was an interesting outcome.On the other hand, some specific increases occur:
- Total time spent in aten::cat kernel
- Baseline: 0.680s
- Vectorized: 1.849s
So working in fewer, larger batches doesn't only enable outrunning the GPU. It decreases the total GPU workload... then adds some overhead. But some of this overhead could be removed with custom CUDA kernels, so I think this is an interesting direction even if you solve the CPU problem some other way.(The pow(x, 2) is only there in the toy code, not my actual kernel, so I didn't performance-tune it.)
I'm surprised this is necessary, I thought modern vectorization on both CPU and GPU handled heterogenous vectorization cases like this handily with conditional execution (on SMT GPUs) or mask registers (on SIMD CPUs)
I will note that these grouped operations occasionally cause a net loss in performance compared to "naive" looping, since it involves calling PyTorch's "x.view(...)" which is usually ~instant but sometimes adds some extra CUDA operations on the backward pass. It always reduces the time spent in aten::add, but adds these extra ops. A really smart vectorizer would use heuristics to decide how/whether to group operations according to the target hardware; my current vectorizer just does the grouping every time.
Yes, I'm off doing my own thing now. Deep Learning went so much further than I ever expected, and now I'm drawn to all the things that can be built today. Who knows, maybe I'll swing back into neuroscience in a few years. (Still friends with my old coworkers / bosses.)
Fun fact, I had to put in extra work to get torch.compile working with my code, for understandable reasons. My library, Vexpr, literally runs an interpreter inside of Python, reading a big tree-like namedtuple-of-namedtuples "expression" data structure and evaluating it recursively. That data structure was way too fancy for torch.compile's guards, so I actually wrote code [1] that converts a Vexpr expression into a big Python code string and evals it, factoring the interpreter out of the code, then I pass that eval'd string into torch.compile.
One torch.compile capability I would be excited to see is compatibility with torch.vmap. One selling point of Vexpr is that you can use vmap with it, so I was sad when I found I couldn't use vmap and still support torch.compile. This made me convert a bunch of my GP kernels [2] to be batch-aware. (This missing capability is also understandable -- both vmap and compile are new.)
Anyway, I'm a fan of what y'all are doing!
[1] https://github.com/outergroup/vexpr/blob/e732e034768443386f9... [2] https://github.com/outergroup/outer-loop-cookbook/blob/5d94c...
> One torch.compile capability I would be excited to see is compatibility with torch.vmap
We added support for torch.func.vmap, iirc - check out test_higher_order_ops.py, grep for vmap.