[This post](https://github.com/tensorflow/tensorflow/issues/54463) in a Tensorflow issue thread points out that CUBLAS 11.3 (which handles matrix multiplication on the GPU) has a bug where if a matrix dimension exceeds 2^20 (around a million) then matrix multiplication may produce incorrect results. This seems to happen only when `tf.einsum` is used, not when we do matmul directly, so I replaced einsum with matmul to work around a crash.