Wednesday, January 15, 2025

Single Instruction Multiple Data - (SIMD) Java Example

import jdk.incubator.vector.*;
import java.util.Arrays;

public class VectorAdditionSIMD {

    private static final VectorSpecies<Double> SPECIES = DoubleVector.SPECIES_PREFERRED;
    private static final int VECTOR_SIZE = SPECIES.length();

    public static double[] addVectorsSIMD(double[] vector1, double[] vector2) {
        if (vector1 == null || vector2 == null) {
            throw new IllegalArgumentException("Vectors cannot be null.");
        }
        if (vector1.length != vector2.length) {
            throw new IllegalArgumentException("Vectors must have the same dimensions.");
        }

        double[] result = new double[vector1.length];

        int i = 0;
        // Vectorized loop
        for (; i <= vector1.length - VECTOR_SIZE; i += VECTOR_SIZE) {
            DoubleVector v1 = DoubleVector.fromArray(SPECIES, vector1, i);
            DoubleVector v2 = DoubleVector.fromArray(SPECIES, vector2, i);
            DoubleVector sum = v1.add(v2);
            sum.intoArray(result, i);
        }

        // Scalar loop for remaining elements
        for (; i < vector1.length; i++) {
            result[i] = vector1[i] + vector2[i];
        }

        return result;
    }

    public static void main(String[] args) {
        // Example Usage
        double[] v1 = new double[100];
        double[] v2 = new double[100];
        double[] v3 = new double[7];

        for(int i = 0; i < 100; i++){
            v1[i] = i * 1.0;
            v2[i] = i * 2.0;
        }

        for(int i = 0; i < 7; i++){
            v3[i] = i * 3.0;
        }

        try {
            double[] sum = addVectorsSIMD(v1, v2);
            System.out.println("Vector 1: " + Arrays.toString(Arrays.copyOf(v1,10)));
            System.out.println("Vector 2: " + Arrays.toString(Arrays.copyOf(v2,10)));
            System.out.println("SIMD Sum: " + Arrays.toString(Arrays.copyOf(sum, 10)));

            double[] sum2 = addVectorsSIMD(v1, v3);

        } catch (IllegalArgumentException e) {
            System.err.println("Error: " + e.getMessage());
        }

    }
}

To verify, first install hsdis you can get a prebuilt binary (e.g. chriswhocodes.com/hsdis). Copy the shared library or dll to the appropriate location (e.g. on Linux copy hsdis-amd64.so to /usr/lib/jvm/java-21-openjdk-amd64/lib/server/) hsdis is the hotspot disassembler so we can see what instructions (and compiler we need c2/server not c1) is utilized.

Compile:

javac --add-modules jdk.incubator.vector VectorAdditionSIMD.java

Run/examine output:

java -XX:+UnlockDiagnosticVMOptions -XX:+PrintAssembly -XX:CompileThreshold=1 -XX:-TieredCompilation --add-modules jdk.incubator.vector VectorAdditionSIMD > c2output.txt

Now within c2output.txt you should find that the C2 compiler was utilized, and furthermore that you are using ymm (256bit) registers, this is a good sign that you are executing via SIMD.

If all is well you should see output along the lines of:

  0x00007f345ca993ca:   vmovdqu (%rdi,%rcx,1),%ymm0
  0x00007f345ca993cf:   vmovdqu (%rsi,%rcx,1),%ymm1
  0x00007f345ca993d4:   vpxor  %ymm1,%ymm0,%ymm0
  0x00007f345ca993d8:   vptest %ymm0,%ymm0
  0x00007f345ca993f1:   vmovdqu -0x20(%rdi,%rax,1),%ymm0
  0x00007f345ca993f7:   vmovdqu -0x20(%rsi,%rax,1),%ymm1
  0x00007f345ca993fd:   vpxor  %ymm1,%ymm0,%ymm0