Sunday, December 14, 2014

ScalaZ - Getting to grips with the Free Monad

Free Monads

About 18 months ago I started working at a Maana Inc, a startup in the Big Data space. About 12 months ago we decided that we'd make the move from Java to Scala as our language of choice.

With hindsight, it was a good decision, if you consider that all programs ever really do is transform data from one form to another, Scala (in fact any functional language) is so much better at expressing these transformations than Java. Code is more compact, more readable and as a result more maintainable.

ScalaZ has interested me for some time and I've started integrating some concepts into our codebase, Validation and Either are the obvious wins in the library, but the library as a whole is enormous and it's hard to even get a handle on parts of it.  A friend had sent me this link, which turned me onto the Free Monad, I think Runar's talk presents a compelling idea. However all I'm going to discuss here is my attempting to get to grips with what the Free Monad itself and specifically it's use with interpreters.

One of the problems I've found with ScalaZ is the lack of documentation and examples, there are lots of blogs that discuss Haskell and the Free Monad, the ScalaZ variant is much less commonly discussed, and often the discussions show code that won't even compile with the latest version of ScalaZ.

The idea of the Free Monad is simple enough, you build an AST, and provide one or more interpreters over the AST, so lets start with some code, I figured a basic stack based language was a good place to start.

Building a Free Monad

The first thing we need to do is define our operators

sealed trait ForthOperators[A]

final case class Push[A](value: Int, o: A) extends ForthOperators[A]
final case class Add[A](o: A) extends ForthOperators[A]
final case class Mul[A](o: A) extends ForthOperators[A]
final case class Dup[A](o: A) extends ForthOperators[A]
final case class End[A](o: A) extends ForthOperators[A]

Since it's a stack based language the only thing that takes an argument is Push, the o: A part is really there just to full fill the Functor/Monad requirement, i.e. it has to wrap a value. When we get to our interpreter implementation it will wrap the continuation.

The Free Monad basically provide a Monad given a Functor, so the next thing we need to provide is our Functor

 implicit val ForthFunctor: Functor[ForthOperators] = new Functor[ForthOperators] {
    def map[A, B](fa: ForthOperators[A])(f: A => B): ForthOperators[B] =
      fa match {
        case Push(value, cont) => Push(value, f(cont))
        case Add(cont) => Add(f(cont))
        case Mul(cont) => Mul(f(cont))
        case Dup(cont) => Dup(f(cont))
        case End(cont) => End(f(cont))
      }
  }

Functor requires we provide a map function, this is basically a mechanical implementation, we just apply f to the wrapped value, if we were using Haskell we'd just specify deriving (Functor), in Scala if we use the Free Monad (later I'll discuss FreeC which doesn't require an explicit Functor) we have to provide the implementation.

We can now use the Free.LiftF function to convert out functor into a Free Monad. To make things easier were going to use a little implicit magic and provide some helpers to allow us to specify Monads for each of the operators. This is really syntactic sugar, but without you would have to specify the encapsulated value every time you used an operator.

  type ForthProg[A] = Free[ForthOperators, A]

  import scala.language.implicitConversions
  implicit def liftForth[A](forth: ForthOperators[A]): ForthProg[A] = Free.liftF(forth)

  def push(value: Int)  = Push(value, ())
  def add = Add(())
  def mul = Mul(())
  def dup = Dup(())
  def end = End(())

Note the type were wrapping (o: A) in the Monad is actually Unit because in this case we don't care about the returned value.

A Program using the DSL

At this point we can use the Monad in for comprehensions (Scala's version of the do statement in Haskell)

    val testProg = for {
      _ <- push(3)
      _ <- push(6)
      _ <- add
      _ <- push(7)
      _ <- push(2)
      _ <- add
      _ <- mul
      _ <- dup
      _ <- add
    } yield ()

The important thing to realize here is that all this for comprehension does is build the AST. And in fact you can define subtrees (subroutines).

    val square = for {
      _ <- dup
      _ <- mul
    } yield ()

    val testProg = for {
      _ <- push(3)
      _ <- square
      _ <- push(4)
      _ <- square
      _ <- add
    } yield ()

OK that's actually pretty cool, but not overly useful without someway to "run" it.

If you go look at the source code for scalaz.Free you'll discover there are a lot of ways to "run" the code,  I'll cover several here.

Providing a Driver

The first and possibly most obvious way is to write your own driver over the tree

  final def runProgram(stack: List[Int], program: ForthProg[Unit]): List[Int] = program.fold(
  _ => stack, {
    case Push(value, cont) =>
      runProgram(value :: stack, cont)
    case Add(cont) =>
      val a :: b :: tail = stack
      runProgram((a + b) :: tail, cont)
    case Mul(cont) =>
      val a :: b :: tail = stack
      runProgram((a * b) :: tail, cont)
    case Dup(cont) =>
      val a :: tail = stack
      runProgram(a :: a :: tail, cont)
    case End(cont) =>
      stack
  })

We use the fold operator which takes 2 functions, running the first function if the Free Monad is Return and the second if not. With that we can now run the program.

println (runProgram(Nil, testProg))

Which correctly prints "List(25)"
There are several issues with this, most notably runProgram is not tail recursive, leading to possible stack space issues. it turns out that this pattern of winding state through a FreeMonad is common enough that there is an entry point that does exactly this "foldRun" which is tail recursive.
To use fold run we have to provide a function like the one below

  def runFn(stack: List[Int], program: ForthOperators[ForthProg[Unit]]): (List[Int], ForthProg[Unit]) = program match {
    case Push(value, cont) =>
      (value :: stack, cont)
    case Add(cont) =>
      val a :: b :: tail = stack
      ((a + b) :: tail, cont)
    case Mul(cont) =>
      val a :: b :: tail = stack
      ((a * b) :: tail, cont)
    case Dup(cont) =>
      val a :: tail = stack
      (a :: a :: tail, cont)
    case End(cont) =>
      (stack, Free.point(()))
  }

Implementation is basically identical to the first driver, but we return the state and the passed in continuation. For end we pass back an empty continuation. We can execute our program like this.

println(testProg.foldRun(List[Int]())(runFn))

and it prints "(List(25),())", the second element in the tuple being the () we have encapsulated in the Free Monad.

Natural Transformations

Yet another way to run something is to use what are called Natural Transformations, basically they map one functor type to another so in our case the Free Monad to some other functor.
So what other functor?
Well to do anything useful we need to carry our stack around, and we know the functor needs to wrap () and not our stack. So we could either implement one or we can use the State Monad that's designed for exactly this purpose.

  type Stack = List[Int]
  type StackState[A] = State[Stack, A]

Now we just need to create the natural transformation

  def runProgram: ForthOperators ~> StackState = new (ForthOperators ~> StackState) {
    def apply[A](t: ForthOperators[A]) : StackState[A] = t match {
      case Push(value : Int, cont) =>
        State((a: Stack) => (value::a, cont))
      case Add(cont) =>
        State((stack : Stack) => {
          val a :: b :: tail = stack
          ((a + b) :: tail, cont)
        })
      case Mul(cont) =>
        State((stack : Stack) => {
          val a :: b :: tail = stack
          ((a * b) :: tail, cont)
        })
      case Dup(cont) =>
        State((stack : Stack) => {
          val a :: tail = stack
          (a :: a :: tail, cont)
        })
      case End(cont) =>
        // This doesn't work as intended there may not
        // be a way to do this using ~>
        State((a : Stack) => (a, cont))
    }
  }

It requires us to implement a single apply method that given an instance of the Free Monad returns an instance of our State Monad. It's worth noting that the code is slightly larger in this case because the StateMonad stores a functions that transform the state.
Note one disadvantage of this mechanism is you lose the ability to terminate the program flow early, End here becomes a Noop.
To run the program using the natural transformation we use this

println(testProg.foldMap(runProgram).exec(List[Int]()))

The call to foldMap transforms the Free Monads into StateMonads, but we still haven't done anything yet, so we call exec and provide an initial stack. And unsurprisingly it prints "List(25)"

FreeC

When I asked about idiomatic uses of the scalaz Free Monad on the #scalaz IRC channel, I was pointed at some example code here. It uses Natural Transformation, but rather than creating a scalaz.Free it instead uses scalaz.Free.FreeC which uses the Coyoneda (which I don't pretend to understand). The big advantage is you no longer need to provide the functor implementation.

Defining the operators remains the same, we no longer provide the implicit Functor, and lifting to FreeC becomes.

  type ForthProg[A] = Free.FreeC[ForthOperators, A]

  import scala.language.implicitConversions
  implicit def liftForth[A](forth: ForthOperators[A]): ForthProg[A] = Free.liftFC(forth)

  def push(value: Int)  = Push(value, ())
  def add = Add(())
  def mul = Mul(())
  def dup = Dup(())
  def end = End(())

The runProgram Natural Transform remains the same and we execute it as follows


    println(Free.runFC(testProg)(runProgram).exec(List[Int]()))

Which of course prints "List(25)".
I think not having to explicitly provide the functor makes  this the preferable model.

Another Interpreter

And finally we can provide an alternate interpreter that will just print out the program, no state required in this case so we just provide a transform to the Identity Monad.

  import scalaz.Id._

  def printProgram: ForthOperators ~> Id = new (ForthOperators ~> Id) {
    def apply[A](t: ForthOperators[A]): Id[A] = t match {
      case Push(value: Int, cont) =>
        println(s"Push $value")
        cont
      case Add(cont) =>
        println("Add")
        cont
      case Mul(cont) =>
        println("Mul")
        cont
      case Dup(cont) =>
        println("Dup")
        cont
      case End(cont) =>
        println("End")
        cont
    }
  }

And call it with


    Free.runFC(testProg)(printProgram)

And we get

Push 3
Dup
Mul
Push 4
Dup
Mul
Add

Complete code for the Free based example can be found here and code for the FreeC version here.

3 comments:

  1. Nice article. Just a comment about coding style: Consider to use List.empty[Int] to create your empty lists.

    ReplyDelete
  2. The `ForthOperators` don’t need the `o` parameter to fulfill the functor/monad requirement. You can make them case objects that extend `ForthOperator[Unit]`. You’re not actually using `o` or `cont` anywhere. Great example and post though!

    ReplyDelete
  3. 1XBET for Cricket Betting - Online Betting in India - XN
    It has a minimum bet of Rs. 1,000 for the best online leovegas cricket bets in India, and 코인카지노 there is no bonus offer from 1XBET the cricket betting site. The maximum

    ReplyDelete