- Technical highlight of the release: AVX-512 masking is extremely handy in implementing unrolled BLAS Level 2 operations __for small inputs, resulting in up to 5x faster kernels than OpenBLAS__.
- Semantic highlight of the release: Bilinear forms now __support complex numbers__ as inputs, extending the kernels' applicability to Computational Physics.
---
Bilinear Forms are essential in Scientific Computing. Some of the most computationally intensive cases arise in Quantum systems and their simulations, as discussed on [`r/Quantum`](https://www.reddit.com/r/quantum/comments/1gy1s7d/seeking_advice_on_opensource_hardwareaccelerated/). This PR adds support for complex inputs to make it more broadly applicable.
math
\text{BilinearForm}(a, b, M) = a^T M b
In Python, you can execute this by consecutively calling 2 NumPy functions. Ideally, reusing a buffer for the intermediate results:
py
ndim = 128
import numpy as np
dtype = np.float32
temporary_vector = np.empty((ndim, ), dtype=dtype)
first_quantum_state = np.random.randn(ndim).astype(dtype)
second_quantum_state = np.random.randn(ndim).astype(dtype)
interaction_matrix = np.random.randn(ndim, ndim).astype(dtype)
np.matmul(first_quantum_state, interaction_matrix, out=temporary_vector)
result: float = np.inner(temporary_vector, second_quantum_state)
With SimSIMD, the last 2 lines are fused:
py
import simsimd as simd
simd.bilinear(first_quantum_state, second_quantum_state, interaction_matrix)
For 128-dimensional `np.float32`, the latency of 2.11 μs with NumPy went down to 1.31 μs. For smaller 16-dimensional `np.float32`, the latency of 1.31 μs with NumPy went down to 202 ns. As always, the gap is wider for low-precision `np.float16` representations: 2.68 μs with NumPy vs 313 ns with NumPy.
Small Matrices and AVX-512
In the past, developers were used to providing separate precompiled kernels for every reasonable matrix size when dealing with small matrices. That negatively affects the binary size and makes CPU `L1i` instruction caches ineffective. With AVX-512, however, for different matrix sizes, we can reuse the same single-instruction vectorized loops with just a single additional `BZHI` instruction precomputing the load masks.
Avoiding Data Dependency
A common approach in dot products is to use a single register to accumulate dot products. That [`VFMADD132PS` instruction](https://uops.info/html-instr/VFMADD132PS_ZMM_ZMM_ZMM.html):
- AMD Zen 4 has a latency of 4 cycles and can execute on ports 0 and 1.
- Intel Skylake-X has a latency of 4 cycles and can execute on ports 0 and 5.
Assuming it can run on 2 ports simultaneously, even on modern hardware, introducing data dependency between consecutive statements is inefficient. In future generations, we may be able to compute this on more ports, so to "futureproof" the solution, I use 4 intermediaries.
Avoiding Horizontal Reductions
When computing $a \dot X \dot b$, we may prefer to evaluate $X \dot b$ first due to the associativity of matrix multiplication. On tiny inputs, the operation may be bottlenecked by computing horizontal reductions for every one of the rows in $X$. Instead, we use more serial loads and broadcasts but only perform one horizontal accumulation in the end, assuming all of the needed intermediaries fit into a single register (or a few if we minimize the data dependency).
Intel Sapphire Rapids Benchmarks
Running on recent Intel Sapphire Rapids CPUs, one can expect the following performance metrics for 128-dimensional Bilinear Forms for SimSIMD and OpenBLAS:
sh
-----------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
-----------------------------------------------------------------------------------------------------------------
bilinear_f64_blas<128d>/min_time:10.000/threads:1 3584 ns 3584 ns 3906234 abs_delta=3.8576a bytes=571.503M/s pairs=279.054k/s relative_error=1.45341f
bilinear_f64c_blas<128d>/min_time:10.000/threads:1 7605 ns 7604 ns 1856665 abs_delta=3.90906a bytes=538.656M/s pairs=131.508k/s relative_error=3.10503f
bilinear_f32_blas<128d>/min_time:10.000/threads:1 1818 ns 1818 ns 7621072 abs_delta=743.294p bytes=563.325M/s pairs=550.122k/s relative_error=301.396n
bilinear_f32c_blas<128d>/min_time:10.000/threads:1 3607 ns 3606 ns 3886483 abs_delta=958.531p bytes=567.864M/s pairs=277.278k/s relative_error=1.4445u
bilinear_f16_haswell<128d>/min_time:10.000/threads:1 1324 ns 1324 ns 10597225 abs_delta=1.31674n bytes=386.742M/s pairs=755.355k/s relative_error=851.968n
bilinear_bf16_haswell<128d>/min_time:10.000/threads:1 1305 ns 1305 ns 10752131 abs_delta=1.33001n bytes=392.464M/s pairs=766.532k/s relative_error=561.046n
bilinear_bf16_genoa<128d>/min_time:10.000/threads:1 862 ns 862 ns 16241596 abs_delta=1.40284n bytes=593.885M/s pairs=1.15993M/s relative_error=849.533n
bilinear_bf16c_genoa<128d>/min_time:10.000/threads:1 2610 ns 2610 ns 5355435 abs_delta=351.596p bytes=392.313M/s pairs=383.118k/s relative_error=243.698n
bilinear_f16_sapphire<128d>/min_time:10.000/threads:1 875 ns 875 ns 16038203 abs_delta=10.5652u bytes=584.951M/s pairs=1.14248M/s relative_error=9.42998m
bilinear_f16c_sapphire<128d>/min_time:10.000/threads:1 2159 ns 2159 ns 6449575 abs_delta=4.43296u bytes=474.398M/s pairs=463.28k/s relative_error=3.98057m
bilinear_f64_skylake<128d>/min_time:10.000/threads:1 3483 ns 3483 ns 4019657 abs_delta=4.3853a bytes=587.96M/s pairs=287.09k/s relative_error=3.02046f
bilinear_f64c_skylake<128d>/min_time:10.000/threads:1 7178 ns 7178 ns 1949803 abs_delta=3.45547a bytes=570.624M/s pairs=139.313k/s relative_error=4.07708f
bilinear_f32_skylake<128d>/min_time:10.000/threads:1 1783 ns 1783 ns 7848896 abs_delta=2.45041n bytes=574.255M/s pairs=560.796k/s relative_error=811.561n
bilinear_f32c_skylake<128d>/min_time:10.000/threads:1 3504 ns 3504 ns 3976879 abs_delta=1.94251n bytes=584.494M/s pairs=285.397k/s relative_error=2.99757u
bilinear_f64_serial<128d>/min_time:10.000/threads:1 5528 ns 5528 ns 2529904 abs_delta=0 bytes=370.459M/s pairs=180.888k/s relative_error=0
bilinear_f64c_serial<128d>/min_time:10.000/threads:1 12324 ns 12324 ns 1140788 abs_delta=0 bytes=332.371M/s pairs=81.1453k/s relative_error=0
bilinear_f32_serial<128d>/min_time:10.000/threads:1 5299 ns 5298 ns 2649614 abs_delta=1.69242n bytes=193.264M/s pairs=188.734k/s relative_error=776.834n
bilinear_f32c_serial<128d>/min_time:10.000/threads:1 10217 ns 10216 ns 1370535 abs_delta=1.89398n bytes=200.461M/s pairs=97.8816k/s relative_error=3.25219u
bilinear_f16_serial<128d>/min_time:10.000/threads:1 42372 ns 42371 ns 330369 abs_delta=1.93284n bytes=12.0838M/s pairs=23.6011k/s relative_error=1.51289u
bilinear_f16c_serial<128d>/min_time:10.000/threads:1 46101 ns 46100 ns 303997 abs_delta=1.77214n bytes=22.2124M/s pairs=21.6918k/s relative_error=1.5494u
bilinear_bf16_serial<128d>/min_time:10.000/threads:1 85325 ns 85324 ns 163256 abs_delta=1.34067n bytes=6.00066M/s pairs=11.72k/s relative_error=527.801n
bilinear_bf16c_serial<128d>/min_time:10.000/threads:1 178970 ns 178967 ns 78235 abs_delta=1.46323n bytes=5.72174M/s pairs=5.58764k/s relative_error=1004.88n
Highlights:
- Single- and double-precision kernels are only about 5% faster than BLAS due to removed temporary buffer stores.
- Both `bf16` and `f16` kernels provide linear speedups proportional to the number of bits in the data type.
On low-dimensional inputs, the performance gap is larger:
---------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations UserCounters...
---------------------------------------------------------------------------------------------------------------
bilinear_f64_blas<8d>/min_time:10.000/threads:1 42.7 ns 42.7 ns 328247670 abs_delta=15.9107a bytes=3.00004G/s pairs=23.4378M/s relative_error=550.946a
bilinear_f64c_blas<8d>/min_time:10.000/threads:1 57.4 ns 57.4 ns 243896993 abs_delta=21.3452a bytes=4.46378G/s pairs=17.4366M/s relative_error=514.643a
bilinear_f32_blas<8d>/min_time:10.000/threads:1 32.2 ns 32.2 ns 434784869 abs_delta=6.73645n bytes=3.97757G/s pairs=31.0747M/s relative_error=235.395n
bilinear_f32c_blas<8d>/min_time:10.000/threads:1 50.6 ns 50.6 ns 276504577 abs_delta=7.97379n bytes=2.52823G/s pairs=19.7518M/s relative_error=251.204n
bilinear_f16_haswell<8d>/min_time:10.000/threads:1 13.7 ns 13.7 ns 1000000000 abs_delta=6.06053n bytes=9.35133G/s pairs=73.0573M/s relative_error=139.096n
bilinear_bf16_haswell<8d>/min_time:10.000/threads:1 13.0 ns 13.0 ns 1000000000 abs_delta=5.03892n bytes=9.84787G/s pairs=76.9365M/s relative_error=114.101n
bilinear_bf16_genoa<8d>/min_time:10.000/threads:1 12.6 ns 12.6 ns 1000000000 abs_delta=5.63947n bytes=10.1297G/s pairs=79.1384M/s relative_error=166.305n
bilinear_bf16c_genoa<8d>/min_time:10.000/threads:1 69.0 ns 69.0 ns 203022389 abs_delta=1.61581n bytes=1.85573G/s pairs=14.4979M/s relative_error=60.9203n
bilinear_f16_sapphire<8d>/min_time:10.000/threads:1 8.52 ns 8.52 ns 1000000000 abs_delta=51.4863u bytes=15.0256G/s pairs=117.387M/s relative_error=1.92771m
bilinear_f16c_sapphire<8d>/min_time:10.000/threads:1 64.6 ns 64.6 ns 216692584 abs_delta=43.8492u bytes=1.98133G/s pairs=15.4791M/s relative_error=1.48218m
bilinear_f32_skylake<8d>/min_time:10.000/threads:1 7.28 ns 7.28 ns 1000000000 abs_delta=8.92396n bytes=17.5799G/s pairs=137.343M/s relative_error=266.557n
bilinear_f32c_skylake<8d>/min_time:10.000/threads:1 42.8 ns 42.8 ns 326789735 abs_delta=10.4774n bytes=2.98821G/s pairs=23.3454M/s relative_error=267.67n
bilinear_f64_skylake<8d>/min_time:10.000/threads:1 7.16 ns 7.16 ns 1000000000 abs_delta=16.8322a bytes=17.8732G/s pairs=139.634M/s relative_error=776.898a
bilinear_f64c_skylake<8d>/min_time:10.000/threads:1 31.2 ns 31.2 ns 449958679 abs_delta=17.4692a bytes=8.20188G/s pairs=32.0386M/s relative_error=477.326a
bilinear_f64_serial<8d>/min_time:10.000/threads:1 19.3 ns 19.3 ns 724453573 abs_delta=0 bytes=6.63046G/s pairs=51.8005M/s relative_error=0
bilinear_f64c_serial<8d>/min_time:10.000/threads:1 47.7 ns 47.7 ns 293638808 abs_delta=0 bytes=5.36703G/s pairs=20.965M/s relative_error=0
bilinear_f32_serial<8d>/min_time:10.000/threads:1 18.4 ns 18.4 ns 759547931 abs_delta=7.93122n bytes=6.94336G/s pairs=54.245M/s relative_error=213.04n
bilinear_f32c_serial<8d>/min_time:10.000/threads:1 45.6 ns 45.6 ns 307012654 abs_delta=9.52236n bytes=2.80829G/s pairs=21.9398M/s relative_error=282.08n
bilinear_f16_serial<8d>/min_time:10.000/threads:1 171 ns 171 ns 81713243 abs_delta=7.46151n bytes=747.117M/s pairs=5.83685M/s relative_error=187.409n
bilinear_f16c_serial<8d>/min_time:10.000/threads:1 208 ns 208 ns 67195854 abs_delta=8.79194n bytes=614.281M/s pairs=4.79907M/s relative_error=265.818n
bilinear_bf16_serial<8d>/min_time:10.000/threads:1 359 ns 359 ns 38947709 abs_delta=5.77119n bytes=356.094M/s pairs=2.78198M/s relative_error=122.725n
bilinear_bf16c_serial<8d>/min_time:10.000/threads:1 744 ns 744 ns 18821435 abs_delta=6.72388n bytes=172.071M/s pairs=1.34431M/s relative_error=145.277n
Highlights:
- For `f32`, the performance __grew from 31.07 to 137.34 Million__ operations per second.
- For `f64`, the performance __grew from 23.44 to 139.63 Million__ operations per second.
Commits
- Add: Bilinear complex kernels for NEON (cd15779)
- Add: Half-precision bilinear forms on x86 (f804694)
- Add: `f64` bilinear forms in AVX-512 (7f24d59)
- Add: Complex bilinear forms (3a56174)
- Add: Complex structs (ee2de83)
- Docs: Improved benchmarks table (acc61b5)
- Make: Revert `s390x` and `ppc64le` support (82666ff)
- Make: Pin `cibuildwheel==2.21.3` (d865a46)
- Make: Drop `s390x` and `ppc64le` (5d9a219)
- Improve: Fewer PyTest runs (6effbea)
- Make: Override s390x container (08c7ac0)
- Make: Bump Py & TinySemVer (03eba14)
- Make: Upgrade Py CI (d82c2bb)
- Improve: Computing masks in AVX-512 (f259492)
- Improve: Unroll AVX-512 bilinear forms (05e29b3)
- Docs: Multi-threading in BLAS (4f351ad)
- Fix: Write beyond buffer bounds (2bef182)
- Improve: cBLAS Bilinear Form benchmarks (1e5b9d7)
- Improve: Shorter macros (15eedb5)
- Improve: Dispatch complex kernels (3ba20d0)
- Fix: Missing `idx_scalars` in SVE (19a56bc)
- Docs: Using `complex` & `int` types (67e39e6)
- Improve: Complex numbers handling in `bench.cxx` (b923c1a)
- Improve: Use complex types for `dense.h` (8941462)