Divergence functions are a bit more complex than the Cosine Similarity, primarily because they have to compute logarithms, which are relatively slow when using LibC's `logf`.
So, aside from minor patches, in this PR, I've rewritten the [Jensen Shannon](https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence) distances leveraging several optimizations, mainly focusing on AVX-512 and AVX-512FP16 extensions, which resulted in __4.6x__ improvement over the auto-vectorized single-precision variant and a whopping __118x__ improvement over the half-precision code produced by GCC 12.
Optimizations
- __Logarithm Computation__. Instead of multiple bitwise operations, `_mm512_getexp_ph` and `_mm512_getmant_ph` are now used to extract the exponent and the mantissa of the floating-point number, streamlining the process. I've also used Horner's method for the polynomial approximation.
- __Division Avoidance__. To avoid expensive division operations, reciprocal approximations are utilized - `_mm512_rcp_ph` for half-precision and `_mm512_rcp14_ps` for single-precision. The `_mm512_rcp28_ps` was found to be unnecessary for this implementation.
- __Handling Zeros__. The `_mm512_cmp_ph_mask` is used to compute a mask for close-to-zero values, avoiding the addition of an "epsilon" to every component, which is both cleaner and more accurate.
- __Parallel Accumulation__. The accumulation of $KL(P||Q)$ and $KL(Q||P)$ are now handled in separate registers, and the masked `_mm512_maskz_fmadd_ph` replaces distinct addition and multiplication operations, optimizing the calculation further.
Implementation
To remind, the Jensen Shannon divergence is the symmetric version of the Kullback-Leibler divergence:
math
JSD(P, Q) = \frac{1}{2} D(P || M) + \frac{1}{2} D(Q || M) \\
math
M = \frac{1}{2}(P + Q), D(P || Q) = \sum P(i) \cdot \log \left( \frac{P(i)}{Q(i)} \right)
For AVX-512FP16, the current implementation looks like this:
c
__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline __m512h simsimd_avx512_f16_log2(__m512h x) {
// Extract the exponent and mantissa
__m512h one = _mm512_set1_ph((_Float16)1);
__m512h e = _mm512_getexp_ph(x);
__m512h m = _mm512_getmant_ph(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src);
// Compute the polynomial using Horner's method
__m512h p = _mm512_set1_ph((_Float16)-3.4436006e-2f);
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1821337e-1f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-1.2315303f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)2.5988452f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)-3.3241990f));
p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((_Float16)3.1157899f));
return _mm512_add_ph(_mm512_mul_ph(p, _mm512_sub_ph(m, one)), e);
}
__attribute__((target("avx512f,avx512vl,avx512fp16")))
inline static simsimd_f32_t simsimd_avx512_f16_js(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n) {
__m512h sum_a_vec = _mm512_set1_ph((_Float16)0);
__m512h sum_b_vec = _mm512_set1_ph((_Float16)0);
__m512h epsilon_vec = _mm512_set1_ph((_Float16)1e-6f);
for (simsimd_size_t i = 0; i < n; i += 32) {
__mmask32 mask = n - i >= 32 ? 0xFFFFFFFF : ((1u << (n - i)) - 1u);
__m512h a_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a + i));
__m512h b_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b + i));
__m512h m_vec = _mm512_mul_ph(_mm512_add_ph(a_vec, b_vec), _mm512_set1_ph((_Float16)0.5f));
// Avoid division by zero problems from probabilities under zero down the road.
// Masking is a nicer way to do this, than adding the `epsilon` to every component.
__mmask32 nonzero_mask_a = _mm512_cmp_ph_mask(a_vec, epsilon_vec, _CMP_GE_OQ);
__mmask32 nonzero_mask_b = _mm512_cmp_ph_mask(b_vec, epsilon_vec, _CMP_GE_OQ);
__mmask32 nonzero_mask = nonzero_mask_a & nonzero_mask_b & mask;
// Division is an expensive operation. Instead of doing it twice,
// we can approximate the reciprocal of `m` and multiply instead.
__m512h m_recip_approx = _mm512_rcp_ph(m_vec);
__m512h ratio_a_vec = _mm512_mul_ph(a_vec, m_recip_approx);
__m512h ratio_b_vec = _mm512_mul_ph(b_vec, m_recip_approx);
// The natural logarithm is equivalent to `log2`, multiplied by the `loge(2)`
__m512h log_ratio_a_vec = simsimd_avx512_f16_log2(ratio_a_vec);
__m512h log_ratio_b_vec = simsimd_avx512_f16_log2(ratio_b_vec);
// Instead of separate multiplication and addition, invoke the FMA
sum_a_vec = _mm512_maskz_fmadd_ph(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec);
sum_b_vec = _mm512_maskz_fmadd_ph(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec);
}
simsimd_f32_t log2_normalizer = 0.693147181f;
return _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * 0.5f * log2_normalizer;
}
Benchmarks
I conducted benchmarks at both the higher-level Python and lower-level C++ layers, comparing the auto-vectorization on GCC 12 to our new implementation on an Intel Sapphire Rapids CPU on AWS:
<img width="1178" alt="Screenshot 2023-10-23 at 14 33 03" src="https://github.com/ashvardanian/SimSIMD/assets/1983160/39d4b937-b684-4a5f-a010-dafa6dc8d114">
The program was compiled with `-O3` and `-ffast-math` and was running on all cores of the 4-core instance, potentially favoring the non-vectorized solution. When normalized and tabulated, the results are as follows:
| Benchmark | Pairs/s | Gigabytes/s | Absolute Error | Relative Error |
|-----------------------|----------:|------------:|---------------:|---------------:|
| `serial_f32_js_1536d` | 0.243 M/s | 2.98 G/s | 0 | 0 |