Wrestling with vmap in JAX
When I discovered JAX, it ticked all the right boxes for me: performant, flexible and extensible number crunching using a familiar API and a functional programming mindset. I was keen to move beyond the tutorial phase and test it with a real life application. I decided to port a personal project that I wrote almost 10 years ago in numpy and OpenCV to jax. The old project is called "ChopShop". It is a library for computational art.
ChopShop 1.0
I dabble in computational art as an extension to my dabbles with photography and painting . I use computational art to experiment with analytic techniques in a different context from the rest of my work. At the heart of ChopShop 1.0 is a Monte-Carlo based recommendation engine. It takes a source image (such the Mona Lisa) and does some image processing on it to turn it into a tiny little thumbnail. This little thumbnail is like a crude sketch. There is no detail. The sketch only contains some key defining "features" of the original image. ChopShop then analyses this thumbnail pixel by pixel and finds another image in its library that contains the features present in each pixel of the thumnail. It creates a new image as a mosaic of thousands of images selected from the library. In the case of the Mona Lisa above, the target images were paper folds. In what way is this a Monte-Carlo method? Each source feature is described by a statistical distribution. The target images also have features that are described by distributions. ChopShop evaluates source distribution/s and samples library images that are similar to the source distribution/s, It them makes a random selection for each co-ordinate in the mosaic based on a match between summary stats between source and library distributions.
Getting into trouble with JAX ChopShop
When I wrote ChopShop 1.0 using numpy and opencv many moon ago, I used a fair amount of imperative programming and made extensive use of numpy's immutability. The overall structure of my code looked roughly like this:
It looped through each pixel in the thumbnail, selected a slice for the mosaic then updated that slice. As is common with imperative code it tells a story from the start (iterating through thumbnail) to the end (filled mosaic) as a series of steps.
For my first iteration of the JAX version I was wise enough to not try to describe these steps as nested loops, but I still started from a thumbnail array and tried to transform it into a mosaic array via a series of transformations performed as vmaps. I will spare you the gory details of my several failed attempts to get this working. I could have got it to work functionally , but my code was getting horribly complex and was also performing horrendously. I had eventually resorted to using the new experimental loops feature. The looping constructs themselves have a nice and simple API to work with but whatever code you place inside a loop becomes a black hole to the debugger. My couple of days spent wrestling with jax had left me frustrated by the fact that I could not find my errors or diagnose performance issues in my code. I was on the verge of giving up on jax.
Total reset in thinking
Not ready to admit defeat, I started again from scratch. This time determined to solve my little image creation puzzle with no loops, and none of the other features that got me into trouble by enticing me to replicate my original 1.0 imperative coding style: lax.dynamic_slice and x.at[slice].set(). The hardest part of all of this was knowing that it was time to hit the reset button and start again. My new strategy was working backwards from something in the shape of the desired outputs not forward from something in the shape of a thumbnail. In the future I will adopt this strategy from the start for anything that I do with JAX.
The selfie below was my first output from a still very basic implementation of jax ChopShop. The notebook shows a simplified walk through of the thought process behind my "work backwards" strategy.
Image by author