It looped through each pixel in the thumbnail, selected a slice for the mosaic then updated that slice. As is common with imperative code it tells a story from the start (iterating through thumbnail) to the end (filled mosaic) as a series of steps.
For my first iteration of the JAX version I was wise enough to not try to describe these steps as nested loops, but I still started from a thumbnail array and tried to transform it into a mosaic array via a series of transformations performed as vmaps. I will spare you the gory details of my several failed attempts to get this working. I could have got it to work functionally , but my code was getting horribly complex and was also performing horrendously. I had eventually resorted to using the new experimental loops feature. The looping constructs themselves have a nice and simple API to work with but whatever code you place inside a loop becomes a black hole to the debugger. My couple of days spent wrestling with jax had left me frustrated by the fact that I could not find my errors or diagnose performance issues in my code. I was on the verge of giving up on jax.
Not ready to admit defeat, I started again from scratch. This time determined to solve my little image creation puzzle with no loops, and none of the other features that got me into trouble by enticing me to replicate my original 1.0 imperative coding style: lax.dynamic_slice and x.at[slice].set(). The hardest part of all of this was knowing that it was time to hit the reset button and start again. My new strategy was working backwards from something in the shape of the desired outputs not forward from something in the shape of a thumbnail. In the future I will adopt this strategy from the start for anything that I do with JAX.
The selfie below was my first output from a still very basic implementation of jax ChopShop. The notebook shows a simplified walk through of the thought process behind my "work backwards" strategy.