Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Mojo πŸ”₯ Programming

CPU & GPU algorithm implementations in the Mojo programming language.


Categories

  • Arrays & Hashing β€” 12 problems
  • Strings β€” 7 problems
  • Greedy & Two Pointers β€” 2 problems
  • Search & Sort β€” 6 problems
  • Linked Lists β€” 1 problem
  • Dynamic Programming β€” 3 problems
  • Bit Manipulation & Math β€” 2 problems
  • Grid & Matrix β€” 3 problems
  • Mojo Concepts β€” 6 problems
  • GPU programming β€” 12 problems
  • Utilities β€” 1 problem
  • CUDA β€” 1 problem

Full source on GitHub

Two Sum

# Two Sum Problem:
# Given an array of integers `nums` and a target value `target`,
# return the indices of two numbers such that they add up to `target`.


def two_sum(nums: List[Int], target: Int) -> Tuple[Int, Int]:
    # Default return value: (-1, -1) if no valid pair is found
    indices = (-1, -1)

    # Early exit: If list has 0 or 1 elements, no pair can be formed
    if len(nums) <= 1:
        return indices

    # Create a dictionary to map each value to its index for quick lookup
    # Format: value_indices[value] = index
    var value_indices = Dict[Int, Int]()  # value -> index

    # Iterate through the array
    for idx in range(len(nums)):
        # Calculate the number needed to reach the target
        diff = target - nums[idx]

        # If this difference was seen before, we found the pair
        if diff in value_indices:
            # Retrieve the stored index of the matching number
            indices[0] = value_indices.get(diff).value()
            # Store the current index as the second of the pair
            indices[1] = idx
        else:
            # Store the current number and its index for future reference
            value_indices[nums[idx]] = idx

    # Return the result tuple
    return indices

def two_sum_costly(nums: List[Int], target: Int) -> Tuple[Int, Int]:
    if len(nums) <= 1:
        return (-1, -1)
    for i in range(len(nums)):
        for j in range(i + 1, len(nums)):
            if nums[i] + nums[j] == target:
                return (i, j)
    return (-1, -1)


def main():
    nums = [2, 7, 11, 15]
    target = 9
    indices = two_sum(nums, target)
    debug_assert(indices[0] == 0 and indices[1] == 1, "Assertion failed")
    target = 18
    indices = two_sum(nums, target)
    debug_assert(indices[0] == 1 and indices[1] == 2, "Assertion failed")
    target = 100
    indices = two_sum(nums, target)
    debug_assert(indices[0] == -1 and indices[1] == -2, "Assertion failed")

View source on GitHub

3Sum

This solution also returns indices of the triplets along with their values

import random
from collections import InlineArray

# Define a fixed-size array type to store triplets: each element is a tuple (value: Int, index: UInt)
comptime Triplet = InlineArray[(Int, UInt), 3]


# Comparison function used for sorting the array of tuples (value, original_index)
@parameter
def compare_fn(left: (Int, UInt), right: (Int, UInt)) -> Bool:
    return (
        left[0] < right[0]
    )  # Compare based on the actual value, not the index


# Function to find all unique triplets that sum up to 0
def find_triplets(nums: List[Int]) -> List[Triplet]:
    # Create a new list of tuples (value, original index)
    numbers = List[(Int, UInt)](capacity=len(nums))
    for idx in range(len(nums)):
        numbers.append((nums[idx], UInt(idx)))  # Keep original index

    # Sort the list based on the value of elements
    sort[compare_fn](numbers)

    # List to store the result triplets
    outcome = []

    # Iterate through the sorted list using a fixed pivot at index `idx`
    for idx in range(len(numbers) - 2):
        # Early stopping: if current value is greater than 0, we can't find a sum of 0
        if numbers[idx][0] > 0:
            break
        # Skip duplicates: avoid repeated first elements to prevent repeated triplets
        if idx > 0 and numbers[idx][0] == numbers[idx - 1][0]:
            continue

        # Two-pointer approach starts here
        left, right = idx + 1, len(numbers) - 1
        while left < right:
            # Calculate sum of the triplet
            sum = numbers[idx][0] + numbers[left][0] + numbers[right][0]
            if sum == 0:
                # Found a valid triplet that sums to 0
                outcome.append(
                    Triplet(numbers[idx], numbers[left], numbers[right])
                )
                # Move both pointers inward
                left += 1
                right -= 1

                # Skip duplicates after finding a valid triplet
                while left < right and numbers[left - 1][0] == numbers[left][0]:
                    left += 1
                while (
                    left < right and numbers[right][0] == numbers[right + 1][0]
                ):
                    right -= 1
            elif sum < 0:
                # Sum is too small, move left pointer to increase it
                left += 1
            else:
                # Sum is too large, move right pointer to decrease it
                right -= 1

    return outcome


# Helper function to nicely print triplet values and their original indices
def pretty_print(triplets: List[Triplet]):
    for triplet in triplets:
        print("Elem1 value: ", triplet[][0][0], ", index: ", triplet[][0][1])
        print("Elem2 value: ", triplet[][1][0], ", index: ", triplet[][1][1])
        print("Elem3 value: ", triplet[][2][0], ", index: ", triplet[][2][1])
        print()


# Entry point: generates a test list, shuffles it, finds triplets and prints them
def main():
    list = [-3, -3, -2, -1, 0, 1, 2, 2, 3]
    random.shuffle(list)  # Randomize input to simulate unordered input
    triplets = find_triplets(list)
    pretty_print(triplets)

View source on GitHub

4Sum

from collections import InlineArray

comptime Quadruplet = InlineArray[Int, 4]

# Function to find all unique triplets that sum up to target
def quadruplets(mut nums: List[Int], target: Int) -> List[Quadruplet]:
    result = []
    if len(nums) == 0:
        return result
    length = len(nums)
    sort(nums)
    for i in range(length - 3):
        if i > 0 and nums[i] == nums[i - 1]:
            continue
        for j in range(i + 1, length - 2):
            if j > i + 1 and nums[j - 1] == nums[j]:
                continue
            low, high = j + 1, length - 1
            while low < high:
                four_sum = nums[i] + nums[j] + nums[low] + nums[high]
                if four_sum == target:
                    result.append(
                        Quadruplet(nums[i], nums[j], nums[low], nums[high])
                    )
                    low += 1
                    high -= 1
                    while low < high and nums[low] == nums[low -1]:
                        low += 1
                    while low < high and nums[high] == nums[high + 1]:
                        high -= 1
                elif four_sum < target:
                    low += 1
                else:
                    high -= 1
    return result


def main():
    nums = [1, 0, -1, 0, -2, 2]
    target = 0
    result = quadruplets(nums, target)
    for each in result:
        print(each[][0], each[][1], each[][2], each[][3])

View source on GitHub

Product Of Array Except Self

# Function to compute the product of all elements in the list except the one at each index
def product_except_self(nums: List[Int]) -> List[Int]:
    # If the input list is empty, return it as is
    if len(nums) == 0:
        return nums

    # Store the length of the input list
    len = len(nums)

    # Initialize the result list with all 1s. This will store our final answer.
    result = [1] * len

    # prefix_product holds the product of all elements to the *left* of the current index
    prefix_product = 1
    for idx in range(len):
        # For each index, store the current prefix product
        result[idx] = prefix_product
        # Update the prefix product by multiplying it with the current number
        prefix_product *= nums[idx]

    # suffix_product holds the product of all elements to the *right* of the current index
    suffix_product = 1
    # Iterate from right to left
    for idx in range(len - 1, -1, -1):
        # Multiply the result at index with the current suffix product
        result[idx] *= suffix_product
        # Update the suffix product by multiplying with current number
        suffix_product *= nums[idx]

    return result


# Entry point
def main():
    # Example input
    nums = [1, 2, 3, 4]
    # Call the function to get result
    result = product_except_self(nums)  # Output should be [24, 12, 8, 6]
    # Print the result
    print(result.__str__())

View source on GitHub

Max Sum Sub Array

# Given an integer array nums, find the subarray with the largest sum, and return its sum.


# Function to find the subarray with the maximum sum
def max_sum_sub_array(nums: List[Int]) -> Int:
    # If the list is empty, return 0 as no subarray exists
    if len(nums) == 0:
        return 0

    # Initialize the running sum and max sum with the first element
    # running_sum: current subarray sum being tracked
    # max_sum: maximum subarray sum seen so far
    running_sum, max_sum = nums[0], nums[0]

    # Iterate over the list starting from the second element
    for idx in range(1, len(nums)):
        # Decide whether to extend the previous subarray or start a new subarray at current index
        running_sum = max(running_sum + nums[idx], nums[idx])

        # Update max_sum if the current running_sum is greater
        max_sum = max(max_sum, running_sum)

    # Return the maximum subarray sum found
    return max_sum


def main():
    nums = [-2, 1, -3, 4, -1, 2, 1, -5, 4]
    max_sum = max_sum_sub_array(nums)
    debug_assert(max_sum == 6, "Assertion failed")
    nums = [5,4,-1,7,8]
    max_sum = max_sum_sub_array(nums)
    debug_assert(max_sum == 23, "Assertion failed")

View source on GitHub

Max Average Subarray

# Find the maximum average of any contiguous subarray of length k(window size) from the array.

def find_max_average(read nums: List[Int], window_size: UInt) -> Float16:
    length = len(nums)
    if length == 0 or window_size == 0:
        return 0.0  # Return 0 if input is invalid

    var max_average: Float16 = 0.0
    window_sum = 0

    # Compute sum of the first 'window_size' elements
    for idx in range(window_size):
        window_sum += nums[idx]
    max_average = Float16(window_sum / window_size)  # Initialize max average

    # Slide the window over the array
    for idx in range(window_size, length):
        window_sum += nums[idx]                     # Add next element
        window_sum -= nums[idx - window_size]       # Remove the element going out of window
        average = Float16(window_sum / window_size) # Current window average
        max_average = max(max_average, average)     # Update max average if needed

    return max_average


def main():
    nums = [1, 12, -5, -6, 50, 3]
    window_size = 4
    max_average = find_max_average(nums, window_size)
    debug_assert(max_average == 12.75, "Assertion failed")  # Test case check

View source on GitHub

Max Subarray Product

# Function to find the maximum product of a contiguous subarray
def max_subarray_product(read nums: List[Int]) -> Int:
    # Handle edge case: empty array
    if len(nums) == 0:
        return 0
    # Handle edge case: single element
    elif len(nums) == 1:
        return nums[0]
    else:
        # Initialize max_product: if first element is 0, set to 1 temporarily
        max_product = 1 if nums[0] == 0 else nums[0]
        # Track both current max and min products (important for handling negatives)
        curr_max, curr_min = 1, 1

        # Iterate through all elements
        for idx in range(0, len(nums)):
            num = nums[idx]

            # Reset both max and min when zero is encountered (new subarray starts)
            if num == 0:
                curr_max, curr_min = 1, 1
                continue

            # Preserve previous curr_max for updating curr_min
            curr_max_copy = curr_max

            # Update current max and min by considering:
            # - current number alone
            # - product of current number with previous max
            # - product of current number with previous min (for negatives)
            curr_max = max(curr_max * num, curr_min * num, num)
            curr_min = min(curr_max_copy * num, curr_min * num, num)

            # Update the global max product
            max_product = max(max_product, curr_max)

        return max_product


# Entry point
def main():
    nums = [2, 3, -2, 4]  # Expected maximum product subarray: [2, 3] => 6
    max_product = max_subarray_product(nums)
    print(max_product)
    debug_assert(max_product == 6, "Assertion failed")

View source on GitHub

Water Container Max Area

def max_area(heights: List[Int]) -> Int:
    # If there are fewer than 2 lines, no container can be formed
    if len(heights) < 2:
        return 0

    left, right = 0, len(heights) - 1

    max_area = 0

    while left < right:
        # Height of container is limited by the shorter of the two lines
        min_height = min(heights[left], heights[right])

        # Calculate area formed between the two lines and update max_area if it's larger
        max_area = max(max_area, (right - left) * min_height)

        # Move the pointer that's at the shorter line inward to potentially find a taller line
        # This can potentially increase the area despite reducing the width
        if heights[left] <= heights[right]:
            left += 1
        else:
            right -= 1

    return max_area


from std.testing import assert_equal


def main():
    heights = [1, 8, 6, 2, 5, 4, 8, 3, 7]
    mx_area = max_area(heights)
    assert_equal(mx_area, 49, "Assertion failed")

    heights = [1, 1]
    mx_area = max_area(heights)
    assert_equal(mx_area, 1, "Assertion failed")

View source on GitHub

Largest Number

Arrange non-negative integers to form the largest possible number and return it as a string.

def largest_number(nums: List[Int]) raises -> String:
    if len(nums) == 0:
        return ""
    strs = List[String](capacity=len(nums))
    for each in nums:
        strs.append(String(each[]))
    sort[compare_fn](strs)
    result = StringSlice("").join(strs)
    return String(Int(result))


@parameter
def compare_fn(left: String, right: String) -> Bool:
    return left + right > right + left


from std.testing import assert_true


def main() raises:
    nums = [10, 2]
    result = largest_number(nums)
    assert_true(result == "210", "Assertion failed")

    nums = [3, 30, 34, 5, 9]
    result = largest_number(nums)
    assert_true(result == "9534330", "Assertion failed")

    nums = [0, 0, 0, 0, 0]
    result = largest_number(nums)
    assert_true(result == "0", "Assertion failed")

    nums = List[Int]()
    result = largest_number(nums)
    assert_true(result == "", "Assertion failed")

View source on GitHub

Remove duplicates from sorted array

This implementation mutates the array in place.

Post de-duplication, only the unique elements are retained.

After the last unique element, all entries which have been shifted are discarded.

def remove_duplicates(mut nums: List[Int]) -> None:
    # If the list has 0 or 1 element, it's already unique
    if len(nums) < 2:
        return

    # `left` points to the position where the next unique element should go
    left = 0

    # Start from second element and iterate through the list
    for right in range(1, len(nums)):
        # If a unique value is found (not equal to the previous one),
        # move it to the `left + 1` position
        if nums[right - 1] != nums[right]:
            left += 1
            nums[left] = nums[right]

    # After all unique elements are placed at the beginning of the list,
    # remove all remaining elements beyond the `left` index
    for _ in range(len(nums) - 1, left, -1):
        _ = nums.pop()  # Discard redundant elements


# Import testing helper for assertions
from std.testing import assert_equal


def main() raises:
    # Each test validates that the function keeps only unique sorted elements
    nums = [1, 1]
    remove_duplicates(nums)
    assert_equal(nums, [1], "Assertion failed")

    nums = [1, 1, 1]
    remove_duplicates(nums)
    assert_equal(nums, [1], "Assertion failed")

    nums = [1, 1, 1, 2]
    remove_duplicates(nums)
    assert_equal(nums, [1, 2], "Assertion failed")

    nums = [1, 1, 1, 2, 3, 3, 3, 5, 5, 6, 6, 8, 8, 8, 9, 10, 10]
    remove_duplicates(nums)
    assert_equal(nums, [1, 2, 3, 5, 6, 8, 9, 10], "Assertion failed")

    nums = [1, 1, 1, 2, 3, 3, 3, 5, 5, 6, 6, 8, 8, 8, 9, 10, 10, 11]
    remove_duplicates(nums)
    assert_equal(nums, [1, 2, 3, 5, 6, 8, 9, 10, 11], "Assertion failed")

View source on GitHub

Merge nums2 into nums1 (in-place)

Given two sorted arrays nums1 (size m + n, with m valid elements followed by n zeros) and nums2 (size n), merge them in-place into nums1 as one sorted array.

def merge(mut nums1: List[Int], nums2: List[Int]):
    if len(nums1) == 0 or len(nums2) == 0:
        return
    # Set pointer m to the last valid element in nums1 (i.e., excluding trailing zeros)
    m = len(nums1) - len(nums2) - 1

    # Set pointer n to the last element of nums2
    n = len(nums2) - 1

    # Set pointer last to the end of nums1 (i.e., last index where final element will go)
    last = len(nums1) - 1

    # Traverse both arrays from the end and fill nums1 from the back
    while m >= 0 and n >= 0:
        if nums1[m] >= nums2[n]:
            # If current nums1 element is greater, place it at 'last' and move pointers
            nums1[last] = nums1[m]
            m -= 1
        else:
            # Else, place nums2[n] at 'last' and move pointers
            nums1[last] = nums2[n]
            n -= 1
        last -= 1

    # If there are leftover elements in nums2 (i.e., nums2 had smaller elements)
    while n >= 0:
        nums1[last] = nums2[n]
        last -= 1
        n -= 1

    # No need to handle leftover nums1 elements, they are already in place 1


from std.testing import assert_equal


def main() raises:
    nums1 = [5, 8, 11, 13, 0, 0, 0]
    nums2 = [3, 9, 19]
    merge(nums1, nums2)
    assert_equal(nums1, [3, 5, 8, 9, 11, 13, 19], "Assertion failed")
    nums1 = [1, 2, 3, 0, 0, 0]
    nums2 = [2, 5, 6]
    merge(nums1, nums2)
    assert_equal(nums1, [1, 2, 2, 3, 5, 6], "Assertion failed")

    nums1 = [1]
    nums2 = []
    merge(nums1, nums2)
    assert_equal(nums1, [1], "Assertion failed")

View source on GitHub

Sum 1D Tensor

from layout import Layout, LayoutTensor
from algorithm import vectorize
from sys import simdwidthof


def summer[
    type: DType, layout: Layout, //, simdwidth: Int = simdwidthof[type]()
](
    tensor: LayoutTensor[type, layout, MutableAnyOrigin],
    start: Int = 0,
    end: Int = layout.size(),
) -> Scalar[type]:
    result = Scalar[type](0)

    @parameter
    def sum[simd_width: Int](idx: Int):
        result += tensor.load[width=simd_width](0, start + idx).reduce_add()

    vectorize[sum, simdwidth](end - start)
    return result


def main():
    from math import iota
    comptime elems_count = 1 << 10
    var array = InlineArray[Scalar[DType.uint32], elems_count](fill=0)
    iota(array.unsafe_ptr(), elems_count)
    tensor = LayoutTensor[
        DType.uint32, Layout.row_major(1, elems_count), MutableAnyOrigin
    ](array.unsafe_ptr())
    #print(tensor)
    start = 1022
    end = 1024
    #result = summer[16](tensor, start, end)
    result = summer(tensor)
    print(result)

View source on GitHub

String to Integer (atoi)

Implement the atoi(string s) function, which converts a string to a signed integer.

def atoi(s: String) raises -> Int:
    # Return 0 for an empty string (edge case)
    if len(s) == 0:
        return 0

    buffer = String()  # Buffer to collect valid characters forming the number
    digits = String("0123456789")  # Valid digit characters
    idx = 0  # Index for scanning the string
    prelude = True  # Indicates we're in the whitespace/prefix-skipping phase

    while idx < len(s):
        # Skip leading whitespaces
        while prelude and idx < len(s) and (s[idx] == " "):
            idx += 1
            continue

        # If we have skipped whitespaces, check for sign or digit
        if prelude and idx < len(s):
            # If current char is '+' or '-' or a digit, add it to buffer
            if s[idx] == "-" or s[idx] == "+" or s[idx] in digits:
                if not s[idx] == "+":
                    buffer.__iadd__(s[idx])
                idx += 1
                prelude = False  # Exit prelude phase after processing sign or first digit
                continue
            else:
                # Invalid character before number starts; exit parsing
                break

        # If we encounter a non-digit (excluding whitespace in middle), break
        if s[idx] != " " and s[idx] not in digits:
            break

        # If it's a digit, add to buffer
        if s[idx] in digits:
            buffer.__iadd__(s[idx])
        idx += 1

    # Now buffer contains something like "-123", "456", "+789", etc.
    number, factor = 0, 1

    # Convert from left to right (excluding the first char which could be a sign)
    for idx in range(len(buffer) - 1, 0, -1):
        number = number + Int(buffer[idx]) * factor
        factor *= 10

    # Handle the first character (either a sign or a digit)
    number = (
        -1 * number if buffer[0] == "-" else number + Int(buffer[0]) * factor
    ) if len(buffer) > 1 else number

    return number


from std.testing import assert_equal


def main() raises:
    s = "   -           13   37    c0d3"
    number = atoi(s)
    assert_equal(number, -1337)

    s1 = "13   37    c0d3"
    number = atoi(s1)
    assert_equal(number, 1337)

    s2 = "1337c0d3"
    number = atoi(s2)
    assert_equal(number, 1337)

    s3 = "   -042"
    number = atoi(s3)
    assert_equal(number, -42)

    s4 = "42"
    number = atoi(s4)
    assert_equal(number, 42)

    s5 = "0-1"
    number = atoi(s5)
    assert_equal(number, 0)

    s6 = "words and 987"
    number = atoi(s6)
    assert_equal(number, 0)

    s7 = "    words and 987"
    number = atoi(s7)
    assert_equal(number, 0)

    s8 = "+987"
    number = atoi(s8)
    assert_equal(number, 987)

    s9 = " + 98700 www"
    number = atoi(s9)
    assert_equal(number, 98700)

    s10 = " - 98700 www"
    number = atoi(s10)
    assert_equal(number, -98700)

View source on GitHub

Longest Substr No Char Repeats

# Longest Substring Without Repeating Characters
# Given a string s, find the length of the longest substring without duplicate characters.
from collections import Set


# Function to compute the length of the longest substring without repeating characters.
# Takes a StringLiteral `s` and returns an unsigned 16-bit integer.
def len_longest_substr_no_char_repeats(s: StringLiteral) raises -> UInt16:
    # If the input string is empty, return 0 immediately
    if len(s) == 0:
        return 0

    # Initialize a Set to keep track of characters in the current window (substring)
    # Start by adding the first character of the string
    seen = Set(s[0])

    # `left` is the left boundary of the current sliding window (start index)
    left = 0

    # Initial max_length is 1 since we already have one character in the set
    max_length = 1

    # Start iterating from the second character to the end of the string
    for idx in range(1, len(s)):
        # If the current character is already in the set, it means a repetition
        # So we move the `left` boundary forward until we remove the duplicate
        while s[idx] in seen:
            seen.remove(s[left])
            left += 1

        # Add the current character to the set
        seen.add(s[idx])

        # Update max_length with the size of the current window (set size)
        max_length = max(max_length, len(seen))

    # Return the maximum length found
    return max_length


def main() raises:
    s = "abcabcbb"
    print(len_longest_substr_no_char_repeats(s))  # 3
    s = "bbbbb"
    print(len_longest_substr_no_char_repeats(s))  # 1
    s = "pwwkew"
    print(len_longest_substr_no_char_repeats(s))  # 3

View source on GitHub

Longest Palindromic Substring

Given a string s, return the longest palindromic substring therein

def longest_palindrome(s: String) -> String:
    if len(s) == 0:
        return s
    longest = String("")
    for i in range(len(s)):
        left, right = i, i
        find_longest(left, right, s, longest)
        left, right = i, i + 1
        find_longest(left, right, s, longest)
    return longest


def find_longest(mut left: Int, mut right: Int, s: String, mut longest: String):
    while 0 <= left and right < len(s) and s[left] == s[right]:
        if right - left + 1 > len(longest):
            longest = s[left : right + 1]
        left -= 1
        right += 1


from std.testing import assert_true


def main() raises:
    var s: String = "babad"
    var expected: String = "bab"
    result = longest_palindrome(s)
    assert_true(result == expected, "Assertion failed")

    s = "cbbd"
    expected = "bb"
    result = longest_palindrome(s)
    assert_true(result == expected, "Assertion failed")

    s = "racecar"
    expected = "racecar"
    result = longest_palindrome(s)
    assert_true(result == expected, "Assertion failed")

View source on GitHub

Interleaving String

Determine if s3 is an interleaving of s1 and s2.

def is_interleave(s1: String, s2: String, s3: String) -> Bool:
    if len(s1) + len(s2) != len(s3):
        return False
    dp = List[List[Bool]](
        length=len(s1) + 1, fill=List[Bool](length=len(s2) + 1, fill=False)
    )
    dp[len(s1)][len(s2)] = True
    for i in range(len(s1), -1, -1):
        for j in range(len(s2), -1, -1):
            if i < len(s1) and s1[i] == s3[i + j] and dp[i + 1][j]:
                dp[i][j] = True
            if j < len(s2) and s2[j] == s3[i + j] and dp[i][j + 1]:
                dp[i][j] = True

    return dp[0][0]


from std.testing import assert_true, assert_false


def main() raises:
    var s1: String = "aabcc"
    var s2: String = "dbbca"
    var s3: String = "aadbbcbcac"
    result = is_interleave(s1, s2, s3)
    assert_true(result, "Assertion failed")

    s1 = "aabcc"
    s2 = "dbbca"
    s3 = "aadbbbaccc"
    result = is_interleave(s1, s2, s3)
    assert_false(result, "Assertion failed")

    s1 = ""
    s2 = ""
    s3 = ""
    result = is_interleave(s1, s2, s3)
    assert_true(result, "Assertion failed")

View source on GitHub

Last word length

Return the length of the last word in a given space-separated string

def last_word_length(s: String) -> Int:
    if len(s) == 0:
        return 0
    i, length = len(s) - 1, 0
    while i >= 0 and s[i] == " ":
        i -= 1
    while i >= 0 and s[i] != " ":
        length += 1
        i -= 1
    return length


def main() raises:
    from std.testing import assert_true

    result = last_word_length("Hello World")
    assert_true(result == 5, "Assertion failed")

    result = last_word_length("         ")
    assert_true(result == 0, "Assertion failed")

    result = last_word_length("   fly me   to   the moon  ")
    assert_true(result == 4, "Assertion failed")

    result = last_word_length("luffy is still joyboy")
    assert_true(result == 6, "Assertion failed")

View source on GitHub

Find All Anagrams in a String

Return all start indices of anagrams of stringβ€―pβ€―in stringβ€―s.

from collections import Dict


def find_anagrams(s: String, p: String) -> List[Int]:
    if len(s) < len(p):
        return List[Int]()
    pdict = Dict[String, Int]()
    sdict = Dict[String, Int]()
    for i in range(len(p)):
        pdict[p[i]] = 1 + pdict.get(p[i], 0)
        sdict[s[i]] = 1 + sdict.get(s[i], 0)
    result = [0] if equals(pdict, sdict) else List[Int]()
    left, right = 0, len(p)
    while right < len(s):
        sdict[s[right]] = 1 + sdict.get(s[right], 0)
        sdict[s[left]] = sdict.get(s[left], 1) - 1
        if sdict.get(s[left]) and sdict.get(s[left]).value() == 0:
            try:
                _ = sdict.pop(s[left])
            except e:
                print(e)
        right += 1
        left += 1
        if equals(pdict, sdict):
            result.append(left)
    return result


def equals(read d1: Dict[String, Int], read d2: Dict[String, Int]) -> Bool:
    if len(d1) != len(d2):
        return False
    for each in d1.items():
        try:
            if each[].value != d2[each[].key]:
                return False
        except e:
            return False
    return True


from std.testing import assert_equal


def main() raises:
    var s: String = "abc"
    var p: String = "abc"
    result = find_anagrams(s, p)
    assert_equal(result, [0], "Assertion failed")

    s = "cba"
    p = "abc"
    result = find_anagrams(s, p)
    assert_equal(result, [0], "Assertion failed")

    s = "cbaa"
    p = "abc"
    result = find_anagrams(s, p)
    assert_equal(result, [0], "Assertion failed")

    s = "cbaacb"
    p = "abc"
    result = find_anagrams(s, p)
    assert_equal(result, [0, 3], "Assertion failed")

    s = "cbaebabacd"
    p = "abc"
    result = find_anagrams(s, p)
    assert_equal(result, [0, 6], "Assertion failed")

    s = "abab"
    p = "ab"
    result = find_anagrams(s, p)
    assert_equal(result, [0, 1, 2], "Assertion failed")

View source on GitHub

Word Search

Check if a word can be formed in a grid by sequentially adjacent (non-repeating) horizontal or vertical letters.

from collections import Set

comptime SString = List[StaticString]
comptime SStrings = List[SString]


def present(board: SStrings, word: StaticString) raises -> Bool:
    if len(board) == 0:
        return False
    rows, cols = len(board), len(board[0])
    visited = Set[String]()
    for row in range(rows):
        for col in range(cols):
            if trace(rows, cols, board, word, 0, row, col, visited):
                return True
    return False


def trace(
    rows: UInt,
    cols: UInt,
    board: SStrings,
    word: StaticString,
    idx: UInt,
    row: UInt,
    col: UInt,
    mut visited: Set[String],
) raises -> Bool:
    if len(word) == idx:
        return True
    cell = String(row) + String(col)
    if (
        row < 0
        or row >= rows
        or col < 0
        or col >= cols
        or board[row][col] != word[idx]
        or cell in visited
    ):
        return False
    visited.add(cell)
    exists = (
        trace(rows, cols, board, word, idx + 1, row + 1, col, visited)
        or trace(rows, cols, board, word, idx + 1, row - 1, col, visited)
        or trace(rows, cols, board, word, idx + 1, row, col + 1, visited)
        or trace(rows, cols, board, word, idx + 1, row, col - 1, visited)
    )
    visited.remove(cell)
    return exists


from std.testing import assert_true, assert_false


def main() raises:
    board = SStrings(
        SString("A", "B", "C", "E"),
        SString("S", "F", "C", "S"),
        SString("A", "D", "E", "E"),
    )
    word1 = "ABCB"
    result = present(board, word1)
    assert_false(result, "Assertion failed")
    board = SStrings(
        SString("A", "B", "C", "E"),
        SString("S", "F", "C", "S"),
        SString("A", "D", "E", "E"),
    )
    word2 = "SEE"
    result = present(board, word2)
    assert_true(result, "Assertion failed")
    board = SStrings(
        SString("A", "B", "C", "E"),
        SString("S", "F", "C", "S"),
        SString("A", "D", "E", "E"),
    )
    word3 = "ABCCED"
    result = present(board, word3)
    assert_true(result, "Assertion failed")

View source on GitHub

Buy And Sell Stock

# Function to calculate the maximum profit from a list of stock prices
# where you are allowed to make only one buy and one sell transaction.
# You must buy before you sell.


def max_profit(prices: List[UInt]) -> UInt:
    var max_profit: UInt = 0  # Stores the maximum profit found so far
    buy_day = 0  # Pointer to track the day to buy the stock
    sell_day = 1  # Pointer to track the day to sell the stock

    # Loop until sell_day reaches the end of the price list
    while sell_day < len(prices):
        if prices[buy_day] < prices[sell_day]:
            # If selling is profitable, calculate profit and update max
            max_profit = max(max_profit, prices[sell_day] - prices[buy_day])
        else:
            # If current sell_day is cheaper than buy_day, shift buy_day
            buy_day = sell_day
        sell_day += 1  # Move to the next day

    return max_profit  # Return the highest profit found


def main():
    # First test: Best profit is buying at 1 and selling at 6 => profit = 5
    prices = [7, 1, 5, 3, 6, 4]
    debug_assert(max_profit(prices) == 5, "Assertion failed")

    # Second test: No profitable day to sell => profit = 0
    prices = [7, 6, 4, 3, 1]
    debug_assert(max_profit(prices) == 0, "Assertion failed")

View source on GitHub

Buy And Sell Stock 2

# On each day, you may decide to buy and/or sell the stock. You can only hold
# at most one share of the stock at any time. However, you can buy it then
# immediately sell it on the same day.

# Find and return the maximum profit you can achieve.
# Note the problem is confusing - given per day prices, you make no profit buying
# and selling on the same day! 

def total_profit(prices: List[UInt]) -> UInt:
    var total_profit: UInt = 0  # Stores the maximum profit found so far

    # Loop until sell day reaches the end of the price list
    for day in range(1, len(prices)):
        if prices[day - 1] < prices[day]:
            # If selling is profitable, sell it & add up profit
            total_profit += prices[day] - prices[day - 1]
    return total_profit  # Return the highest profit found


def main():
    # First test: Best profit is buying at 1 and selling at 6 => profit = 5
    prices = [7, 1, 5, 3, 6, 4]
    debug_assert(total_profit(prices) == 7, "Assertion failed")

    # Second test: No profitable day to sell => profit = 0
    prices = [7, 6, 4, 3, 1]
    debug_assert(total_profit(prices) == 0, "Assertion failed")
    # 3rd test: Best profit is buying at every day and selling next day
    prices = [1, 2, 3, 4, 5]
    debug_assert(total_profit(prices) == 4, "Assertion failed")

View source on GitHub

Min In Sorted Rotated Arr

# Given a sorted, rotated array `nums`, find and return the minimum element.
# The solution uses binary search to achieve O(log n) time complexity.

def find_min(nums: List[Int]) -> Int:
    # Handle edge case: empty list
    if len(nums) == 0:
        return Int.MIN  # Return the minimum representable integer

    # Handle edge case: single-element list
    elif len(nums) == 1:
        return nums[0]

    else:
        # Initialize binary search pointers
        left, right = 0, len(nums) - 1
        curr_min = nums[0]  # Assume first element is the minimum initially

        while left <= right:
            # If the current window is sorted, the smallest element is at the left
            if nums[left] <= nums[right]:
                return min(curr_min, nums[left])

            # Compute the mid index
            mid = (left + right) // 2

            # Update the minimum seen so far
            cur_min = min(curr_min, nums[mid])

            # Determine which side is unsorted (contains the pivot)
            if nums[mid] >= nums[left]:
                # Left half is sorted, so min must be in the right half
                left = mid + 1
            else:
                # Right half is sorted, so min must be in the left half (including mid)
                right = mid - 1

        # Fallback return β€” should never be reached in a rotated sorted array
        return curr_min  # Keeps compiler happy

def main():
    nums = List[Int]()
    minimum = find_min(nums)
    debug_assert(minimum == Int.MIN, "Assertion failed")
    nums = [4, 5, 6, 7, 0, 1, 2]
    minimum = find_min(nums)
    debug_assert(minimum == 0, "Assertion failed")
    nums = [4]
    minimum = find_min(nums)
    debug_assert(minimum == 4, "Assertion failed")
    nums = [4, 5]
    minimum = find_min(nums)
    debug_assert(minimum == 4, "Assertion failed")
    nums = [5, 4]
    minimum = find_min(nums)
    debug_assert(minimum == 4, "Assertion failed")
    nums = [3, 4, 5, 1, 2]
    minimum = find_min(nums)
    debug_assert(minimum == 1, "Assertion failed")
    nums = [11, 13, 15, 17]
    minimum = find_min(nums)
    debug_assert(minimum == 11, "Assertion failed")
    nums = [11, 13, 15, 17, 1, 1, 2, 2]
    minimum = find_min(nums)
    debug_assert(minimum == 1, "Assertion failed")

View source on GitHub

Search Sorted Rotated Arr

# Search in Rotated Sorted Array

# Function to search for a target in a rotated sorted array
comptime ItemType = ComparableCollectionElement


def find[ItmType: ItemType](read items: List[ItmType], target: ItmType) -> Int:
    if len(items) == 0:
        return -1

    # Initialize pointers for binary search
    left, right = 0, len(items) - 1

    # Perform binary search
    while left <= right:
        mid = (left + right) // 2  # Calculate middle index

        # If the middle element is the target, return the index
        if items[mid] == target:
            return mid

        # Determine which half is sorted
        if items[mid] >= items[left]:
            # Left half is sorted

            # Check if target lies outside the sorted left half
            if target < items[left] or target > items[mid]:
                # Target is in the right half
                left = mid + 1
            else:
                # Target is in the left half
                right = mid - 1
        else:
            # Right half is sorted

            # Check if target lies outside the sorted right half
            if target > items[right] or target < items[mid]:
                # Target is in the left half
                right = mid - 1
            else:
                # Target is in the right half
                left = mid + 1

    # Target not found
    return -1


def main():
    # Example 1: Target exists in the array
    items = [4, 5, 6, 7, 0, 1, 2]
    target = 0
    # Expected output: 4 (index of 0)
    # debug_assert(find(items, target) == 4, "Assertion failed")

    # Example 2: Target does not exist
    items = [4, 5, 6, 7, 0, 1, 2]
    target = 3
    # Expected output: -1
    # debug_assert(find(items, target) == -1, "Assertion failed")

    # Example 3: Single-element array, target not present
    items = [1]
    target = 0
    # Expected output: -1
    debug_assert(find(items, target) == -1, "Assertion failed")

View source on GitHub

Search Sorted Rotated Arr With Duplicates

Generic function to find the index of a target element in a rotated sorted list.

Works for any type that implements ComparableCollectionElement (e.g., Int, Float, etc.).

# Define an alias for types that support comparison operations.
comptime ItemType = ComparableCollectionElement


def find[ItmType: ItemType](read items: List[ItmType], target: ItmType) -> Int:
    if len(items) == 0:
        return -1
    left, right = 0, len(items) - 1

    # Perform modified binary search to handle rotation and duplicates
    while left <= right:
        mid = left + (right - left) // 2

        # Target found at midpoint
        if items[mid] == target:
            return mid

        # Case 1: Target is less than midpoint value
        if target < items[mid]:
            # If target is greater than the left bound, it must be in the left subarray
            if target > items[left]:
                right = mid - 1
            # If target is less than the left bound, it must be in the right subarray
            elif target < items[left]:
                left = mid + 1
            # If target equals the left bound, it's a match
            else:
                return left

        # Case 2: Target is greater than midpoint value
        else:
            # If target is less than the right bound, it lies in the right subarray
            if target < items[right]:
                left = mid + 1
            # If target is greater than the right bound, it must lie to the left
            elif target > items[right]:
                right = mid - 1
            # If target equals the right bound, it's a match
            else:
                return right

    # Target not found
    return -1


from std.testing import assert_equal


def main() raises:
    items = [7, 7, 8, 9, 10, 10, 12, 1, 2, 3, 3, 4, 4, 5, 5, 6]
    targets = [12, 7, 3, 10, 1, 6, 4]
    expected_indices = [6, 0, 9, 5, 7, 15, 11]

    results = List[Int](capacity=len(targets))
    for i in range(len(targets)):
        index = find(items, targets[i])
        results.append(index)
    assert_equal(expected_indices, results, "Assertion failed!")

    items = [6]
    target = 6
    index = find(items, target)
    assert_equal(index, 0, "Assertion failed")

    items = [7, 7, 7]
    target = 7
    index = find(items, target)
    assert_equal(index, 1, "Assertion failed")

    items = [7, 8, 4, 5]
    target = 3
    index = find(items, target)
    assert_equal(index, -1, "Assertion failed")

View source on GitHub

Find First/Last

Find first and last index of a target value in a sorted array

def find_first_last(arr: List[Int], target: Int) -> (Int, Int):
    result = (-1, -1)
    if len(arr) == 0:
        return result
    left, right = 0, len(arr) - 1

    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            result[1] = mid
            left = mid + 1
        elif arr[mid] > target:
            right = mid - 1
        else:
            left = mid + 1
    left, right = (
        0,
        result[1],
    )  # result[1] -1 would keep left index at -1 for single occurence of target

    while left <= right:
        mid = (left + right) // 2
        if arr[mid] == target:
            result[0] = mid
            right = mid - 1
        elif arr[mid] > target:
            right = mid - 1
        else:
            left = mid + 1
    return result


from std.testing import assert_true


def main() raises:
    arr = [5, 7, 7, 8, 8, 10]
    target = 8
    result = find_first_last(arr, target)
    assert_true(result[0] == 3 and result[1] == 4, "Assertion failed")
    target = 6
    result = find_first_last(arr, target)
    assert_true(result[0] == -1 and result[1] == -1, "Assertion failed")

    arr = [5, 7, 7, 8, 10]
    target = 8
    result = find_first_last(arr, target)
    assert_true(result[0] == 3 and result[1] == 3, "Assertion failed")

View source on GitHub

Merge Itervals

Merge overlapping intervals

@parameter
def compare_fn(interval1: (Int, Int), interval2: (Int, Int)) -> Bool:
    return interval1[0] < interval2[0]


def merge_intervals(mut intervals: List[(Int, Int)]) -> List[(Int, Int)]:
    if len(intervals) == 0:
        return List[(Int, Int)]()

    sort[compare_fn](intervals)

    result = List[(Int, Int)]()
    result.append(intervals[0])
    for curr_interval in intervals[1:]:
        start, end = curr_interval[]
        last_interval = result[-1]
        last_start, last_end = last_interval
        if start <= last_end:
            result[len(result) - 1] = (last_start, max(last_end, end))
        else:
            result.append(curr_interval[])

    return result


from std.testing import assert_true


def main() raises:
    intervals = List[(Int, Int)]((1, 3), (2, 6), (8, 10), (15, 18))
    expected = List[(Int, Int)]((1, 6), (8, 10), (15, 18))
    result = merge_intervals(intervals)
    i = 0
    for each in result:
        assert_true(
            each[][0] == expected[i][0] and each[][1] == expected[i][1],
            "Assertion failed",
        )
        i += 1

    intervals = List[(Int, Int)]((1, 4), (4, 5))
    expected = List[(Int, Int)]((1, 5))
    result = merge_intervals(intervals)
    i = 0
    for each in result:
        assert_true(
            each[][0] == expected[i][0] and each[][1] == expected[i][1],
            "Assertion failed",
        )
        i += 1

View source on GitHub

Shapes

trait Shape(ComparableCollectionElement):
    def area(self) -> UInt:
        ...

@value
struct Rectangle(Shape):
    var length: UInt
    var width: UInt

    def area(self) -> UInt:
        return self.length * self.width

    def __lt__(self, other: Self) -> Bool:
        return self.area() < other.area()

    def __le__(self, other: Self) -> Bool:
        return self.area() <= other.area()

    def __eq__(self, other: Self) -> Bool:
        return self.area() == other.area()

    def __ne__(self, other: Self) -> Bool:
        return self.area() != other.area()

    def __gt__(self, other: Self) -> Bool:
        return self.area() > other.area()

    def __ge__(self, other: Self) -> Bool:
        return self.area() >= other.area()

View source on GitHub

Generic singnly link list

A singly linked list with parametric polymorphism.

Current supports adding multiple elements at one go via the append method

from memory import Pointer, UnsafePointer

comptime ElementType = CollectionElement


@value
struct Node[
    T: ElementType,
]:
    comptime NextNode = UnsafePointer[Self]
    var value: T
    var next: Self.NextNode

    def __init__(
        out self,
        owned value: T,
    ):
        self.value = value
        self.next = Self.NextNode()

    def __init__(
        out self,
        owned value: T,
        next: Optional[Self.NextNode],
    ):
        self.value = value^
        self.next = next.value() if next else Self.NextNode()

    def __bool__(self) -> Bool:
        return True

    def __str__[
        ElementType: WritableCollectionElement
    ](self: Node[ElementType]) -> String:
        return String.write(self.value)


struct LinkedList[T: ElementType](Sized):
    var head: Optional[Node[T]]
    var len: UInt

    def __init__(out self):
        self.head = None
        self.len = 0

    def __len__(self) -> Int:
        return self.len

    def __init__(out self, *elems: T):
        self = Self()
        self.append(elems)

    def append(mut self, *elems: T):
        self.append(elems)

    def append(mut self, elems: VariadicListMem[T]):
        if len(elems) == 0:
            return
        next = 0
        var current: UnsafePointer[Node[T]]
        if self.head is None:
            self.head = Optional(Node(elems[0]))
            current = UnsafePointer(to=self.head.value())
            next = 1
            self.len += 1
        else:
            curr = UnsafePointer(to=self.head.value())
            while curr and curr[].next:
                curr = curr[].next
            current = curr
        for i in range(next, len(elems)):
            node = Node(elems[i])
            current[].next = UnsafePointer[Node[T]].alloc(1)
            current[].next.init_pointee_move(node)
            current = current[].next
            self.len += 1

    def __str__[
        ElementType: WritableCollectionElement
    ](self: LinkedList[ElementType]) -> String:
        if self.len == 0:
            return String("[]")
        else:
            s = String("[")
            current = self.head.value()
            s.write(current.value)
            for i in range(1, self.len):
                next = current.next[]
                if i <= self.len - 1:
                    s.write(", ")
                s.write(next.value)
                current = next
            s.write("]")
            return s

    def __iter__(self) -> _LinkedListIter[T, __origin_of(self)]:
        return _LinkedListIter(Pointer(to=self))


@value
struct _LinkedListIter[
    mut: Bool, //,
    ElementType: CollectionElement,
    origin: Origin[mut],
]:
    var src: Pointer[LinkedList[ElementType], origin]
    var curr: UnsafePointer[Node[ElementType]]
    var moved: Int

    def __init__(out self, src: Pointer[LinkedList[ElementType], origin]):
        self.src = src
        self.curr = UnsafePointer(to=self.src[].head.value())
        self.moved = 0

    def __itr__(self) -> Self:
        return self

    def __next__(mut self) -> Pointer[ElementType, origin]:
        out = Pointer[ElementType, origin](to=self.curr[].value)
        self.moved += 1
        self.curr = self.curr[].next
        return out

    def __has_next__(self) -> Bool:
        return self.curr.__bool__()

    def __len__(self) -> Int:
        return self.src[].len - self.moved


def main():
    linkedlist = LinkedList[Int]()
    print(linkedlist.__str__())

    linkedlist = LinkedList(1)
    print(linkedlist.__str__())

    linkedlist = LinkedList(1, 2, 3)
    print(linkedlist.__str__())

    linkedlist.append(4, 5, 6)
    print(linkedlist.len)
    print(linkedlist.__str__())
    for e in linkedlist:
        print(e[].__str__())

View source on GitHub

Longest Common Subsequence Recursive

Return the length of the longest common subsequence between two strings, or 0 if none exists

def longest_subseq(mut text1: String, mut text2: String) raises -> Int:
    if len(text1) == 0 or len(text2) == 0:
        return 0
    if text1[len(text1) - 1] == text2[len(text2) - 1]:
        text1 = text1[0:-1]
        text2 = text2[0:-1]
        return 1 + longest_subseq(text1, text2)
    else:
        text_1 = text1[0:-1]
        text_2 = text2[0:-1]
        count1 = longest_subseq(text1, text_2)
        count2 = longest_subseq(text2, text_1)
        return max(count1, count2)


from std.testing import assert_equal


def main() raises:
    var text1: String = "abcde"
    var text2: String = "ace"
    result = longest_subseq(text1, text2)
    assert_equal(result, 3, "Assertion failed")

    text1 = "abc"
    text2 = "abc"
    result = longest_subseq(text1, text2)
    assert_equal(result, 3, "Assertion failed")

    text1 = "abc"
    text2 = "xyz"
    result = longest_subseq(text1, text2)
    assert_equal(result, 0, "Assertion failed")

View source on GitHub

Longest Common Subsequence Dynamic

Return the length of the longest common subsequence between two strings, or 0 if none exists

def longest_subseq(mut text1: String, mut text2: String) raises -> Int:
    if len(text1) == 0 or len(text2) == 0:
        return 0
    dp = List[List[Int]](
        length=len(text1) + 1, fill=List[Int](length=len(text2) + 1, fill=0)
    )
    for i in range(len(text1) - 1, -1, -1):
        for j in range(len(text2) - 1, -1, -1):
            if text1[i] == text2[j]:
                dp[i][j] = 1 + dp[i + 1][j + 1]
            else:
                dp[i][j] = max(dp[i][j + 1], dp[i + 1][j])
    return dp[0][0]


from std.testing import assert_equal


def main() raises:
    var text1: String = "abcde"
    var text2: String = "ace"
    result = longest_subseq(text1, text2)
    assert_equal(result, 3, "Assertion failed")

    text1 = "abc"
    text2 = "abc"
    result = longest_subseq(text1, text2)
    assert_equal(result, 3, "Assertion failed")

    text1 = "abc"
    text2 = "xyz"
    result = longest_subseq(text1, text2)
    assert_equal(result, 0, "Assertion failed")

View source on GitHub

Combination sum

Find all unique combinations of numbers from candidates (reusable unlimited times) that sum to target.

def combination_sum(candidates: List[Int], target: Int) -> List[List[Int]]:
    combinations = List[List[Int]]()
    if len(candidates) == 0:
        return combinations

    var curr_combination = []
    find_combinations(candidates, 0, curr_combination, 0, target, combinations)
    return combinations

def find_combinations(
    candidates: List[Int],
    curr_index: Int,
    mut curr_combination: List[Int],
    total: Int,
    target: Int,
    mut combinations: List[List[Int]],
):
    if total == target:
        copy = curr_combination.copy()
        sort(copy)  # For validation
        combinations.append(copy)
        return
    if curr_index >= len(candidates) or total > target:
        return
    curr_combination.append(candidates[curr_index])
    find_combinations(
        candidates,
        curr_index,
        curr_combination,
        total + candidates[curr_index],
        target,
        combinations,
    )
    _ = curr_combination.pop()
    find_combinations(
        candidates,
        curr_index + 1,
        curr_combination,
        total,
        target,
        combinations,
    )


from std.testing import assert_true


def main() raises:
    candidates = [2, 3, 6, 7]
    target = 7
    var result: List[List[Int]] = combination_sum(candidates, target)
    expected = [2, 2, 3]
    count = 0
    for each in result:
        if each[] == expected:
            count += 1
    assert_true(count == 1, "assertion failed")
    expected = [7]
    count = 0
    for each in result:
        if each[] == expected:
            count += 1
    assert_true(count == 1, "assertion failed")

    candidates = [2, 3, 5]
    target = 8
    result = combination_sum(candidates, target)
    expected = [2, 2, 2, 2]
    count = 0
    for each in result:
        if each[] == expected:
            count += 1
    assert_true(count == 1, "assertion failed")
    expected = [2, 3, 3]
    count = 0
    for each in result:
        if each[] == expected:
            count += 1
    assert_true(count == 1, "assertion failed")
    expected = [3, 5]
    count = 0
    for each in result:
        if each[] == expected:
            count += 1
    assert_true(count == 1, "assertion failed")
    candidates = [2]
    target = 1
    result = combination_sum(candidates, target)
    assert_true(len(result) == 0, "assertion failed")

View source on GitHub

Num Ones

def count_bits(mut num: Int) -> Int:
	result = 0
	while num:
		result += num % 2
		num = num >> 1
	return result

def count_bits2(mut num: Int) -> Int:
	result = 0
	while num:
		result += num & 1
		num >>= 1
	return result

from std.testing import assert_equal

def main() raises:
	num = 11
	assert_equal(3, count_bits(num))
	num = 128
	assert_equal(1, count_bits(num))
	num = 2147483645
	assert_equal(30, count_bits(num))

	num = 11
	assert_equal(3, count_bits2(num))
	num = 128
	assert_equal(1, count_bits2(num))
	num = 2147483645
	assert_equal(30, count_bits2(num))

View source on GitHub

Count 1s For Each Entry

Return an array where each element at index i (0 ≀ i ≀ n) is the number of 1’s in the binary representation of i.

def count_bits(n: Int) -> List[Int]:
    result = List(length=n + 1, fill=0)
    power = 1
    for i in range(1, n + 1):
        if i == power * 2:
            power = i
        result[i] = 1 + result[i - power]

    return result


from std.testing import assert_equal


def main() raises:
    n = 2
    result = count_bits(n)
    assert_equal([0, 1, 1], result, "Assertion failed")

    n = 5
    result = count_bits(n)
    assert_equal([0, 1, 1, 2, 1, 2], result, "Assertion failed")

View source on GitHub

Game Of Life

from gridv1 import Grid as GridV1
from gridv2 import Grid as GridV2
from utils import Variant
import random

comptime Grid = Variant[GridV1, GridV2]


def run(owned grid: Grid) raises -> None:
    var inner: GridV1
    if grid.isa[GridV1]():
        inner = grid[GridV1]
        print("Received a grid of type V1")
    else:
        inner = GridV1(grid[GridV2])
        print("Received a grid of type V2 - converted to V1")
    while True:
        print("Current mutation:\n\n")
        print(inner)
        print()
        print()
        if input("Enter 'q' to quit or press <Enter> to continue: ") == "q":
            break
        inner.mutate()


def main() raises -> None:
    random.seed()
    var grid: Grid
    if random.random_ui64(0, 1):
        v1 = GridV1.new(16, 16)
        grid = Grid(v1)
    else:
        v2 = GridV2.new(None, 16, 16)
        grid = Grid(v2)
    run(grid^)

View source on GitHub

Gridv1

import random
from gridv2 import Grid as GridV2


# Grid is a 2D structure holding cell states (0: dead, 1: alive)
# It supports string conversion, output writing, and cell access/update
@value
struct Grid(Stringable, Writable):
    var data: List[List[Int, True]]  # 2D grid of integers (1 = alive, 0 = dead)

    # Constructor to initialize the grid with given data
    def __init__(out self, data: List[List[Int, True]]):
        self.data = data

    @implicit
    def __init__(out self, source: GridV2):
        data = source.data
        rows = source.rows
        cols = source.cols
        grid = List[List[Int, True]]()
        for row in range(rows):
            curr_row = List[Int, True]()
            for col in range(cols):
                curr_row.append(Int((data + (row * cols + col))[]))
            grid.append(curr_row)
        self.data = grid^

    # Get the number of rows in the grid
    def row_count(self) -> Int:
        if self.data:
            return len(self.data)
        else:
            return 0

    # Get the number of columns in the grid
    def col_count(self) -> Int:
        if self.data[0]:
            return len(self.data[0])
        else:
            return 0

    # Convert the grid to a string for pretty-printing
    def __str__(self) -> String:
        capacity = self.row_count() * self.col_count()
        if capacity == 0:
            return String()
        s = String(capacity=capacity)
        row_index = 0
        for row in self.data:
            for col in row[]:
                if col[] == 1:
                    s += "*"  # Alive cell represented by '*'
                else:
                    s += " "  # Dead cell is blank
            if row_index != self.row_count() - 1:
                s += "\n"  # Line break between rows
            row_index += 1
        return s

    # Allow writing the grid to any output writer
    def write_to[W: Writer](self, mut writer: W) -> None:
        writer.write(self.__str__())

    # Access cell at (row, col)
    def __getitem__(self, row: Int, col: Int) -> Int:
        return self.data[row][col]

    # Update cell at (row, col)
    def __setitem__(mut self, row: Int, col: Int, value: Int) -> None:
        self.data[row][col] = value

    # Static method to create a random grid with specified size
    @staticmethod
    def new(rows: Int, cols: Int) -> Self:
        random.seed()
        data = List[List[Int, True]](capacity=rows)
        for row in range(rows):
            record = List[Int, True](capacity=cols)
            for col in range(cols):
                # Initialize each cell randomly to 0 or 1
                record.append(Int(random.random_si64(0, 1)))
            data.append(record)
        return Self(data)

    # Perform one step of mutation (Game of Life rules)
    def mutate(mut self):
        rows = self.row_count()
        cols = self.col_count()
        for row in range(rows):
            above = (row - 1) % rows
            below = (row + 1) % rows
            for col in range(cols):
                left = (col - 1) % cols
                right = (col + 1) % cols

                # Count live neighbors using 8-connected grid
                alive_neighbours = (
                    self[above, left]
                    + self[above, col]
                    + self[above, right]
                    + self[row, right]
                    + self[below, right]
                    + self[below, col]
                    + self[below, left]
                    + self[row, left]
                )

                # Apply Conway's Game of Life rules:
                # Rule 1 & 2: Any live cell with 2 or 3 live neighbors survives
                if self[row, col] == 1 and (
                    alive_neighbours == 2 or alive_neighbours == 3
                ):
                    continue  # Keep alive

                # Rule 3: All other live cells die
                else:
                    self[row, col] = 0

                # Rule 4: Any dead cell with exactly 3 live neighbors becomes alive
                if self[row, col] == 0 and alive_neighbours == 3:
                    self[row, col] = 1


def run(owned grid: Grid) raises -> None:
    while True:
        print("Current mutation:\n\n")
        print(grid)
        print()
        print()
        if input("Enter 'q' to quit or press <Enter> to continue: ") == "q":
            break
        grid.mutate()


def main() raises -> None:
    grid_2 = GridV2.new(42, 16, 16)
    #run(grid_2)
    print(grid_2)
    print("Implicit conversion\n\n")
    grid_1 = Grid(grid_2)
    print(grid_1)

View source on GitHub

Gridv2

import random
from collections import Optional
from memory import UnsafePointer, memcpy, memset_zero
from gridv1 import Grid as GridV1

struct Grid(Stringable, Writable):
    var data: UnsafePointer[UInt8]
    var rows: Int
    var cols: Int

    def __init__(out self, rows: Int, cols: Int):
        self.rows = rows
        self.cols = cols
        self.data = UnsafePointer[UInt8].alloc(rows * cols)

    def __init__(out self, source: GridV1):
        rows = len(source.data)
        cols = len(source.data[0])
        self = Self(rows, cols)
        for row in range(rows):
            for col in range(cols):
                value = UInt8(source[row, col])
                (self.data + row * cols + col)[] = value

    def __copyinit__(out self, existing: Self):
        self.rows = existing.rows
        self.cols = existing.cols
        count = self.rows * self.cols
        self.data = UnsafePointer[UInt8].alloc(count)
        memcpy(dest=self.data, src=existing.data, count=count)

    def __moveinit__(out self, owned existing: Self):
        self.data = existing.data
        self.rows = existing.rows
        self.cols = existing.cols
    
    def __del__(owned self):
        self.data.free()    

    def __str__(self) -> String:
        capacity = self.rows * self.cols
        if capacity == 0:
            return String()
        s = String(capacity=capacity)
        for row in range(self.rows):
            for col in range(self.cols):
                # if (self.data + row * self.cols + col)[] == 1:
                if self[row, col] == 1:
                    s += "*"  # Alive cell represented by '*'
                else:
                    s += " "  # Dead cell is blank
            if row != self.rows - 1:
                s += "\n"  # Line break between rows
        return s

    def write_to[W: Writer](self, mut writer: W) -> None:
        writer.write(self.__str__())

    def __getitem__(self, row: Int, col: Int) -> UInt8:
        return (self.data + row * self.cols + col)[]

    def __setitem__(mut self, row: Int, col: Int, value: UInt8) -> None:
        (self.data + row * self.cols + col)[] = value

    @staticmethod
    def new(seed: Optional[Int], rows: Int, cols: Int) -> Self:
        if seed:
            random.seed(seed.value())
        else:
            random.seed()
        grid = Self(rows, cols)
        random.randint(grid.data, rows * cols, 0, 1)
        return grid

    def mutate(mut self) -> None:
        rows = self.rows
        cols = self.cols
        for row in range(rows):
            above = (row - 1) % rows
            below = (row + 1) % rows
            for col in range(cols):
                left = (col - 1) % cols
                right = (col + 1) % cols
                alive_neighbours = (
                    self[above, left]
                    + self[above, col]
                    + self[above, right]
                    + self[row, right]
                    + self[below, right]
                    + self[below, col]
                    + self[below, left]
                    + self[row, left]
                )
                if self[row, col] == 1:
                    if alive_neighbours < 2:
                        self[row, col] = 0
                    if alive_neighbours == 2 or alive_neighbours == 3:
                        continue
                    if alive_neighbours > 3:
                        self[row, col] = 0
                else:
                    if alive_neighbours == 3:
                        self[row, col] = 1


def run(owned grid: Grid) raises -> None:
    while True:
        print("Current mutation:\n\n")
        print(grid)
        print()
        print()
        if input("Enter 'q' to quit or press <Enter> to continue: ") == "q":
            break
        grid.mutate()


def main() raises -> None:
    grid_1 = GridV1.new(16, 16)
    # run(grid_1)
    print(grid_1)
    print("Implicit conversion\n\n")
    grid_2 = Grid(grid_1)
    print(grid_2)

View source on GitHub

Cyclic Reference 1

# We have no issues Referece1 calling Reference2 which also calls Reference1
from cyclic_reference_2 import Reference2


struct Reference1:
    def __init__(out self):
        Reference2.print("Reference2 inside Reference1 constructor")

    @staticmethod
    def print(s: String):
        print(s)


def main():
    var ref1 = Reference1()

View source on GitHub

Cyclic Reference 2

# We have no issues Referece2 calling Reference1 which also calls Reference2

from cyclic_reference_1 import Reference1


struct Reference2:
    def __init__(out self):
        Reference1.print("Reference1 inside Reference2 constructor")

    @staticmethod
    def print(s: String):
        print(s)


def main():
    var ref2 = Reference2()

View source on GitHub

Buffer Reduce

from buffer import NDBuffer
from algorithm import vectorize
from sys import simdwidthof


def summer[
    type: DType, //, simdwidth: Int = simdwidthof[type]()
](buffer: NDBuffer[type=type, rank=1]) -> Scalar[type]:
    result = Scalar[type](0)

    @parameter
    def sum[simd_width: Int](idx: Int):
        result += buffer.load[width=simd_width](idx).reduce_add()

    vectorize[sum, simdwidth](len(buffer))
    return result


from collections import InlineArray
from math import iota


def main() raises:
    comptime elem_count = 30
    var array = InlineArray[Scalar[DType.float64], elem_count](
        uninitialized=True
    )
    iota(array.unsafe_ptr(), elem_count)

    var buf = NDBuffer[DType.float64, 1, _, elem_count](array)

    result = summer(buf)
    print(result)

View source on GitHub

Custom Struct Compare

from search_sorted_rotated_arr import find
from shapes import Rectangle


def main():
    # Re
    r4 = Rectangle(4, 10)  # 40
    r5 = Rectangle(4, 12)  # 48
    r6 = Rectangle(5, 10)  # 50
    r7 = Rectangle(8, 8)  # 64
    r1 = Rectangle(3, 2)  # 6
    r2 = Rectangle(4, 4)  # 16
    r3 = Rectangle(4, 8)  # 32
    # Rectangles are sorted and rotated in the list
    items = [r4, r5, r6, r7, r1, r2, r3]
    item_index = find(items, r2)
    debug_assert(item_index == 5, "Assertion failed")

View source on GitHub

SIMD Select

Select based on a SIMD (Single Instruction, Multiple Data) mask

This example demonstrates how to use SIMD.select() to perform element-wise conditional selection between two SIMD vectors based on a boolean mask.

from std.testing import assert_true


def main() raises:
    # Create a SIMD (Single Instruction, Multiple Data) boolean selector vector of size 4.
    # Each element is a boolean value (True or False), indicating which value to select from `left` or `right`.
    # - If selector[i] == True, take the value from `left[i]`
    # - If selector[i] == False, take the value from `right[i]`
    selector = SIMD[DType.bool, 4](False, True, False, True)

    # Define a SIMD vector `left` with 4 elements of unsigned 8-bit integers
    left = SIMD[DType.uint8, 4](0, 42, 0, 42)

    # Define another SIMD vector `right` with 4 elements of unsigned 8-bit integers
    right = SIMD[DType.uint8, 4](42, 0, 42, 0)

    # Use the selector to choose elements from either `left` or `right`:
    # result[i] = left[i] if selector[i] else right[i]
    result = selector.select(left, right)
    #   β†’ result = [42, 42, 42, 42]

    expected = SIMD[DType.uint8, 4](42)  #   β†’ expected = [42, 42, 42, 42]
    assert_true(all(result == expected))

View source on GitHub

Check device core

Check physical and logical cores of the device

from sys import num_physical_cores, num_logical_cores

def main():
    print("    Physical Cores : ", num_physical_cores())
    print("    Logical Cores  : ", num_logical_cores())

View source on GitHub

Add 10

Implement a kernel that adds 10 to each position of vector a and stores it in vector out.

from gpu.host import DeviceContext
from memory import UnsafePointer
from gpu import thread_idx

comptime SIZE = 4
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = SIZE
comptime dtype = DType.float32


def add_10(
    out: UnsafePointer[Scalar[dtype]], array: UnsafePointer[Scalar[dtype]]
):
    tid = thread_idx.x
    out[tid] = array[tid] + 10


def main() raises:
    ctx = DeviceContext()
    d_array_buff = ctx.enqueue_create_buffer[dtype](SIZE)
    expected = ctx.enqueue_create_buffer[dtype](SIZE)
    d_out_buff = ctx.enqueue_create_buffer[dtype](SIZE)

    _ = d_out_buff.enqueue_fill(0)

    with d_array_buff.map_to_host() as h_array_buff:
        for i in range(SIZE):
            h_array_buff[i] = i

    ctx.enqueue_function[add_10](
        d_out_buff.unsafe_ptr(),
        d_array_buff.unsafe_ptr(),
        grid_dim=BLOCKS_PER_GRID,
        block_dim=THREADS_PER_BLOCK,
    )

    ctx.synchronize()

    with d_out_buff.map_to_host() as h_out_buff:
        print(h_out_buff)

View source on GitHub

Add a constant 10

Implement a kernel that adds 10 to each position of 2d matrix a and stores it in out 2d matrix.

from gpu.host import DeviceContext
from memory import UnsafePointer
from gpu import thread_idx, block_dim
from std.testing import assert_equal

comptime SIZE = 2
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = (3,3)
comptime dtype = DType.float32


def add_10_2d(
    out: UnsafePointer[Scalar[dtype]], array: UnsafePointer[Scalar[dtype]], size: Int
):
    tid = thread_idx.z * (block_dim.y * block_dim.x) + thread_idx.y * block_dim.x + thread_idx.x
    if tid < size * size:
        out[tid] = array[tid] + 10


def main():
  try:
    ctx = DeviceContext()
    d_array_buff = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(0)
    d_out_buff = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(0)
    expected = ctx.enqueue_create_host_buffer[dtype](SIZE * SIZE).enqueue_fill(0)


    with d_array_buff.map_to_host() as h_array_buff:
        for i in range(SIZE):
            for j in range(SIZE):
                h_array_buff[i * SIZE + j] = i * SIZE + j
                expected[i * SIZE + j] = h_array_buff[i * SIZE + j] + 10
        print("Input: ", h_array_buff)

    ctx.enqueue_function[add_10_2d](
            d_out_buff.unsafe_ptr(),
            d_array_buff.unsafe_ptr(),
            SIZE,
            grid_dim=BLOCKS_PER_GRID,
            block_dim=THREADS_PER_BLOCK,
        )

    ctx.synchronize()

    with d_out_buff.map_to_host() as h_out_buff:
        print(h_out_buff)
        print(expected)
        for i in range(SIZE * SIZE ):
            assert_equal(h_out_buff[i], expected[i])

  except e:
    print(e)

View source on GitHub

Add constant to 2D Layout tensor

Implement a kernel that adds 10 to each position of 2D LayoutTensor a and stores it in 2D LayoutTensor out.

from gpu.host import DeviceContext
from gpu import thread_idx
from layout import Layout, LayoutTensor
from math import iota


comptime SIZE = 2
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = (3, 3)
comptime dtype = DType.float32
comptime layout = Layout.row_major(SIZE, SIZE)


def add_10_2dlayout(
    out: LayoutTensor[mut=True, dtype, layout],
    a: LayoutTensor[mut=True, dtype, layout],
    size: Int,
):
    row = thread_idx.y
    col = thread_idx.x
    # FILL ME IN (roughly 2 lines)
    if row < size and col < size:
        out[row, col] = a[row, col] + 10


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(
            0.0
        )
        buffer_out = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(
            0.0
        )

        with buffer_a.map_to_host() as h_buffer_a:
            iota(h_buffer_a.unsafe_ptr(), SIZE * SIZE)

        out = LayoutTensor[mut=True, dtype, layout](buffer_out)
        a = LayoutTensor[mut=True, dtype, layout](buffer_a)

        ctx.enqueue_function[add_10_2dlayout](
            out,
            a,
            SIZE,
            grid_dim=(BLOCKS_PER_GRID, BLOCKS_PER_GRID),
            block_dim=THREADS_PER_BLOCK,
        )

        ctx.synchronize()

        with buffer_out.map_to_host() as h_buffer_out:
            print(h_buffer_out)
    except e:
        print(e)

View source on GitHub

Add 10

Implement a kernel that adds 10 to each position of vector a and stores it in vector out.

More threads than data β€” guard against out-of-bounds access.

### Add 10
### Implement a kernel that adds 10 to each position of vector a and stores it in vector out.
### More threads than data β€” guard against out-of-bounds access.

from gpu.host import DeviceContext
from memory import UnsafePointer
from gpu import thread_idx, block_dim, block_idx
from std.testing import assert_equal

comptime SIZE = 4
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = (8, 1)
comptime dtype = DType.float32


def add_10_with_guard(
    out: UnsafePointer[Scalar[dtype]], array: UnsafePointer[Scalar[dtype]]
):
    tid = (
        thread_idx.z * (block_dim.y * block_dim.x)
        + thread_idx.y * block_dim.x
        + thread_idx.x
    )

    if tid < SIZE:
        out[tid] = array[tid] + 10


def main() raises:
    ctx = DeviceContext()
    d_array_buff = ctx.enqueue_create_buffer[dtype](SIZE)
    d_out_buff = ctx.enqueue_create_buffer[dtype](SIZE)
    expected = ctx.enqueue_create_host_buffer[dtype](SIZE)
    _ = d_out_buff.enqueue_fill(0)

    with d_array_buff.map_to_host() as h_array_buff:
        for i in range(SIZE):
            h_array_buff[i] = i

    ctx.enqueue_function[add_10_with_guard](
        d_out_buff.unsafe_ptr(),
        d_array_buff.unsafe_ptr(),
        grid_dim=BLOCKS_PER_GRID,
        block_dim=THREADS_PER_BLOCK,
    )

    ctx.synchronize()

    for i in range(SIZE):
        expected[i] = i + 10

    print(expected)

    with d_out_buff.map_to_host() as h_out_buff:
        print(h_out_buff)
        for i in range(SIZE):
            assert_equal(h_out_buff[i], expected[i])

View source on GitHub

Vector Add(Flexible Kernel with Grid-Stride Loop)

Demonstrates a vector addition kernel with Grid-Stride loops, SIMD vectorization,

and loop unrolling β€” runnable on both CPU and GPU (when available).

See Grid-stride loop

from std.gpu.host import DeviceContext, HostBuffer, DeviceAttribute
from std.gpu import thread_idx, block_idx, block_dim, grid_dim
from std.testing import assert_almost_equal
from utils import Timer
from std.random import random_float64, seed
from std.sys import has_accelerator, simd_width_of


# GPU kernel: element-wise vector addition with grid-stride loop, SIMD loads,
# and compile-time loop unrolling. Each thread processes CHUNK_SIZE elements
# per iteration, then advances by the total grid stride.
#
# Parameters:
#   result: output pointer (mutably addressed)
#   a, b: input pointers (immutably addressed)
#   size: number of elements in each vector
#
# Template parameters:
#   dtype: element data type (e.g. DType.float32)
#   simd_width: SIMD width, auto-detected from dtype
#   simd_vectors_per_thread: number of SIMD vectors per thread per grid step
#
def vector_add[
    dtype: DType,
    simd_width: Int = simd_width_of[dtype](),
    simd_vectors_per_thread: Int = 4 * simd_width,
](
    result: UnsafePointer[Scalar[dtype], MutAnyOrigin],
    a: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
    b: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
    size: Int,
):
    var tid = block_idx.x * block_dim.x + thread_idx.x
    var grid_stride = grid_dim.x * block_dim.x

    comptime CHUNK_SIZE = simd_vectors_per_thread * simd_width
    # =========================================================
    # Each thread processes CHUNK_SIZE elements
    # =========================================================
    var start_index = (
        tid * CHUNK_SIZE
    )  # Start index for each thread per grid_stride

    while start_index < size:
        comptime for vector in range(simd_vectors_per_thread):
            var i = start_index + vector * simd_width

            # Bound check for this vector
            if i + simd_width <= size:
                # Load whole vectors, add up and store
                result.store[width=simd_width](
                    i, a.load[width=simd_width](i) + b.load[width=simd_width](i)
                )
            else:  # i < size, can not load a simd_length vector, handle tail
                for j in range(i, size):
                    result.store[width=1](
                        j, a.load[width=1](j) + b.load[width=1](j)
                    )

        start_index += grid_stride * CHUNK_SIZE


# CPU reference implementation: simple sequential element-wise vector addition.
#
# Parameters:
#   result: output host buffer
#   a, b: input host buffers
#   size: number of elements
#
def vector_add_cpu[
    dtype: DType,
    //,
](
    result: HostBuffer[dtype],
    a: HostBuffer[dtype],
    b: HostBuffer[dtype],
    size: Int,
):
    var i = 0
    while i < size:
        result[i] = a[i] + b[i]
        i += 1


# Fill a host buffer with random float64 values cast to the target dtype.
# Optionally accepts a seed for reproducible results.
#
# Parameters:
#   buffer_a: host buffer to fill
#   init_seed: optional RNG seed (deterministic if provided)
#   min, max: range for random values
#
def fill[
    dtype: DType,
    //,
](
    buffer_a: HostBuffer[dtype],
    init_seed: Optional[Int] = None,
    min: Float64 = 1.0,
    max: Float64 = 10.0,
):
    if init_seed:
        seed(init_seed.value())
    else:
        seed()
    for i in range(len(buffer_a)):
        buffer_a[i] = random_float64(min, max).cast[dtype]()


# Benchmark vector addition on CPU and (if available) GPU, then validate that
# all GPU results match the CPU reference within a small tolerance.
#
def main() raises:
    comptime dtype = DType.float32
    var size = 100000000
    var cpu_ctx = DeviceContext(api="cpu")

    var lhs_host_buffer = cpu_ctx.enqueue_create_host_buffer[dtype](size)
    var rhs_host_buffer = cpu_ctx.enqueue_create_host_buffer[dtype](size)
    var result_host_buffer = cpu_ctx.enqueue_create_host_buffer[dtype](size)

    fill(lhs_host_buffer, init_seed=42)
    fill(rhs_host_buffer, init_seed=123)
    with Timer("CPU execution took: "):
        vector_add_cpu(
            result_host_buffer, lhs_host_buffer, rhs_host_buffer, size
        )
    cpu_ctx.synchronize()

    comptime if has_accelerator():
        var gpu_ctx = DeviceContext()

        var result_gpu_buffer = gpu_ctx.enqueue_create_buffer[dtype](size)
        var lhs_gpu_buffer = gpu_ctx.enqueue_create_buffer[dtype](size)
        var rhs_gpu_buffer = gpu_ctx.enqueue_create_buffer[dtype](size)

        lhs_host_buffer.enqueue_copy_to(dst=lhs_gpu_buffer)
        rhs_host_buffer.enqueue_copy_to(dst=rhs_gpu_buffer)

        var max_blocks_per_sm = gpu_ctx.get_attribute(
            DeviceAttribute.MAX_BLOCKS_PER_MULTIPROCESSOR
        )
        var sm_count = gpu_ctx.get_attribute(
            DeviceAttribute.MULTIPROCESSOR_COUNT
        )
        var threads_per_block = 256
        var max_threads_per_sm = gpu_ctx.get_attribute(
           DeviceAttribute.MAX_THREADS_PER_MULTIPROCESSOR
        )
        var max_blocks = max_threads_per_sm // threads_per_block
        var blocks_count = min(max_blocks_per_sm, max_blocks) * sm_count * 4
        print("Max block per sm: ", max_blocks, "sm count: ", sm_count)
        print(
            "Launching",
            blocks_count,
            "blocks with",
            threads_per_block,
            "threads per block",
        )

        with Timer("GPU execution took: "):
            gpu_ctx.enqueue_function[vector_add[dtype]](
                result_gpu_buffer.unsafe_ptr(),
                lhs_gpu_buffer.unsafe_ptr(),
                rhs_gpu_buffer.unsafe_ptr(),
                size,
                grid_dim=blocks_count,
                block_dim=threads_per_block,
            )
            gpu_ctx.synchronize()

        with result_gpu_buffer.map_to_host() as gpu_result:
            for i in range(size):
                assert_almost_equal(gpu_result[i], result_host_buffer[i])

View source on GitHub

Layout Basics

from gpu.host import DeviceContext
from layout import Layout, LayoutTensor

comptime HEIGHT = 2
comptime WIDTH = 3
comptime dtype = DType.float32
comptime layout = Layout.row_major(HEIGHT, WIDTH)
comptime BLOCKS_PER_GRID = 1
comptime THREADS_PER_BLOCK = 1


def kernel[
    dtype: DType, layout: Layout
](tensor: LayoutTensor[mut=True, dtype, layout]):
    print("Before\n")
    print(tensor)
    tensor[0, 0] += 1.0
    print()
    print("After\n")
    print(tensor)


def main():
    ctx = DeviceContext(api="cuda")
    cpu_ctx = DeviceContext(api="cpu")
    buffer = ctx.enqueue_create_buffer[dtype](HEIGHT * WIDTH).enqueue_fill(0)
    cpu_buffer = cpu_ctx.enqueue_create_host_buffer[dtype](HEIGHT * WIDTH)

    for i in range(HEIGHT * WIDTH):
        cpu_buffer[i] = i**2

    cpu_buffer.enqueue_copy_to(buffer)

    tensor = LayoutTensor[mut=True, dtype, layout](buffer.unsafe_ptr())

    ctx.enqueue_function[kernel[dtype, layout]](
        tensor, grid_dim=BLOCKS_PER_GRID, block_dim=THREADS_PER_BLOCK
    )

    ctx.synchronize()

    print(ctx.name())
    print(ctx.api())
    print(cpu_ctx.api())
    cpu_buffer.unsafe_ptr()[] = 98.0
    print(cpu_buffer)

View source on GitHub

Dumb matrix multiplication

Simulate the CPU-style triple for-loop truly dumb matrix multiplication

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true


comptime ROWS_A = 8
comptime COLS_A = 16
comptime ROWS_B = 16
comptime COLS_B = 8
comptime ROWS_C = 8
comptime COLS_C = 8


comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = 1
# Total numbers blocks in the grid
comptime BLOCKS = 1

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)

# alias Matrix = LayoutTensor[dtype, _, MutableAnyOrigin]
comptime Matrix = LayoutTensor[mut=True, dtype, _]


def naive_matmaul(
    A: UnsafePointer[Scalar[dtype]],
    B: UnsafePointer[Scalar[dtype]],
    C: UnsafePointer[Scalar[dtype]],
):
    var tid = block_idx.x * block_dim.x + thread_idx.x

    if tid == 0:
        for i in range(ROWS_A):
            for j in range(COLS_B):
                for k in range(COLS_A):
                    (C + i * COLS_C + j)[] += (A + i * COLS_A + k)[] * (
                        B + k * COLS_B + j
                    )[]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    # random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        # matrix_a = LayoutTensor[dtype, layout_a, MutableAnyOrigin](buffer_a)
        # matrix_b = LayoutTensor[dtype, layout_b, MutableAnyOrigin](buffer_b)
        # matrix_c =  LayoutTensor[dtype, layout_c, MutableAnyOrigin](buffer_c)

        ctx.enqueue_function[naive_matmaul](
            buffer_a.unsafe_ptr(),
            buffer_b.unsafe_ptr(),
            buffer_c.unsafe_ptr(),
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Matrix multiplication 1 GPU thread per output column

Simulate the CPU-style dumb matrix multiplication 1 thread per output column

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true

comptime ROWS_A = 33
comptime COLS_A = 13
comptime ROWS_B = 13
comptime COLS_B = 8
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = COLS_C
# Total numbers blocks in the grid
comptime BLOCKS = 1

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]


def naive_matmul_one_thread_per_col[
    a: Layout, b: Layout, c: Layout
](A: MatrixA, B: MatrixB, C: MatrixC,):
    var tid = block_idx.x * block_dim.x + thread_idx.x

    if tid < COLS_C:  # Each thread id `tid` is cols of C or B
        for i in range(ROWS_A):
            for k in range(COLS_A):
                C[i, tid] += A[i, k] * B[k, tid]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)

        ctx.enqueue_function[
            naive_matmul_one_thread_per_col[layout_a, layout_b, layout_c]
        ](
            matrix_a,
            matrix_b,
            matrix_c,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Dumb matrix multiplication

Simulate the CPU-style matrix multiplication with 1 GPU thread per row

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true

comptime ROWS_A = 64
comptime COLS_A = 16
comptime ROWS_B = 16
comptime COLS_B = 8
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = ROWS_C
# Total numbers blocks in the grid
comptime BLOCKS = 1

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]


def naive_matmul_one_thread_per_row[
    a: Layout, b: Layout, c: Layout
](A: MatrixA, B: MatrixB, C: MatrixC,):
    var tid = block_idx.x * block_dim.x + thread_idx.x

    if tid < ROWS_A:  # Each thread id `tid` is a row of A or C
        for j in range(COLS_B):
            for k in range(COLS_A):
                C[tid, j] += A[tid, k] * B[k, j]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)

        ctx.enqueue_function[
            naive_matmul_one_thread_per_row[layout_a, layout_b, layout_c]
        ](
            matrix_a,
            matrix_b,
            matrix_c,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Dumb matrix multiplication

Simulate the CPU-style triple for-loop truly dumb matrix multiplication

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true

comptime ROWS_A = 64
comptime COLS_A = 16
comptime ROWS_B = 16
comptime COLS_B = 8
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = 1
# Total numbers blocks in the grid
comptime BLOCKS = 1

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]


def naive_matmul_single_thread_layout_tensor[
    a: Layout, b: Layout, c: Layout
](A: MatrixA, B: MatrixB, C: MatrixC,):
    var tid = block_idx.x * block_dim.x + thread_idx.x

    if tid == 0:
        for i in range(ROWS_A):
            for j in range(COLS_B):
                for k in range(COLS_A):
                    C[i, j] += A[i, k] * B[k, j]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)

        ctx.enqueue_function[
            naive_matmul_single_thread_layout_tensor[
                layout_a, layout_b, layout_c
            ]
        ](
            matrix_a,
            matrix_b,
            matrix_c,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Dumb matrix multiplication

Use one one GPU thread for each column of the output matrix

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy
from python import Python, PythonObject
from std.testing import assert_true

comptime ROWS_A = 64
comptime COLS_A = 16
comptime ROWS_B = 16
comptime COLS_B = 8
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = (5, 5)
# Total numbers blocks in the grid
comptime BLOCKS = (
    (COLS_C + THREADS[0] - 1) // THREADS[0],
    (ROWS_C + THREADS[1] - 1) // THREADS[1],
)

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]


def matmul_thread_per_output_cell[
    a: Layout, b: Layout, c: Layout
](A: MatrixA, B: MatrixB, C: MatrixC,):
    var i = block_idx.y * block_dim.y + thread_idx.y  # Rows
    var j = block_idx.x * block_dim.x + thread_idx.x  # Colums

    if i < ROWS_C and j < COLS_C:
        for k in range(ROWS_B):
            C[i, j] += A[i, k] * B[k, j]


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)

        ctx.enqueue_function[
            matmul_thread_per_output_cell[layout_a, layout_b, layout_c]
        ](
            matrix_a,
            matrix_b,
            matrix_c,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Matmul Thread Per Output Cell Vectorized

Uses shared memory via stack_allocation

from gpu.host import DeviceContext, HostBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from layout import Layout, LayoutTensor
from memory import UnsafePointer, memcpy, stack_allocation
from python import Python, PythonObject
from std.testing import assert_true
from algorithm import vectorize
from sys import simdwidthof, strided_load


comptime ROWS_A = 9
comptime COLS_A = 17
comptime ROWS_B = 17
comptime COLS_B = 7
comptime ROWS_C = ROWS_A
comptime COLS_C = COLS_B

comptime MATRIX_MIN_ELEM = -5.0
comptime MATRIX_MAX_ELEM = 5.0

comptime dtype = DType.float32
# Num threads per block
comptime THREADS = (5, 5)
# Total numbers blocks in the grid
comptime BLOCKS = (
    (COLS_C + THREADS[0] - 1) // THREADS[0],
    (ROWS_C + THREADS[1] - 1) // THREADS[1],
)

comptime layout_a = Layout.row_major(ROWS_A, COLS_A)
comptime layout_b = Layout.row_major(ROWS_B, COLS_B)
comptime layout_c = Layout.row_major(ROWS_C, COLS_C)


comptime MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
comptime MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
comptime MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]
comptime Storage = LayoutTensor[
    dtype, Layout.row_major(1, simdwidthof[dtype]()), MutableAnyOrigin
]


def matmul_thread_per_output_cell_vectorized(
    A: MatrixA, B: MatrixB, C: MatrixC, store: Storage
):
    var i = block_idx.y * block_dim.y + thread_idx.y  # Rows
    var j = block_idx.x * block_dim.x + thread_idx.x  # Colums
    if i < ROWS_C and j < COLS_C:
        tile = stack_allocation[ROWS_B, Scalar[dtype]]()
        each_b_col = B.tile[ROWS_B, 1](0, j)
        for k in range(ROWS_B):
            tile[k] = each_b_col[k, 0][0]

        @parameter
        def dotproduct[simd_width: Int](idx: Int):
            C[i, j] += (
                A.load[width=simd_width](i, idx)
                * tile.load[width=simd_width](idx)
            ).reduce_add()

        vectorize[dotproduct, simdwidthof[dtype]()](ROWS_B)


# Initialize the matrix buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_float64(
            MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
        ).cast[dtype]()[0]


def main():
    try:
        ctx = DeviceContext()

        buffer_a = ctx.enqueue_create_buffer[dtype](
            ROWS_A * COLS_A
        ).enqueue_fill(0.0)
        buffer_b = ctx.enqueue_create_buffer[dtype](
            ROWS_B * COLS_B
        ).enqueue_fill(0.0)
        buffer_c = ctx.enqueue_create_buffer[dtype](
            ROWS_C * COLS_C
        ).enqueue_fill(0.0)

        store = ctx.enqueue_create_buffer[dtype](
            simdwidthof[dtype]()
        ).enqueue_fill(0.0)

        with buffer_a.map_to_host() as h_buffer_a:
            fill_buffer(h_buffer_a)

        with buffer_b.map_to_host() as h_buffer_b:
            fill_buffer(h_buffer_b)

        matrix_a = MatrixA(buffer_a)
        matrix_b = MatrixB(buffer_b)
        matrix_c = MatrixC(buffer_c)
        storage = Storage(store)

        ctx.enqueue_function[matmul_thread_per_output_cell_vectorized](
            matrix_a,
            matrix_b,
            matrix_c,
            storage,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with buffer_a.map_to_host() as h_buffer_a:
            with buffer_b.map_to_host() as h_buffer_b:
                with buffer_c.map_to_host() as h_buffer_c:
                    assert_allclose(
                        (ROWS_A, COLS_A, h_buffer_a),
                        (ROWS_B, COLS_B, h_buffer_b),
                        (ROWS_C, COLS_C, h_buffer_c),
                    )

    except e:
        print("Prininting here: ", e)


def assert_allclose(
    buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
    buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
) raises:
    a_rows, a_cols, a_buff = buff_a_with_dims
    matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)

    b_rows, b_cols, b_buff = buff_b_with_dims
    matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)

    c_rows, c_cols, c_buff = buff_c_with_dims
    matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
    np = Python.import_module("numpy")
    assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
    print("Assertion was successful")


def to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
    np = Python.import_module("numpy")
    ndarray = np.zeros(len(buffer), dtype=np.float32)
    ndarray_ptr = ndarray_ptr[dtype](ndarray)
    buffer_ptr = buffer.unsafe_ptr()
    memcpy(ndarray_ptr, buffer_ptr, len(buffer))
    return ndarray


def reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
    return ndarray.reshape(rows, cols)


def ndarray_ptr[
    dtype: DType
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
    return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()

View source on GitHub

Timer utility

A simple RAII / context-manager timer that prints elapsed wall-clock time in nanoseconds.

from std.time import global_perf_counter_ns


# Timer struct used via Python-style `with` blocks.
#
# Records the time at `__enter__` and prints the elapsed duration at `__exit__`.
# The label prefix is a templated `StringSlice` to avoid allocation.
#
# Usage:
#   with Timer("My computation: "):
#       do_something()
#   # Prints: "My computation:  12345000 nanoseconds"
#
@fieldwise_init
struct Timer[origin: Origin, //](ImplicitlyCopyable):
    var start_time: UInt64
    var prefix: StringSlice[Self.origin]

    def __init__(out self, prefix: StringSlice[Self.origin]):
        self.start_time = 0
        self.prefix = prefix

    def __enter__(mut self) -> Self:
        self.start_time = global_perf_counter_ns()
        return self

    def __exit__(mut self):
        elapsed_time_ms = global_perf_counter_ns() - self.start_time
        print(self.prefix, elapsed_time_ms, "nanoseconds")

View source on GitHub

Histogram

Program to compute histogram of a 1D array

from gpu.host import DeviceContext, HostBuffer, DeviceBuffer
from gpu import thread_idx, block_idx, block_dim
import random
from math import ceildiv
from memory import UnsafePointer
from layout import Layout, LayoutTensor
from os import Atomic
from os.atomic import Consistency

comptime dtype = DType.int64
# How many numbers to bin? 2 ^ 20 (default)
comptime ELEMS_COUNT = 1 << 20
# How many bins?
comptime NUM_BINS = 10
# Num threads per block
comptime THREADS = 256
# Total numbers blocks in the grid
comptime BLOCKS = ceildiv(ELEMS_COUNT, THREADS)

# Max value of any binned element
comptime MAX_ELEM = 101
comptime MIN_ELEM = 1

comptime BIN_WIDTH = (MAX_ELEM - MIN_ELEM + 1) // NUM_BINS
comptime input_layout = Layout.row_major(ELEMS_COUNT)


def histogram(
    input: LayoutTensor[dtype, input_layout, MutableAnyOrigin],
    output: UnsafePointer[Scalar[dtype]],
    total_elems: Int,
):
    var tid = block_idx.x * block_dim.x + thread_idx.x

    if tid < total_elems:
        var elem = input[tid]
        bin_index = bin_index(elem[0])
        # _ = Atomic.fetch_add[ordering= Consistency.MONOTONIC](output + bin_index, 1)
        _ = Atomic.fetch_add(output + bin_index, 1)


# Initialize the input buffer with values in the range 0 to 100
def fill_buffer(buffer: HostBuffer[dtype]):
    # Randomize
    random.seed()
    for i in range(len(buffer)):
        buffer[i] = random.random_ui64(MIN_ELEM, MAX_ELEM).cast[dtype]()[0]


# Find the bin index given a number
@always_inline
def bin_index(elem: Int64) -> Int:
    bin_index = Int((elem - MIN_ELEM) // BIN_WIDTH)
    if bin_index >= NUM_BINS:
        bin_index = NUM_BINS - 1
    elif bin_index < 0:
        bin_index = 0
    return bin_index


def main():
    try:
        ctx = DeviceContext()

        elements = ctx.enqueue_create_buffer[dtype](ELEMS_COUNT)
        bins = ctx.enqueue_create_buffer[dtype](NUM_BINS).enqueue_fill(0)

        with elements.map_to_host() as host_elements:
            fill_buffer(host_elements)

        input_tensor = LayoutTensor[dtype, input_layout, MutableAnyOrigin](
            elements
        )
        # output_tensor = LayoutTensor[mut=True, dtype, output_layout](bins)

        ctx.enqueue_function[histogram](
            input_tensor,
            bins.unsafe_ptr(),
            ELEMS_COUNT,
            grid_dim=BLOCKS,
            block_dim=THREADS,
        )

        ctx.synchronize()

        with bins.map_to_host() as bins_host:
            print(bins_host)

        print(ctx.name())
    except e:
        print("Prininting here: ", e)

View source on GitHub