Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Matmul Thread Per Output Cell Vectorized

Uses shared memory via stack_allocation

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy, stack_allocation
from python import Python, PythonObject
from std.testing import assert_true
from algorithm import vectorize
from sys import simdwidthof, strided_load


comptime ROWS_A = 9
comptime COLS_A = 17
comptime ROWS_B = 17
comptime COLS_B = 7
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = (5, 5)
# Total numbers blocks in the grid
comptime BLOCKS = (
    (COLS_C + THREADS[0] - 1) // THREADS[0],
    (ROWS_C + THREADS[1] - 1) // THREADS[1],
)

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]
comptime Storage = LayoutTensor[
    dtype, Layout.row_major(1, simdwidthof[dtype]()), MutableAnyOrigin
]


def matmul_thread_per_output_cell_vectorized(
    A: MatrixA, B: MatrixB, C: MatrixC, store: Storage
):
    var i = block_idx.y * block_dim.y + thread_idx.y  # Rows
    var j = block_idx.x * block_dim.x + thread_idx.x  # Colums
    if i < ROWS_C and j < COLS_C:
        tile = stack_allocation[ROWS_B, Scalar[dtype]]()
        each_b_col = B.tile[ROWS_B, 1](0, j)
        for k in range(ROWS_B):
            tile[k] = each_b_col[k, 0][0]

        @parameter
        def dotproduct[simd_width: Int](idx: Int):
            C[i, j] += (
                A.load[width=simd_width](i, idx)
                * tile.load[width=simd_width](idx)
            ).reduce_add()

        vectorize[dotproduct, simdwidthof[dtype]()](ROWS_B)


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        store = ctx.enqueue_create_buffer[dtype](
            simdwidthof[dtype]()
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)
        storage = Storage(store)

        ctx.enqueue_function[matmul_thread_per_output_cell_vectorized](
            matrix_a,
            matrix_b,
            matrix_c,
            storage,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub