In [ ]:
import Pkg
Pkg.activate("colab6")
Pkg.add(["CUDA", "KernelAbstractions", "Adapt", "NVTX"])
In [ ]:
versioninfo()
In [ ]:
using CUDA, KernelAbstractions, Adapt
Different layers of abstraction¶
In [ ]:
N = 100000
a = 0.5
X_cpu = rand(Float64, N)
Y_cpu = zeros(Float64, N)
X = CuVector(X_cpu)
Y = CuVector(Y_cpu)
Vendor-specific¶
In [ ]:
function saxpy!(a,X,Y)
i = (blockIdx().x - 1) * blockDim().x + threadIdx().x
if i <= length(Y)
@inbounds Y[i] = a * X[i] + Y[i]
end
return nothing
end
@cuda threads=32 blocks=cld(length(Y), 32) saxpy!(a, X, Y)
Y
KernelAbstractions¶
In [ ]:
using KernelAbstractions
using CUDA
@kernel function kernel_saxpy!(a, @Const(X), Y)
I = @index(Global)
@inbounds Y[I] = a * X[I] + Y[I]
end
kernel_saxpy!(CUDABackend())(a, X, Y, ndrange=length(Y))
Y
Array abstractions¶
In [ ]:
Y .= a .* X .+ Y
KernelAbstractions.jl¶


















Summary -- How to use KernelAbstractions?¶
- Use
@kernel function mykernel(args...) end
to write a GPU-style program - Instantiate kernel for a backend
kernel = mykernel(backend)
- Backends come from Vendor specific libraries
KA.allocate(backend, ...)
to obtain memory- Launch kernel
kernel(args..., ndrange=...)
while specifying the grid to execute over.
In [ ]:
function vadd(a, b, c)
for i in eachindex(c)
c[i] = a[i] + b[i]
end
end
a = rand(N)
b = rand(N)
c = similar(a)
vadd(a, b, c)
In [ ]:
import KernelAbstractions as KA
@kernel function vadd_kernel(a, b, c)
i = @index(Global)
c[i] = a[i] + b[i]
end
In [ ]:
backend = CUDABackend()
a = KA.allocate(backend, Float32, N)
b = KA.allocate(backend, Float32, N)
c = similar(a)
vadd_kernel(backend)(a, b, c; ndrange=size(c))
Asynchronous operations¶
GPU operations are asynchronous with regards to the host! They are ordered with respect to each other, but special care must be taken when using Julia's task based programming together with GPU programming.
The JuliaGPU ecosystem synchronizes the GPU on access, so when you move data from and to the GPU we wait for all the kernels to finish!
When benchmarking you need to synchronize the device!
@benchmark begin
vadd_kernel(a, b, c; ndrange=size(c))
KA.synchronize(backend)
end
Otherwise you are only measuring the launch of the kernel.
What makes an application portable?¶
- Can I run it on a different compute architecture
- Different CPU architectures
- We live in a mult GPU vendor world
- Does it compute the same thing?
- Can I develop on one platform and move to another later?
- Does it achieve the same performance?
- Can I take advantage of platform specific capabilities?
Adapt.jl¶
Adapt.jl is a lightweight dependency that you can use to convert complex structures from CPU to GPU.
using Adapt
adapt(CuArray, ::Adjoint{Array})::Adjoint{CuArray}
struct Model{T<:Number, AT<:AbstractArray{T}}
data::AT
end
Adapt.adapt_structure(to, x::Model) = Model(adapt(to, x.data))
cpu = Model(rand(64, 64));
using CUDA
gpu = adapt(CuArray, cpu)
Model{Float64, CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}}(...)
GPU kernel -- transpose¶
In [ ]:
nreps = 3
N = 2048
T = Float32
TILE_DIM = 32
BLOCK_ROWS = 8
Naive kernels¶
In [ ]:
@kernel function simple_copy_kernel!(output, @Const(input))
I, J = @index(Global, NTuple)
@inbounds output[I, J] = input[I, J]
end
In [ ]:
@kernel function simple_transpose_kernel!(output, @Const(input))
I, J = @index(Global, NTuple)
@inbounds output[J, I] = input[I, J]
end
Using localmemory¶
In [ ]:
@kernel unsafe_indices = true function lmem_copy_kernel!(
output, @Const(input),
::Val{BANK} = Val(1),
) where {BANK}
I, J = @index(Global, NTuple)
i, j = @index(Local, NTuple)
N = @uniform @groupsize()[1]
M = @uniform @groupsize()[2]
# +1 to avoid bank conflicts on shared memory
tile = @localmem eltype(output) (N + BANK, M)
@inbounds tile[i, j] = input[I, J]
@synchronize
@inbounds output[I, J] = tile[i, j]
end
In [ ]:
@kernel unsafe_indices = true function lmem_transpose_kernel!(
output, @Const(input),
::Val{BANK} = Val(1),
) where {BANK}
gi, gj = @index(Group, NTuple)
i, j = @index(Local, NTuple)
N = @uniform @groupsize()[1]
M = @uniform @groupsize()[2]
# +1 to avoid bank conflicts on shared memory
tile = @localmem eltype(output) (N + BANK, M)
# Manually calculate global indexes
# Later on we need to pivot the group index
I = (gi - 1) * N + i
J = (gj - 1) * M + j
@inbounds tile[i, j] = input[I, J]
@synchronize
# Pivot the group index
I = (gj - 1) * M + i
J = (gi - 1) * N + j
@inbounds output[I, J] = tile[j, i]
end
Local Memory + process multiple elements per lane¶
In [ ]:
using KernelAbstractions.Extras: @unroll
In [ ]:
@kernel unsafe_indices=true function coalesced_copy_kernel!(
output, @Const(input),
::Val{BANK} = Val(1),
) where {BANK}
gi, gj = @index(Group, NTuple)
i, j = @index(Local, NTuple)
TILE_DIM = @uniform @groupsize()[1]
BLOCK_ROWS = @uniform @groupsize()[2]
# +1 to avoid bank conflicts on shared memory
tile = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
# Can't use @index(Global), because we use a smaller ndrange
I = (gi - 1) * TILE_DIM + i
J = (gj - 1) * TILE_DIM + j
@unroll for k in 0:BLOCK_ROWS:(TILE_DIM - 1)
@inbounds tile[i, j + k] = input[I, J + k]
end
@synchronize
@unroll for k in 0:BLOCK_ROWS:(TILE_DIM - 1)
@inbounds output[I, J + k] = tile[i, j + k]
end
end
In [ ]:
@kernel unsafe_indices = true function coalesced_transpose_kernel!(
output, @Const(input),
::Val{BANK} = Val(1),
) where {BANK}
gi, gj = @index(Group, NTuple)
i, j = @index(Local, NTuple)
TILE_DIM = @uniform @groupsize()[1]
BLOCK_ROWS = @uniform @groupsize()[2]
# +1 to avoid bank conflicts on shared memory
tile = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
# Can't use @index(Global), because we use a smaller ndrange
I = (gi - 1) * TILE_DIM + i
J = (gj - 1) * TILE_DIM + j
@unroll for k in 0:BLOCK_ROWS:(TILE_DIM - 1)
@inbounds tile[i, j + k] = input[I, J + k]
end
@synchronize
# Transpose block offsets
I = (gj - 1) * TILE_DIM + i
J = (gi - 1) * TILE_DIM + j
@unroll for k in 0:BLOCK_ROWS:(TILE_DIM - 1)
@inbounds output[I, J + k] = tile[j + k, i]
end
end
Benchmark harness¶
In [ ]:
using NVTX, Random
In [ ]:
backend = CUDABackend()
In [ ]:
CUDA.@profile for block_dims in ((TILE_DIM, TILE_DIM), (TILE_DIM * TILE_DIM, 1), (1, TILE_DIM * TILE_DIM))
for (name, kernel) in (
("copy", simple_copy_kernel!(backend, block_dims)),
("transpose", simple_transpose_kernel!(backend, block_dims)),
)
NVTX.@range "Simple $name $block_dims" let
input = rand!(allocate(backend, T, N, N))
output = similar(input)
# compile kernel
kernel(output, input, ndrange = size(output))
for rep in 1:nreps
kernel(output, input, ndrange = size(output))
end
KernelAbstractions.synchronize(backend)
end
end
end
In [ ]:
# Benchmark localmem
CUDA.@profile for (name, kernel) in (
("copy", lmem_copy_kernel!(backend, (TILE_DIM, TILE_DIM))),
("transpose", lmem_transpose_kernel!(backend, (TILE_DIM, TILE_DIM))),
)
for bank in (true, false)
NVTX.@range "Localmem $name ($TILE_DIM, $TILE_DIM) bank=$bank" let
input = rand!(allocate(backend, T, N, N))
output = similar(input)
# compile kernel
kernel(output, input, Val(Int(bank)), ndrange = size(output))
for rep in 1:nreps
kernel(output, input, Val(Int(bank)), ndrange = size(output))
end
KernelAbstractions.synchronize(backend)
end
end
end
In [ ]:
# Benchmark localmem + multiple elements per lane
CUDA.@profile for (name, kernel) in (
("copy", coalesced_copy_kernel!(backend, (TILE_DIM, BLOCK_ROWS))),
("transpose", coalesced_transpose_kernel!(backend, (TILE_DIM, BLOCK_ROWS))),
)
for bank in (true, false)
NVTX.@range "Localmem + multiple elements $name ($TILE_DIM, $BLOCK_ROWS) bank=$bank" let
input = rand!(allocate(backend, T, N, N))
output = similar(input)
# We want a number of blocks equivalent to (TILE_DIM, TILE_DIM)
# but our blocks are (TILE_DIM, BLOCK_ROWS) so we need to remove
# a factor from the size of the array otherwise we get to many blocks
block_factor = div(TILE_DIM, BLOCK_ROWS)
ndrange = (N, div(N, block_factor))
# compile kernel
kernel(output, input, Val(Int(bank)), ndrange = ndrange)
for rep in 1:nreps
kernel(output, input, Val(Int(bank)), ndrange = ndrange)
end
KernelAbstractions.synchronize(backend)
end
end
end
Matrix multiply¶
In [ ]:
@kernel function naive_matmul_kernel!(output, a, b)
i, j = @index(Global, NTuple)
# creating a temporary sum variable for matrix multiplication
tmp_sum = zero(eltype(output))
for k in 1:size(a)[2]
tmp_sum += a[i, k] * b[k, j]
end
output[i, j] = tmp_sum
end
In [ ]:
# Creating a wrapper kernel for launching with error checks
function naive_matmul!(output, a, b)
if size(a)[2] != size(b)[1]
println("Matrix size mismatch!")
return nothing
end
backend = KernelAbstractions.get_backend(a)
kernel! = naive_matmul_kernel!(backend)
kernel!(output, a, b, ndrange = size(output))
return
end
In [ ]:
let
a = rand!(allocate(backend, Float32, 256, 123))
b = rand!(allocate(backend, Float32, 123, 45))
output = KernelAbstractions.zeros(backend, Float32, 256, 45)
naive_matmul!(output, a, b)
@assert isapprox(output, a * b)
end
In [ ]:
@kernel unsafe_indices = true function coalesced_matmul_kernel!(
output, @Const(A), @Const(B),
::Val{BANK} = Val(1),
) where {BANK}
gi, gj = @index(Group, NTuple)
i, j = @index(Local, NTuple)
TILE_DIM = @uniform @groupsize()[1]
# +1 to avoid bank conflicts on shared memory
tile1 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
tile2 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
# private variable for tile output
outval = @private eltype(output) 1
@inbounds outval[1] = -zero(eltype(output))
@uniform N = size(output, 1)
@uniform M = size(output, 2)
@uniform R = size(A, 2)
# number of tiles depends on inner dimension
@uniform NUM_TILES = div(R + TILE_DIM - 1, TILE_DIM)
# loop over all tiles needed for this calculation
for t in 0:(NUM_TILES - 1)
# Can't use @index(Global), because we use a smaller ndrange
I = (gi - 1) * TILE_DIM + i
J = (gj - 1) * TILE_DIM + j
# load inputs into tiles, with bounds checking for non-square matrices
if I <= N && t * TILE_DIM + j <= R
@inbounds tile1[i, j] = A[I, t * TILE_DIM + j]
else
@inbounds tile1[i, j] = 0.0
end
if t * TILE_DIM + i <= R && J <= M
@inbounds tile2[i, j] = B[t * TILE_DIM + i, J]
else
@inbounds tile2[i, j] = 0.0
end
# wait for all tiles to be loaded
@synchronize
# get global values again
I = (gi - 1) * TILE_DIM + i
J = (gj - 1) * TILE_DIM + j
# calculate value of spot in output, use temporary value to allow for vectorization
out = zero(eltype(output))
@simd for k in 1:TILE_DIM
@inbounds out += tile1[i, k] * tile2[k, j]
end
outval[1] += out
@synchronize
end
# get global indices again
I = (gi - 1) * TILE_DIM + i
J = (gj - 1) * TILE_DIM + j
# save if inbounds
if I <= N && J <= M
@inbounds output[I, J] = outval[1]
end
end
In [ ]:
# Creating a wrapper kernel for launching with error checks
function coalesced_matmul!(output, a, b)
if size(a)[2] != size(b)[1]
println("Matrix size mismatch!")
return nothing
end
backend = KernelAbstractions.get_backend(a)
kernel! = coalesced_matmul_kernel!(backend, (TILE_DIM, TILE_DIM))
kernel!(output, a, b, ndrange = size(output))
return
end
In [ ]:
let
a = rand!(allocate(backend, Float32, 256, 123))
b = rand!(allocate(backend, Float32, 123, 45))
output = KernelAbstractions.zeros(backend, Float32, 256, 45)
coalesced_matmul!(output, a, b)
@assert isapprox(output, a * b)
end
In [ ]:
import LinearAlgebra
Exercise¶
- Vary N, R, M
- Vary T
In [ ]:
let
N = 1024
R = 512
M = 2048
T = Float64
A = rand!(allocate(backend, T, N, R))
B = rand!(allocate(backend, T, R, M))
output_naive = KernelAbstractions.zeros(backend, T, N, M)
output_coalesced = KernelAbstractions.zeros(backend, T, N, M)
output_mul = KernelAbstractions.zeros(backend, T, N, M)
CUDA.@profile for _ in 1:nreps
NVTX.@range "Naive Matmul" begin
naive_matmul!(output_naive, A, B)
KernelAbstractions.synchronize(backend)
end
NVTX.@range "Coalesced Matmul" begin
coalesced_matmul!(output_coalesced, A, B)
KernelAbstractions.synchronize(backend)
end
NVTX.@range "LinearAlgebra.mul!" begin
LinearAlgebra.mul!(output_mul, A, B)
KernelAbstractions.synchronize(backend)
end
end
end