Identity and ArithTuple Tensors in CuTeDSL
Recently, I came across a strange concept called "Identity Tensor" while writing kernels in CuTeDSL, which is a tensor that maps a tensor's coordinate to itself. Sounds simple, but is it?
Practically, it is used within CuTeDSL kernels for checking out of bounds computation, creating predicate tensors, and so on. In this blog, I'll write about my intuition of the identity tensor and why it feels kinda weird at first but makes sense later.
Introduction
Let's write a simple host function that takes a 2D tensor and prints the corresponding identity tensor.
@cute.jit
def test(mX: cute.Tensor):
cute.printf("\nGlobal tensor mX: {}", mX)
idX = cute.make_identity_tensor(mX.shape)
cute.printf("\nIdentity tensor: {}", idX)
x = np.random.rand(14 * 1024).reshape((14, 1024))
test(from_dlpack(x))
The above code prints:
Global tensor mX: raw_ptr(0x000000003673a4d0: f64, generic, align<8>) o (14,1024):(1024,1) =
( 0.211140, 0.971345, 0.472651, 0.753005, 0.545098, 0.065438, 0.015243, 0.624164, 0.574682, 0.733167, 0.239575, 0.820586, 0.626296, 0.138826, 0.425476, 0.048658, 0.969981, 0.549983, 0.035949, 0.652512, 0.686649, 0.288750, 0.328552, 0.193020, 0.630206, 0.983489, 0.107112, 0.847828, 0.432965, 0.471674, 0.381877, [...] )
Identity tensor: (0,0) o (14,1024):(1@0,1@1)
A tensor in CuTeDSL is identified with an Engine and a Layout. The Engine holds an iterator which can be dereferenced for accessing the data. The Layout contains information about the tensor's shape and the strides (elements to jump in each dimension to reach the next value).
In the above example, mX's shape is (14, 1024) with strides (1024, 1) indicating that mX is a row-major tensor. The coordinates this tensor can have ranges from 0->13 in dim/mode-0, and 0->1023 in dim/mode-1. The Engine here is pointer to the start of the data the tensor has.
To access an element of mX at coordinate (1, 1) we need to compute the dot product of the given coordinate with the strides of mX, which gives us an offset into the tensor. We then add that offset to the Engine data pointer.
mX[1, 1] = 0x000000000f7b45b0 + (1, 1) * (1024, 1)
= 0x000000000f7b45b0 + (1 * 1024) + (1 * 1)
= 0x000000000f7b45b0 + 1025
= 0.770638
So far so good, but what about the identity tensor? The Engine there i.e. (0, 0) is not a data pointer in memory, also the strides here are (1@0, 1@1). When you print the tensor, you get:
tensor((0,0) o (14,1024):(1@0,1@1), data=
[[ (0,0), (0,1), (0,2), ..., (0,1021), (0,1022), (0,1023), ],
[ (1,0), (1,1), (1,2), ..., (1,1021), (1,1022), (1,1023), ],
[ (2,0), (2,1), (2,2), ..., (2,1021), (2,1022), (2,1023), ],
...
[ (11,0), (11,1), (11,2), ..., (11,1021), (11,1022), (11,1023), ],
[ (12,0), (12,1), (12,2), ..., (12,1021), (12,1022), (12,1023), ],
[ (13,0), (13,1), (13,2), ..., (13,1021), (13,1022), (13,1023), ]])
As you can see, it contains the coordinates which the mX tensor can have, and they are mapped to themselves. But, what does the Engine being (0, 0) and strides being (1@0, 1@1) mean?
What's the @?
The stride pattern "a@b" (I call it "a at b") means that it is not an "integer" stride rather a basis vector (for e.g. x and y axis in an infinite 2D plane) where the value at the dimension "b" is "a" written in string notation. For example:
1@0 = (1, 0, 0, ...) and 1@1 = (0, 1, 0, 0, ...)
The pattern can be nested as well because CuTeDSL layouts can be nested, for example, reading it from left to right:
1@0@1 = "1 at 0 at 1" = (0, (1, 0, ...), 0, ...)
1@1@0 = "1 at 1 at 0" = ((0, 1, 0, ...), 0, ...)
Tensors like the identity tensor are called Arithmetic Tuple Tensors in CuTeDSL language. Instead of the engine being a data pointer, it is a tuple and instead of the strides being integers, they are basis vectors. Also, these tensors are not stored in the register/local memory on the GPU. Instead, the values are computed on the fly!
These "strides as basis vectors" can be used to compute what element an identity tensor contains at some coordinate (m, n). For m = 2, and n = 15:
idX[2, 15] = (0, 0) + (2, 15) * (1@0, 1@1)
= (0, 0) + 2 * (1, 0, 0, ...) + 15 * (0, 1, 0, ...)
= (0, 0) + (2, 0, ...) + (0, 15, ...)
= (2, 15, ...)
= (2, 15)
This, by itself, may not seem useful at a global level. But, when you tile the global tensors into smaller sub-tiles that each thread can work with on a GPU, we can use the coordinate/identity tensor sub-tiles for each thread to know what the current thread processes.
A Practical Example
As an example, let's say for a row-wise kernel like RMS/Layer norm, we have a 2D tensor of shape (M = 14, N = 1024). Each block should process a full column and some amount of rows. There are TN threads per row and BLOCK_DIM // TN threads per column. If TN = 32, and BLOCK_DIM = 128 then each block will process 128 / 32 = 4 rows.
Now, each of the TN = 32 threads should load a part of the full column in tiles. Thread 0 may load the first 8 elements in vectorized manner (0...7), thread 1 may load the next 8 elements (8...15) and so on. Each thread will load 8 elements (1024 / 8 * 32) = 4 times to cover the full column.
In code, we first tile the global tensor on block level using CuTeDSL's local_tile function. Each block will have shape (4, 1024). We can then use CuTeDSL's tiled copy, thread slice, and partition functions to see what 1 thread processes in each block:
@cute.jit
def test(mX: cute.Tensor):
tile_rows, tile_cols = 4, 1024
rows_per_block, threads_per_row = 4, 32
tiler_mn = (rows_per_block, tile_cols)
# example
bidx = 0
tidx = 0
cute.printf("Global input tensor mX:\n {}", mX)
idX = cute.make_identity_tensor(mX.shape)
cute.printf("Global identity tensor:\n {}", idX)
gX = cute.local_tile(mX, tiler_mn, (bidx, 0))
cX = cute.local_tile(idX, tiler_mn, (bidx, 0))
cute.printf("Block level input tensor gX:\n {}", gX)
cute.printf("Block level coordinate tensor cX:\n {}", cX)
thr_layout = cute.make_ordered_layout((rows_per_block, threads_per_row), order=(1, 0))
val_layout = cute.make_layout((1, 8)) # each thread loads 8 elements
copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mX.element_type, num_bits_per_copy=128)
tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
thr_copy_X = tiled_copy.get_slice(tidx)
tXgX = thr_copy_X.partition_S(gX)
tXcX = thr_copy_X.partition_S(cX)
cute.printf("Thread slice input tensor tXgX:\n {}", tXgX)
cute.printf("Thread slice identity tensor:\n {}", tXcX)
x = np.random.rand(14 * 1024).reshape((14, 1024))
test(from_dlpack(x))
gives us the output:
Global input tensor mX:
raw_ptr(0x000000000a9c0700: f64, generic, align<8>) o (14,1024):(1024,1) =
( 0.140136, 0.974619, 0.368508, 0.730187, 0.882944, 0.554125, 0.433968, 0.520788, 0.569155, 0.636642, 0.605161, 0.190401, 0.072166, 0.979344, 0.309837, 0.294862, 0.077016, 0.589228, 0.207305, 0.526463, 0.957993, 0.552704, 0.301575, 0.132150, 0.584438, 0.878755, 0.390871, 0.410519, 0.298568, 0.802744, 0.442688, [...] )
Global identity tensor:
(0,0) o (14,1024):(1@0,1@1)
Block level input tensor gX:
raw_ptr(0x000000000a9c0700: f64, generic, align<8>) o (4,1024):(1024,1) =
( 0.140136, 0.974619, 0.368508, 0.730187, 0.309837, 0.294862, 0.077016, 0.589228, 0.298568, 0.802744, 0.442688, 0.255719, 0.969421, 0.839816, 0.460844, 0.345255, 0.241419, 0.301533, 0.175595, 0.319782, 0.432365, 0.044018, 0.747581, 0.874302, 0.478939, 0.887515, 0.205933, 0.532725, 0.137578, 0.413298, 0.463574, [...] )
Block level coordinate tensor cX:
(0,0) o (4,1024):(1@0,1@1)
Thread slice input tensor tXgX:
raw_ptr(0x000000000a9c0700: f64, generic, align<8>) o ((2,4),1,4):((1,2),0,256) =
( 0.140136, 0.309837, 0.298568, 0.969421, 0.241419, 0.432365, 0.478939, 0.137578, 0.972452, 0.454237, 0.925056, 0.184337, 0.313381, 0.775815, 0.245599, 0.388027, 0.616751, 0.258569, 0.782289, 0.103342, 0.361084, 0.470297, 0.965911, 0.841517, 0.739463, 0.532083, 0.183096, 0.768398, 0.738167, 0.250136, 0.824740, [...] )
Thread slice identity tensor:
(0,0) o ((2,4),1,4):((1@1,2@1),0,256@1)
The thread slice of the identity tensor here is:
Engine = (0, 0) and Layout = ((2, 4), 1, 4) : ((1@1, 2@1), 0, 256@1)
I have intentionally chosen the rows in the global tensor to be 14. It is not divisible by the rows per block i.e. 4, meaning the last two columns ceil(14 / 4) = 4 are extra. CuTeDSL, by default, rounds up the leftovers. How do we let the threads copy only the elements that are within the global tensor row, and not do anything for the elements when the thread processes out-of-bounds elements.
Since, for 4 blocks, the extra rows will be in the block with index 3. Let's see what tXcX contains for the first thread in each row of block index 3:
Thread 0 slice identity tensor:
(12,0) o ((2,4),1,4):((1@1,2@1),0,256@1)
Thread 32 slice identity tensor:
(13,0) o ((2,4),1,4):((1@1,2@1),0,256@1)
Thread 64 slice identity tensor:
(14,0) o ((2,4),1,4):((1@1,2@1),0,256@1)
Thread 128 slice identity tensor:
(16,0) o ((2,4),1,4):((1@1,2@1),0,256@1)
As you can see, we can get the start of the global coordinate each thread processes by reading the first element in the tXcX tensor:
global_row, global_col = tXcX[(0, 0), 0, 0]
For thread 0 inside block 3, global row is 12. For thread 32 inside block 3, global row becomes 13 and so on. We can now copy only the required elements and skip the OOB reads by checking the condition:
if global_row < mX.shape[0]:
cute.copy(...)
We can also iterate through the last dim/mode of the identity tensor tXcX to see the start of the columns (copying 8 elements) each thread process since the stride there is 256@1:
Thread 0:
tXcX[(0, 0), 0, 0]: (12,0)
tXcX[(0, 0), 0, 1]: (12,256)
tXcX[(0, 0), 0, 2]: (12,512)
tXcX[(0, 0), 0, 3]: (12,768)
Thread 1:
tXcX[(0, 0), 0, 0]: (12,8)
tXcX[(0, 0), 0, 1]: (12,264)
tXcX[(0, 0), 0, 2]: (12,520)
tXcX[(0, 0), 0, 3]: (12,776)
Here, thread 0 in block 3 processes the elements in global row 12 and global columns starting at 0, 256, 512, and 768. Thread 1 processes the same global row but starts at global column 8, 264, 520, and 776.
This is just one example of what Identity Tensors can do in CuTeDSL. They are used in almost all kernels for bounds-checking, predicate tensors, TMA loads, and so on. I hope this article made you understand them intuitively. Thanks for reading!