Doing Quantum Mechanics with a Machine Learning Framework: PyTorch and Correlated Gaussian Wavefunctions: Part 1) Introduction



A Quantum Mechanics problem coded up in PyTorch?! Sure! Why not?

Machine Learning / Deep Learning Frameworks are primarily designed for solving problems in Statistical modeling and have utilities to make working with Artificial Neural Networks, convolutions, numerical differentiation i.e. back-propagation, etc., easier. They are generally highly optimized including taking advantage of GPU compute acceleration. However, at their core they implement a set of very useful Numerical Linear Algebra routines. Those core algorithms are general purpose and can be applied to a wide variety of problems that involve Matrix/Tensor mathematical calculations. I’ll look at one of these “other” problems.

My problem of interest is a “simple” Quantum Mechanics eigen-function partial differential equation for few particle systems utilizing a wave-function expansion with “correlated Gaussian basis functions”. It’s a relatively simple problem really, and we can code the whole thing up a in couple hundred lines of Python using PyTorch.

I’ll explain just enough of the Quantum Mechanics and Mathematics to make the problem and solution (kind of) understandable. The focus is on how easy it is to implement in PyTorch.

This first post will give some explanation of the problem and do some testing of a couple of the formulas that will need to be coded up.


Why PyTorch?

The main reason I chose PyTorch is that I’m interested in it. I started playing with it a few weeks ago and wrote this post, Why You Should Consider PyTorch (includes Install and a few examples). I was thinking about coding up some of my old scientific work and decided to give PyTorch a try. I wasn’t sure if this was feasible at first but after looking deeper into PyTorch I realized that everything I needed was there. At least for a naive simple implementation of the problem. Here are some considerations,

  • PyTorch is compatible with Numpy and has equivalents for most of it’s routines
  • Switching between CPU and CUDA on GPU is seamless and very simple.
  • Automatic differentiation is well implemented and relatively easy to use
  • PyTorch contains a rich set of both CPU and CUDA based BLAS (Basic Linear Algebra Subroutines) and Lapack (higher level linear algebra algorithms).
  • It has “batched” routines to extend matrix operation to larger Tensors structures.
  • It’s interactive and easy to use in a Jupyter notebook!

What about TensorFlow?

TensorFlow may be a better choice(??) The development community around TensorFlow has created a huge library of functionality. It is hard not to recommend at least considering it. The main draw back is that it has a substantial learning curve and doesn’t have the “Pythonic” interactive feel of PyTorch. If I was starting work on an important new scientific programming project I would seriously consider using TensorFlow as a numerical computing framework. I would still use PyTorch for quick prototyping and testing (which is basically what I do in this post).


The Quantum Mechanics

If any of the math makes your head hurt then just read it as “bla bla bla”. The important thing will be to realize that you can implement some interesting matrix equations with PyTorch. Here is your brief introduction to doing quantum mechanics of few particle systems from “first principles”.

The basic QM problem that we are looking at is trying to find the lowest energy state of a system with a small number of particles (electrons, and atomic nuclei … or maybe interesting particles like muons and positrons). We want to solve the Schrodinger equation,

$$H\Psi = E\Psi$$

$H$ is the Hamiltonian energy operator, $E$ is an energy eigenvalue and $\Psi$ is an eigenfunction of $H$, the “wavefunction” of the system. We can write down expressions for $H$ for almost any system of particles but for problems with more than 2 particles we can’t find an exact wavefunction that solves the problem. It’s just mathematically not exactly solvable except for a few model systems (it is the basic many body problem). What we can do is turn the problem into a numerical partial differential equation that can be written as a Matrix eigenvalue problem by expanding $\Psi$ in a basis set of appropriate functions $\phi_k$ that we can find integrals for.

Our system is a collection of $p$ particles with masses $\{M_1,\cdots ,M_p\}$ and charges $\{Q_1,\cdots ,Q_p\}$ interacting under a coulomb potential (atoms and molecules). The Hamiltonian can be transformed from real particle coordinates, to “internal” coordinates by removing the center of mass. We want the internal energy so we remove the translational energy from the equations. That removes 3 degrees of freedom from the problem and that is an exact transformation. What you are left with is a set of p-1 “pseudo particles” that will be coordinates for the internal interactions between particles.

In these transformed coordinates the Hamiltonian energy operator $H$ is the sum of kinetic $T$ and potential $V$ energy terms and looks like,

$$H = -\frac 12\left( \sum_i^n\frac 1{\mu_i}\nabla_i^2+\sum_{i,j}^n\frac 1{M_1}\nabla_i\cdot \nabla_j\right) + \sum_i^n\frac {q_0q_i}{r_i} + \sum_i^n \sum_{i

Here $\mu_i=M_1M_i/\left( M_1+M_i\right)$ is the reduced mass of (pseudo) particle $i$, $M_1$ is the mass at the coordinate origin, $\nabla_i$ is the gradient operator with respect to the $x,y,z$ coordinates of particle $i$, $q_0=Q_1,\,q_1=Q_2,\ldots$, and $r_{ij}=\left\| r_i-r_j\right\|$ where $r_i$ is the Cartesian coordinate vector for particle $i$. [I normally transform this into a matrix vector operator but it is easier to understand with the summation notation.]

The basis functions we’ll use consist of negative expontentials of positive definite quadratic forms (multiple dimensional Gaussian functions).

$$\begin{eqnarray} \phi_k &=&\exp \left[ -r^{\prime }\left( L_kL_k^{\prime }\otimes I_3\right) r\right] &=&\exp \left[ -r^{\prime }\left( A_k\otimes I_3\right) r\right] &=&\exp \left[ -r^{\prime }\bar{A}_kr\right] \end{eqnarray}$$

Here $r$ is a $3n\times 1$ vector of Cartesian coordinates for the $n$ particles, $L_k$ is an $n\times n$ rank $n$ lower triangular matrix. $k$ ranges from 1 to $N$ where $N$ is the number of basis functions. $A_k=L_kL_k^{\prime }$ is written in this Cholesky factored form to assure positive definiteness of the quadratic form.

Correlation in the basis is achieved by including terms of the form $a_{ij}r_i\cdot r_j$ in the quadratic form, i.e. $\exp \left[ -r^{\prime }\left( A_k\otimes I_3\right) r\right] =\exp \left[ -\sum_{i,j}a_{ij}r_i\cdot r_j\right]$. This is perhaps easier to see by noting the identity, $r_{ij}^2=r_i\cdot r_i+r_j\cdot r_j-2r_i\cdot r_j$. In this sense the $\phi_k$ contain information on all inter-particle distances $r_{ij}$ and are thus explicitly correlated.

The Kronecker product with the $3\times 3$ identity matrix $I_3$ insures rotational invariance of the basis functions. The $\phi_k$ are simultaneously angular momentum eigenfunctions with total angular momentum $J=0$.

There is one more detail, permutational symmetry and “spin”. The Hamiltonian operator is invariant to interchange of like particles. Also, one of the defining aspects of quantum mechanics is the symmetry of the wavefunction that is induced by spin angular momentum (particularly the half integral spin of electrons). Lets just say that this introduces some very interesting math and leads to a difficult computational challenge — it scales as $n!$ (n factorial). [It is the process of “sweeping the symmetry problem under the rug” that makes practical quantum chemistry so difficult and also makes it possible to do it for larger systems.] It is consideration of the symmetry aspects of the system that defines the valid “particular solutions” to the Schrodinger equation. I’m only going to say that for what I’m doing we need a proper symmetry projection operator to get the wavefunction we are looking for. What we are doing is close to “pure” quantum mechanics and we have a correlated wavefunction. So when you see something about “symmetry” that’s what is being referred to. We are doing quantum mechanics with very few approximations!

The end result of the discussion above is that we need to find “matrix elements” for $H$ and for the “overlap” $S$ and then minimize the smallest eigenvalue of the generalized eigenvalue equation $(H-ES)c = 0$. $H$ and $S$ are matrices of integrals like,

$$H_{kl} = \langle\phi_k|H|\phi_l\rangle = \int^\infty_{-\infty} \phi_k H \phi_l dr$$

$\langle\phi_k|H|\phi_l\rangle$ is a probabilistic “expectation value” for the operator $H$ (the energy operator). This is the “bra” “ket” notation of quantum mechanics. It’s basically the functional form of an inner product in a vector space. It’s an integral of the coordinates $r$ over all space.

Those are non-trivial integrals, as you can imagine by looking at the definition of $H$ and $\phi_k$. I will have mercy and not derive the formulas here but, we will code them up in PyTorch!

There, now you know the fundamentals of “doing” quantum mechanics. Congratulations if you read through this!


Overlap Matrix Elements and Gradients

To get started we’ll code up the formula for the “normalized” overlap integrals $S_{kl}$ and their derivatives. This will show how simple it can be to go direct from the math to the code. We’ll also check automatic differentiation in PyTorch against the analytic derived gradient.

This is just the first step. When the complete code is done from later blog posts I’ll put it all in a Jupyter notebook and put it on GitHub.

We’re going to compute two terms. The first is the normalized overlap matrix element including the symmetry projection term (a permutation matrix).

Let $A_{kl}=A_k+\tau_P^{\prime }A_l\tau_P$ then the overlap matrix element is defined as, ($\tau_P$) is a term in the symmetry projection operator $\mathcal{P}$.

$$S_{kl}=\frac{\left\langle \phi_k\right. |\left. \mathcal{P}\phi_l\right\rangle }{\left( \left\langle \phi_k\right. |\left. \phi_k\right\rangle \left\langle \phi_l\right. |\left. \phi_l\right\rangle \right) ^{1/2}}=\sum_P\chi_p 2^{3n/2}\left( \frac{\left\| L_k\right\| \left\| L_l\right\| }{\left| A_{kl}\right| }\right) ^{3/2}$$

We’ll only code for one term in that sum.

The other formula to code up is the derivative of the term above with respect to the matrices $L_k$ and $L_l$.

$$\frac{\partial S_{kl}^P}{\partial \left( \,\mathrm{vech}\,L_k\right)^{\prime }}=\frac 32S_{kl}^P\left( \,\mathrm{vech}\,\left[ \left( \,\mathrm{diag}\,L_k\right) ^{-1}-2A_{kl}^{-1}L_k\right] \right)^{\prime }$$

$$\frac{\partial S_{kl}^P}{\partial \left(\,\mathrm{vech}\,L_l\right)^{\prime}}=\frac 32S_{kl}^P\left( \,\mathrm{vech}\,\left[ \left( \,\mathrm{diag}\,L_l\right)^{-1}-2\tau_PA_{kl}^{-1}\tau_P^{\prime }L_l\right]\right) ^{\prime }$$

  • $\left\|L_k\right\|$ is the absolute value of the determinate of $L_k$
  • $\left| A_{kl}\right|$ is hte determinate of $A_{kl}$
  • vech() extracts the lower triangular elements of a matrix to a vector with $n(n+1)/2$ elements
  • diag() keeps just the diagonal elements of a matrix
  • $A^{-1}$ is the matrix inverse of $A$
  • $A^\prime$ is the transpose of $A$

The PyTorch Code

  • n: the number of “pseudo” particles
  • vechLk: nonlinear exponent parameters n(n+1)/2 (parameters we will need the gradient with-respect-to)
  • vechLl: these will form the lower triangle matrices Lk and Ll
  • Sym: symmetry projection matrix for the term being computed

Now for some code!

(I’ll only test the overlap and it’s gradient terms in this post.)

import torch as th   # PyTorch is imported as "torch"
dtype = th.float64   # Use float32 if you are on GeForce GPU

gpuid = 0
device = th.device("cuda:"+ str(gpuid))
#device = th.device("cpu")  # un-comment to change back to CPU

print("Execution device: ",device)
print("PyTorch version: ", th.__version__ )
print("CUDA available: ", th.cuda.is_available())
print("CUDA version: ", th.version.cuda)
print("CUDA device:", th.cuda.get_device_name(gpuid))
Execution device:  cuda:0
PyTorch version:  0.4.0
CUDA available:  True
CUDA version:  9.1.85
CUDA device: TITAN V
# Utility functions

# return the lower triangle of A in column order i.e. vech(A)
def vech(A):
    count = 0
    c = A.shape[0]
    v = th.zeros(c * (c + 1) // 2, device=device, dtype=dtype)
    for j in range(c):
        for i in range(j,c):
            v[count] = A[i,j]
            count += 1
    return v

# vech2L   create lower triangular matrix L from vechA
def vech2L(v,n):
    count = 0
    L = th.zeros((n,n), device=device, dtype=dtype)
    for j in range(n):
        for i in range(j,n):
            L[i,j]=v[count]
            count += 1
    return L
def matrix_elements(n, vechLk, vechLl, Sym):

    # reconstruct lower triangular matrices
    Lk = vech2L(vechLk,n);
    Ll = vech2L(vechLl,n);

    # apply symmetry projection on Ll  
    # th.t() is shorthand for th.transpose(X, 0,1)
    PLl = th.t(Sym) @ Ll;

    # build Ak, Al, Akl, invAkl

    Ak = [email protected](Lk);
    Al = [email protected](PLl);
    Akl = Ak+Al;

    invAkl = th.inverse(Akl);

    # Overlap: (normalized)
    skl = 2**(3*n/2) * th.sqrt( th.pow(th.abs(th.det(Lk))*th.abs(th.det(Ll))/th.det(Akl) ,3) );

    #Analytic gradient formulas with respect to vechLk vechLl
    checkdsk = vech( 3/2 * skl * (th.diag(1/th.diag(Lk)) - 2*invAkl@Lk) )
    checkdsl = vech( 3/2 * skl * (th.diag(1/th.diag(Ll)) - 2*Sym@invAkl@PLl) )

    # Now get the gradient terms using autograd
    dsk = th.autograd.grad(skl, vechLk, retain_graph=True)
    dsl = th.autograd.grad(skl, vechLl)


    return {'skl':skl, 'dsk':dsk, 'dsl':dsl, 'checkdsk':checkdsk, 'checkdsl':checkdsl}
def test_matrix_elements():
    n = 3;

    # using numbers that I know the correct results for
    vechLk = th.tensor([  1.00000039208682,
              0.02548044275764261,
              0.3525161612610669,
              1.6669144815242515,
              0.9630555318946559,
              1.8382882034659822 ], device=device, dtype=dtype, requires_grad=True);

    vechLl = th.tensor([  1.3353550436464964,
               0.9153272033682132,
               0.7958636766525028,
               1.8326931436447955,
               0.3450426931160630,
               1.8711839323167831 ], device=device, dtype=dtype, requires_grad=True);

    Sym = th.tensor([[0,0,1],
                    [0,1,0],
                    [1,0,0]], device=device, dtype=dtype);

    matels = matrix_elements(n, vechLk, vechLl, Sym)

    print('skl:      ',matels['skl'])
    print('dsk:      ',matels['dsk'])
    print('checkdsk:  ',matels['checkdsk'])
    print('dsl:      ',matels['dsl'])
    print('checkdsl:  ',matels['checkdsl'])
test_matel()
skl:       tensor(0.5334, dtype=torch.float64, device='cuda:0')
dsk:       (tensor([ 0.4898,  0.0786, -0.0560,  0.1179, -0.1113, -0.1632], dtype=torch.float64, device='cuda:0'),)
checkdsk:   tensor([ 0.4898,  0.0786, -0.0560,  0.1179, -0.1113, -0.1632], dtype=torch.float64, device='cuda:0')
dsl:       (tensor([ 0.3198, -0.0666, -0.1495, -0.0751, -0.0352, -0.1917], dtype=torch.float64, device='cuda:0'),)
checkdsl:   tensor([ 0.3198, -0.0666, -0.1495, -0.0751, -0.0352, -0.1917], dtype=torch.float64, device='cuda:0')

Notice how easy it is to use automatic differentiation. I just added “requires_grad=True” when defining the vectors of independent variables, vechLk and vechLl. That tells PyTorch to keep track of those variables as it builds a computation graph so that it can generate gradient information. Then “dsk = th.autograd.grad(skl, vechLk, retain_graph=True)” returns the gradient of the function skl with respect to vehcLk (the retain_graph-True past says I’m not done with that “graph” yet). Brilliant and easy!

This all looks good. I feel confident that “autograd” will do the right thing so I can avoid coding up the gradient terms from analytic formulas (even thought I have those formulas in this case).

The code above is a direct implementation of the math using high level constructs available in PyTorch. It is not necessarily optimal code structure! In fact since these matrices are small it is surely not optimal. However, when we finish coding up the entire problem there will be a substantial amount of computing going on. I believe there will be opportunity to use “batch” matrix operations which will utilize the hardware and give good performance.


References

In the next post we’ll add the Kinetic and Potential energy terms and use those to build out the full energy and gradient calculation. After that we’ll set up the optimization and run some real calculations. I’m really having fun with this!

The two primary references for this post are, (if you are interested in the science)

Integrals and derivatives for correlated Gaussian functions using matrix differential calculus
and,
Implementation of gradient formulas for correlated gaussians: He, ∞He, Ps2, 9Be, and ∞Be test results

However those are in “pay-walled” journals so you may not be able to find them. (Sorry I don’t have reprints. I don’t even have the original LaTex code I wrote them in!)

This link on “ResearchGate” has references to several papers that reference this work and some of them are likely to be available on-line. Warning, some of those papers are going to be real heavy on the math!

Happy computing –dbk