这一部分展示了关于State Monad的例子以及使用。

Tasteful stateful computations

Haskell features a thing called the state monad, which makes dealing with stateful problems a breeze while still keeping everything nice and pure.

一个简单的例子就是Stack,这是一个典型的保存状态的计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
type Stack = List[Int]

def pop(stack: Stack): (Int, Stack) = stack match {
case x::xs => (x, xs)
}

def push(a: Int, stack: Stack): (Unit, Stack) = ((), a :: stack)

def stackManip(stack: Stack): (Int, Stack) = {
val (_, newStack1) = push(3, stack)
val (a, newStack2) = pop(newStack1)
pop(newStack2)
}

stackManip(List(3,4,5,2))

State and StateT

We’ll say that a stateful computation is a function that takes some state and returns a value along with some new state. That function would have the following type:

1
s -> (a, s)

Scalaz中的State大概是这个样子的

1
2
3
4
5
6
7
8
type State[S, +A] = StateT[Id, S, A]

// important to define here, rather than at the top-level, to avoid Scala 2.9.2 bug
object State extends StateFunctions {
def apply[S, A](f: S => (S, A)): State[S, A] = new StateT[Id, S, A] {
def apply(s: S) = f(s)
}
}

这里可以看到State一个特殊的地方,他固定了函数的形式f: S => (S, A)。这里的StateT大概是

1
2
3
4
5
6
7
8
9
10
11
trait StateT[F[+_], S, +A] { self =>
/** Run and return the final value and state in the context of `F` */
def apply(initial: S): F[(S, A)]

/** An alias for `apply` */
def run(initial: S): F[(S, A)] = apply(initial)

/** Calls `run` using `Monoid[S].zero` as the initial state */
def runZero(implicit S: Monoid[S]): F[(S, A)] =
run(S.zero)
}

定义一个State需要指定具体的S => (S, A)的函数。可以使用下面的语法直接构造一个State

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
trait StateFunctions extends IndexedStateFunctions {
def constantState[S, A](a: A, s: => S): State[S, A] =
State((_: S) => (s, a))

def state[S, A](a: A): State[S, A] =
State((_ : S, a))

def init[S]: State[S, S] = State(s => (s, s))

def get[S]: State[S, S] = init

def gets[S, T](f: S => T): State[S, T] = State(s => (s, f(s)))

def put[S](s: S): State[S, Unit] = State(_ => (s, ()))

def modify[S](f: S => S): State[S, Unit] = State(s => {
val r = f(s);
(r, ())
})
}

通过使用语法糖,可以直接通过连续的操作最终得到结果,而不是显示的传递State

StateFunctions中还定义了一些辅助函数,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
trait StateFunctions extends IndexedStateFunctions {
def constantState[S, A](a: A, s: => S): State[S, A] =
State((_: S) => (s, a))

def state[S, A](a: A): State[S, A] =
State((_ : S, a))

def init[S]: State[S, S] = State(s => (s, s))

def get[S]: State[S, S] = init

def gets[S, T](f: S => T): State[S, T] = State(s => (s, f(s)))

def put[S](s: S): State[S, Unit] = State(_ => (s, ()))

def modify[S](f: S => S): State[S, Unit] = State(s => {
val r = f(s);
(r, ())
})
}

这里面写的很有意思,首先明确一点,State里面给的是要形如S => (S, A)的行为,比如对于get我们就是要把状态取出来所以就应该在保持原有状态的时候,将状态也放到取出值的位置上s => (s, s)。然后,put其实就是为了将状态放过去,_ => (s, ())

一个使用的例子可以是这样的

1
2
3
4
5
6
for {
_ <- push(3)
a <- pop
b <- State.get
r <- if (b.length > 3) State.put(List(1,2,3,4, 5)) else State.put(List(0, 1, 2))
} yield r

同样,我们也可以直接使用getput去实现poppush操作。

1
2
3
4
5
6
7
8
9
10
val pop = for {
s <- get[Stack]
(x::xs) = s
_ <- put(xs)
} yield x

def push(s: Int) = for {
xs <- get[Stack]
_ <- put(s::xs)
} yield ()

Composing monadic functions

When we were learning about the monad laws, we said that the <=< function is just like composition, only instead of working for normal functions like a -> b, it works for monadic functions like a -> m b.

Kleisli

Scalaz中,这就是那个A => M[B]的封装。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
sealed trait Kleisli[M[+_], -A, +B] { self =>
def run(a: A): M[B]
...
/** alias for `andThen` */
def >=>[C](k: Kleisli[M, B, C])(implicit b: Bind[M]): Kleisli[M, A, C] =
kleisli((a: A) => b.bind(this(a))(k(_)))
def andThen[C](k: Kleisli[M, B, C])(implicit b: Bind[M]): Kleisli[M, A, C] = this >=> k
/** alias for `compose` */
def <=<[C](k: Kleisli[M, C, A])(implicit b: Bind[M]): Kleisli[M, C, B] = k >=> this
def compose[C](k: Kleisli[M, C, A])(implicit b: Bind[M]): Kleisli[M, C, B] = k >=> this
...
}

object Kleisli extends KleisliFunctions with KleisliInstances {
def apply[M[+_], A, B](f: A => M[B]): Kleisli[M, A, B] = kleisli(f)
}

最主要的功能就是它的<=<,即把多个处理monad的函数组合起来

1
2
3
4
5
6
val f = Kleisli { (x: Int) => (x + 1).some }

val g = Kleisli { (x: Int) => (x * 2).some}

5.some >>= (f <=< g) // res: 11
5.some >>= (f >=> g) // res: 12

Reader as Kleisli

Scalaz中定义通过Kleisli来定义ReaderT

1
2
3
4
5
type ReaderT[F[+_], E, A] = Kleisli[F, E, A]
type Reader[E, A] = ReaderT[Id, E, A]
object Reader {
def apply[E, A](f: E => A): Reader[E, A] = Kleisli[Id, E, A](f)
}

这里的ReaderTMonad Transformers

Making monads

In this section, we’re going to look at an example of how a type gets made, identified as a monad and then given the appropriate Monad instance. … What if we wanted to model a non-deterministic value like [3,5,9], but we wanted to express that 3 has a 50% chance of happening and 5 and 9 both have a 25% chance of happening?

由于Scala没有内置的分数,所以我们直接使用Double

1
2
3
4
5
6
7
case class Prob[A](list: List[(A, Double)])

trait ProbInstances {
implicit def probShow[A]: Show[Prob[A]] = Show.showA
}

case object Prob extends ProbInstances

由于Listfunctor,好像Prob本身也应该是个functor。我们给他加上map

1
2
3
4
5
6
7
trait ProbInstances {
implicit def probShow[A]: Show[Prob[A]] = Show.showA
implicit val probInstance = new Functor[Prob] {
def map[A, B](fa: Prob[A])(f: A => B): Prob[B] =
Prob(fa.list map { case (x, p) => (f(x), p)})
}
}

接下来先实现flatten,我们先确定一下语义,[([(1,0.5),(2, 0.5), 0.3]), ([(4, 0.3), (3, 0.7)], 0.4), 0.7)]flatten之后应该是[(1, 0.15), (2, 0.15),(4, 0.21), (3, 0.49)]

1
2
3
4
5
def flatten[B](xs: Prob[Prob[B]]): Prob[B] = {
def multall(innerxs: Prob[B], p: Double) =
innerxs.list map { case (x, r) => (x, p * r)}
Prob((xs.list map {case (innerxs, p) => multall(innerxs, p)}).flatten)
}

最后,我们实现Monad。实现需要实现pointbind两个函数。其中Point相当于return即返回包含值的最小上下文。

1
def point[A](a : => A): Prob[A] = Prob(List((a, 1.0)))

然后是函数bind,他可以简单通过mapflatten完成。顺便加一个用来合并相同项的辅助函数collected。最后的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
case class Prob[A](list: List[(A, Double)])

trait ProbInstances {
implicit def probShow[A]: Show[Prob[A]] = Show.showA
implicit val probInstance = new Functor[Prob] with Monad[Prob] {
def point[A](a : => A): Prob[A] = Prob(List((a, 1.0)))
def bind[A, B](fa: Prob[A])(f: A => Prob[B]): Prob[B] =
flatten(map(fa)(f))

override def map[A, B](fa: Prob[A])(f: A => B): Prob[B] =
Prob(fa.list map { case (x, p) => (f(x), p)})
}

def flatten[B](xs: Prob[Prob[B]]): Prob[B] = {
def multall(innerxs: Prob[B], p: Double) =
innerxs.list map { case (x, r) => (x, p * r)}
Prob((xs.list map {case (innerxs, p) => multall(innerxs, p)}).flatten)
}

def collected[B](xs: Prob[B]): Prob[B] = {
Prob(xs.list.foldRight(Map[B, Double]())((x, y) => {
y get x._1 match {
case Some(v) => y + (x._1 -> (x._2 + v))
case None => y + (x._1 -> x._2)
}
}).toList)
}

}
case object Prob extends ProbInstances

Coin example

可以使用一个硬币的例子看看效果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

sealed trait Coin
case object Heads extends Coin
case object Tails extends Coin
implicit val coinEqual: Equal[Coin] = Equal.equalA


def coin: Prob[Coin] = Prob(Heads -> 0.5 :: Tails -> 0.5 :: Nil)
def loadedCoin: Prob[Coin] = Prob(Heads -> 0.1 :: Tails -> 0.9 :: Nil)

lazy val flipThree: Prob[List[Coin]] = for {
a <- coin
b <- coin
c <- loadedCoin
} yield { List(a, b, c)}


Prob.collected(flipThree map { _ all {_ === Tails}})

这里可以看出来,我们利用for的语法糖,构建了一个Coin的序列,然后从里面找出来全都是Tails的,然后计算出来概率,虽然感觉复杂度有点高,但是很直观。




X