Spire’s Ops macros? Why and How!

Typelevel.scala手上主要有catsshapeless以及spire。前两者主要是函数式编程的库,最后一个spire主要是个数值计算上抽象的库,其中有很多提高效率的技巧可以学习,Breeze就是用了spire,先看一个Ops macros的例子。

implicit operators on type classes

为了给generic type加上操作我们常用这种方法

1
2
def foo2[A](x: A, y: A)(implicit ev: Ordering[A]): A =
x > y

然后展开,触发隐式类型转换

1
2
def foo3[A](x: A, y: A)(implicit ev: Ordering[A]): A =
infixOrderingOps[A](x)(ev) > y

通过infixOrderingOps加上从foo3捕获的ev,然后在运行时就近似

1
2
def foo4[A](x: A, y: A)(implicit ev: Ordering[A]): A =
new ev.Ops(x) > y

通过以参数x来构造一个包含>运算符的类,然后把y当做参数传进去,所以这里就出现一个问题,每次都要重新构造一个对象,new一下,一般来说对性能的损失不大。如果在运算中,尤其是次数很多的调用性能影响就很大了。这种方法可以这么搞

1
2
def bar[A](x: A, y: A)(implicit ev: Ordering[A]): A =
ev.gt(x, y)

但是这样的问题很明显,如果复杂一点

1
2
3
4
5
def qux1[A: Field](x: A, y: A): A =
((x pow 2) + (y pow 2)).sqrt

def qux2[A](x: A, y: A)(implicit ev: Field[A]): A =
ev.sqrt(ev.plus(ev.pow(x, 2), ev.pow(y, 2)))

就变得又难写,又难读。如果可以用宏把这一部分重写就好了。

Machinist

spire中提供了Ops macros的功能,然后独立出来叫Machinist

看一个简单的例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import machinist.DefaultOps
import scala.{specialized => sp}

// 使用specialized注解,避免对基础类型的装箱
trait Eq[@sp A] {
def eqv(lhs: A, rhs: A): Boolean
}

object Eq {
// 为Int提供操作
implicit val intEq = new Eq[Int] {
def eqv(lhs: Int, rhs: Int): Boolean = lhs == rhs
}

//指定宏展开
implicit class EqOps[A](x: A)(implicit ev: Eq[A]) {
def ===(rhs: A): Boolean = macro DefaultOps.binop[A, Boolean]
}
}

object Test {
import Eq.EqOps

def test(a: Int, b: Int)(implicit ev: Eq[Int]): Int =
if (a === b) 999 else 0
}

然后,宏会展开成

1
2
3
4
5
6
7
8
9
10
if (a === b) 999 else 0

// after implicit resolution
if (Eq.EqOps(a)(Eq.intEq).===(b)) 999 else 0

// after macro application
if (Eq.intEq.eqv(a, b)) 999 else 0

// after specialization
if (Eq.intEq.eqv$mcI$sp(a, b)) 999 else 0

Details

其实这样一看,Machinist并不是特别复杂。 考虑我们自己实现这个功能,我们其实只是要识别出来需要修改的运算符,然后把new Ops(a)(ev).===(b)转化成ev.eqv(a, b)的形式。

首先我们先对于

1
2
3
4
// implicit处理后
Eq.EqOps(a)(Eq.intEq).===(b)
// 编译后
$iw.this.Eq.EqOps[Int](a)(ev).===(b)

先把eva取出来

1
2
3
4
5
val (ev, lhs) = c.prefix.tree match {
case q"$_($lhs)($ev)" => (ev, lhs)
case t => c.abort(c.enclosingPosition,
"Cannot extract subject of operator (tree = %s)" format t)
}

由于我们不关心Eq.EqOps[Int]这部分,所以用$_匹配一下。

然后,我们要先获取要替换的方法的名字,这样才能去对照找要替换的内容。这个很简单只要获取调用宏展开部分的名字就好了。

1
val s = c.macroApplication.symbol.name.toString

然后,我们需要知道把那些方法换成那些方法,这个映射是用operatorNames这个Map存下来,最后把他们拼起来就行了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import scala.language.experimental.macros
import scala.language.higherKinds
import scala.reflect.macros.blackbox.Context

trait Ops {
def operatorNames: Map[String, String]

def binop[A, B](c: Context)(rhs: c.Expr[A]): c.Expr[B] = {
import c.universe._
val (ev, lhs) = c.prefix.tree match {
case q"$_($lhs)($ev)" => (ev, lhs)
case t => c.abort(c.enclosingPosition,
"Cannot extract subject of operator (tree = %s)" format t)
}

val s = c.macroApplication.symbol.name.toString

val method = TermName(operatorNames.getOrElse(s, s))

c.Expr(q"$ev.$method($lhs, ${rhs.tree})")
}
}

trait DefaultOperatorNames {

val operatorNames = Map(
// Eq (===)
("$eq$eq$eq", "eqv")
)
}

object DefaultOps extends Ops with DefaultOperatorNames

我跑了一下benchmark

没替换的部分

[info] ::Benchmark Multi.trivalmulti::

[info] cores: 4

[info] hostname: airWifiSai.local

[info] name: Java HotSpot(TM) 64-Bit Server VM

[info] osArch: x86_64

[info] osName: Mac OS X

[info] vendor: Oracle Corporation

[info] version: 25.111-b14

[info] Parameters(size -> 3000000): 58.143474

[info] Parameters(size -> 6000000): 116.573365

[info] Parameters(size -> 9000000): 174.99838

[info] Parameters(size -> 12000000): 236.340219

[info] Parameters(size -> 15000000): 295.208304

替换之后

[info] ::Benchmark Multi.multi::

[info] cores: 4

[info] hostname: airWifiSai.local

[info] name: Java HotSpot(TM) 64-Bit Server VM

[info] osArch: x86_64

[info] osName: Mac OS X

[info] vendor: Oracle Corporation

[info] version: 25.111-b14

[info] Parameters(size -> 3000000): 40.019016

[info] Parameters(size -> 6000000): 80.235033

[info] Parameters(size -> 9000000): 120.282742

[info] Parameters(size -> 12000000): 160.866715

[info] Parameters(size -> 15000000): 203.81983

引用

  1. opsmacros代码
  2. machinist
  3. scala macros