Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Dumb matrix multiplication

Use one one GPU thread for each column of the output matrix

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true

comptime ROWS_A = 64
comptime COLS_A = 16
comptime ROWS_B = 16
comptime COLS_B = 8
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = (5, 5)
# Total numbers blocks in the grid
comptime BLOCKS = (
    (COLS_C + THREADS[0] - 1) // THREADS[0],
    (ROWS_C + THREADS[1] - 1) // THREADS[1],
)

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]


def matmul_thread_per_output_cell[
    a: Layout, b: Layout, c: Layout
](A: MatrixA, B: MatrixB, C: MatrixC,):
    var i = block_idx.y * block_dim.y + thread_idx.y  # Rows
    var j = block_idx.x * block_dim.x + thread_idx.x  # Colums

    if i < ROWS_C and j < COLS_C:
        for k in range(ROWS_B):
            C[i, j] += A[i, k] * B[k, j]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)

        ctx.enqueue_function[
            matmul_thread_per_output_cell[layout_a, layout_b, layout_c]
        ](
            matrix_a,
            matrix_b,
            matrix_c,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub