Continuation Passing Style

Table of Contents

写一写自己对 CPS 的理解,个人经历有限,难免会有错误或者认识不全,所以请见谅.

更新于 2018-11-12

偶遇 CPS

有很久一段时间没有写 Python 了,语法虽好,不过语法糖实在太多,断断续续的接触了一段时间的 Lisp, 从一开始的 Common Lisp,到现在的 Emacs LispRacket,经过这段洗礼以后,有一种"不写 Python 了,干脆写 Lisp 为生算了"的想法(不过想生活的话,还是写 Python 靠谱点,都学就最好),我不是 Lisp 的佼佼者,不过应该也算是个 Lisp fanboy 了吧.

之所以学 Lisp 是因为当时看了王垠的博客而有了想去了解 Programming Language 的欲望.刚好前一段时间大概的读了 Semantics EngineeringEOPL 3rd,对这一块有一个大概的了解.目前还在补 EOPL 上的习题,这本书上有一个我挺感兴趣的内容,看到目录上有这个的时候我当场就兴奋不已,"终于找到有讲这一块的书了","这一块"就是 Continuation Passing Style,简称 CPS.

Continuation

在讲什么是 CPS 之前,得先说一下什么是 continuation.

Lisp 的方言之一 Racket 的里面,它是一个特性,记录下计算过程中的某个执行点的上下文.

接下来我会用 Racket 来讲解,放心,我会简单地讲一下要用到的一些 forms,不会真的涉及到 Racketcontinuation 的实际操作,只是讲它的意像.

简单的 Lisp

Racket 里面一般是这样定义函数的,以定义一个形式参数为 ab 的整型加法函数 add ,和一个减法函数 sub 为例子.

(define (add a b) (+ a b))
(define (sub a b) (- a b))

;; 也可以利用lambda表达式定义函数
;; (define add (lambda (a b) (+ a b)))

;; call it with 1 and 2, then returns 3 as the result
(add 1 2)
;; call it with 3 and 2, then returns 1 as the result
(sub 3 2)

;; if condition then true-branch else false-branch
(if (> (sub 3 2) 0)
    (sub 3 2)
    (add 1 2))

上面注释的 lambda 表达式绑定给 add 函数,而 lambda 表达式就像是匿名函数,反映了 Racket 支持函数式编程.简单点说就是过程与数据有着同等地位,这个特性后面会用上.

接下来用一个更复杂的例子,定义一个名为 ans 的函数,接收三个整型形参 x,y,z 返回结果为 (add (sub x y) z).

(define (ans x y z)
   (add (sub x y) z))

控制上下文(Control context)

好了,介绍什么是 continuation 之前先介绍一下什么是控制上下文?首先要先了解什么是上下文,所谓上下文就是贯穿整个过程的一个意像,在整个过程中的每一个点上,上下文的状态都不一样.

举个现实例子,时间就是一个上下文,昨天我做了什么,今天我又做了什么,昨天做的事情对我今天做的事情造成什么影响,等等.

回到程序上,假如调用 (ans 1 2 3),那么会发生以下几件事情,

  1. 形式参数和实际参数发生绑定: x=1, y=2, z=3.
  2. 计算 (sub x y), 结果为 -1.
  3. 计算 (add 1 3), 结果为 2.

整一个计算过程需要记住变量的值,这需要一个叫做 environment 的东西进行储存,它就是一个从变量到值的映射, environment 是数据上下文(data context)的一个抽象.

除了数据上下文,还有另外一种上下文,上面的整个计算过程中的第二步和第三步,每次执行一步都需要记录下一步要从哪里开始,这个就是所谓的控制上下文.事实上还可以想象一下用调试器的逐步调试的过程,每一步都是一个执行点,逐步执行直到计算结束.

continuation 就是控制上下文的一个抽象,它就像 stack 的帧(frame). Environment 是一个映射,它的 representation 有很多选择:哈希表,关联链表等等,那么 continuationrepresentation 又是什么呢?

Continuation 的 representation

在给 continuation 选择 representation 之前先给每一个执行点选择一种 representation.每一执行点相当于一个过程,每执行一个点就是一次调用.

比如调用 (ans 1 2 3), 这样的话 (sub 1 2) 的结果 -1,不过 -1 不是一个过程,在 Racket 里面,一个函数(procedure here, not function)其实就是一个过程,所以这一步可以这样表示,

;; 第一步,把第一步保存在first-step
(define first-step (lambda (pre-step-res) (sub 1 2)))

;; 执行第一步相当于以下,void是Racket里面的一个值
(define res1 (first-step void))

第二步需要等待第一步的运算结果,

;; 第二步,把第二步保存在second-step
(define second-step (lambda (pre-step-res) (add pre-step-res 3)))

;; 执行第二步
(define res2 (second-step res1))

最后一步就是返回上一步的结果做为整个计算的返回值,

(define last-step (lambda (v) v))

除了用函数作为 representation 外,还可以选择数据结构作为 representation,这个数据结构包含了下一步需要的信息,比如 environment, expression, value, procedurecontinuation.

其中 expression 是下一个要执行的表达式, environment 就是该 expression 执行的 environment, continuation 就是该 expressioncontinuation,如此类推.

我想你应该多少能看出来了这是一个递归.下面开始演示如何编写 CPS 程序.

Continuation Passing Style

什么是 CPS

顾名思意, CPS 就是一种风格,这种风格就是把 continuation 作为参数传递.类似的还有 Environment Passing Style.

用这个参数记录运行时的内容.

CPS 的意义

其实就是对控制上下文(control context)进行抽象和实例化,以及把函数转化成尾调用(Tail call).

尾调用有一个好处,如果函数是一个递归函数,我们可以很容易地在已改为尾调用基础上进行优化,把普通递归优化成尾递归(Tail recursion).

首先要明白什么是尾调用,尾调用就是指函数 \(func\) 的最后一个语句为调用一个函数 \(proc\),它同时也是函数 \(func\) 的返回语句,返回值为 proc(args, ...) 结果.

所谓的尾递归就是指函数 \(func\) 在尾调用中调用自身来进行递归,使得 \(func\) 和迭代(循环)两者在不同的手段(一个用递归,一个用循环语句)下实现同样的计算行为.

虽然计算行为一样,但不同编译器/解释器对待递归和迭代的处理可是不一样的,导致了两者效率方面的不同.

为了提高效率,我们需要把效率较差的实现改写为效率高的实现. 这也是本文章的终极目标: 如何把递归转改写成迭代.

对控制上下文进行抽象就是说程序执行的下一步可以被抽象成数据表示,就像上面的说明例子那样,就是使用了函数来表示下一步要执行的动作.

把上面的 ans 改写成 CPS 程序 ans/k

(define (ans/k x y z cont)
   (sub/k x y
      (lambda (res1)
         (add/k res1 z
            (lambda (res2)
               (cont res2))))))

这是这个例子的最终结果.现在开始说改写的思路,也就是一套把程序转换 CPS 程序的算法.

Simple Expression and Tail Form Expression

要转变成 CPS,则要理解 CPS 对表达式做出的划分: SimpleExp 以及 TfExp (Tail Form Expression).

SimpleExp 包括常量以及变量: 1, a, (- 2 x), (lambda (v) v) 以及 (zero? x) 等等 TfExp 则是由 SimpleExp 组合而成的复杂表达式(ComplexExp), 有: (f x), (if (exp1) exp2 exp3) 等等. 由于 TfExp 是可以拆分成多个 SimpleExp 的,而 SimpleExp 是不可再分的,所以 SimpleExp 也可以叫 AtomicExp.

关系大致就像这样:

SimpleExp    ::= Constant || Variable
TfExp        ::= (SimpleExp SimpleExp*)

这样就可以保证 TfExp 处于函数的 tail position, tail position 就是函数的退出的位置,也就是结束的地方,

这一步的 continuation 和整个函数的 continuation 是一样的,也就是说栈空间没有发生改变, 在这种地方的调用就是尾调用(前面有提到过这个概念),而这样的函数称为 tail form 的.

一套把程序转化为 CPS 程序的算法

先示范如何把 ans 改写成 CPS,也就是把表达式改写成 TfExp.

  1. (ans x y z) 改写 (ans/k x y z cont),
  2. (ans x y z) 整个的计算过程里面, (sub x y) 是第一个进行运算的,同时它也是一个 non-simple expression,所以先对它的调用进行改写, (sub/k x y cont1), cont1 就是要执行的下一步,
  3. 计算完 (sub x y),下一步就是计算整个结果了,假设上一步的 (sub x y) 的结果为 res1,那么下一步就是 (add res1 z) 的计算了, 它及是第二个 non-simple expression, 对它进行改写 (add/k res1 z cont2),
  4. 回到 cont1cont2 上,在第2步上面说到 cont1 是下一步,而 (add res1 z) 刚好也是 (sub x y) 的下一步,那么 cont1 就是 (lambda (res1) (add res1 z)),不要忘记把 (add res1 z) 改写成 (add/k res1 z cont2);

    cont2 就是 (add res1 z) 的下一步,它的下一步就是返回结果,假设 (add res1 z) 的结果为 res2,那么 cont2 就是 (lambda (res2) res2)?,不,这是错的!正确是 (lambda (res2) (cont res2)),不要忘记了 (ans/k x y z cont) 里面的 cont,

    这是才是整个计算过程的真正最后一步,这一步也是什么都没有做,就是返回结果,所以它应该是这样的 (lambda (res) res),这也被叫做空延续.

  5. 最后还要把 addsub 的定义也要改写,注意 +-primitive operators,不能对它们的定义进行修改,所以它们就不用改写,

    (define (add/k x y cont) (cont (+ x y)))
    (define (sub/k x y cont) (cont (- x y)))
    

CPS 的程序实际上反映了这程序的计算过程,这一步到下一步,如此类推,直到计算完毕.

EOPL 中有一段 The CPS Recipe 有对这套算法进行总结,内容大概如下:

  1. 把每个子程序的定义需要多一个参数,这个参数表示的是 continuation, 一般叫做 cont 或者 k;
  2. 当子程序的返回值是常量或者变量 v,都要改成返回 continuation 的返回值,也就是 (cont v) 或者 (k v);
  3. 当一个子程序的调用 (proc x) 出现在 tail position,也就是子程序返回的地方,那么改成就使用同一个 cont 来调用函数, (proc x cont);
  4. 当一个子程序的调用 (proc x) 出现在 operand position,也就是参数位置,比如 (operator (proc x)), 那么先要在一个新的 continuation 下运算完这个调用, 给它的运算结果一个名字,比如 v1,v2..., 并且继续计算.

书中的例子还是很经典的,所以就抄下来了:

#lang racket

(lambda (x)
  (cond
    [(zero? x) 17]
    [(= x 1) (f x)]
    [(= x 2) (+ 22 (f x))]
    [(= x 3) (g 22 (f x))]
    [(= x 4) (+ (f x) 33 (g y))]
    [else (h (f x) (- 44 y) (g y))]))


;; after transforming

(lambda (x cont)
  (cond
    [(zero? x) (cont 17)]
    [(= x 1) (f x cont)]
    [(= x 2)
     (f x
        (lambda (v)
           (+ 22 v)))]
    [(= x 3)
     (f x
        (lambda (v)
           (g 22 v cont)))]
    [(= x 4)
     (f x
        (lambda (v1)
           (g y
              (lambda (v2)
                 (cont (+ v1 33 v2))))))]
    [else (f x
             (lambda (v1)
                (g y
                   (lambda (v2)
                      (h v1 (- 44 y) v2 cont)))))]))

正如上面的代码所展示一样,一旦把子程序转化成 CPS,里面调用的所有子程序都要对应的变化成 CPS.

最后一个例子

分别定义累加从1到n的函数 bad-acc, acc-tailacc, 以及(Fibonacci) bad-fibfib.

#lang racket
(require racket/trace)

;; bad acc
(trace-define (bad-acc n)
    (if (= n 0)
        0
        (+ n (bad-acc (- n 1)))))

;; tail form
(define (acc-tail n)
    (acc-tail-inner n 0))

(trace-define (acc-tail-inner n res)
    (if (= n 0)
        res
        (acc-tail-inner (- n 1) (+ res n))))

;; cps
(define (acc n)
    (acc/k n (lambda (val) val)))

(trace-define (acc/k n cont)
    (if (= n 0)
        (cont 0)
        (acc/k (- n 1)
               (lambda (res) (cont (+ n res))))))

;; BONUS: even tail form can be convert into cps
(define (acc-tail-inner/k n res cont)
    (if (= n 0)
        (cont res)
        (acc-tail-inner/k
          (- n 1) (+ res n) cont)))

;;; bad fib
(trace-define (bad-fib n)
  (if (or (= n 1) (= n 2))
      1
      (+ (bad-fib (- n 2)) (bad-fib (- n 1)))))

;; cps
(trace-define (fib/k n cont)
  (if (or (= n 1) (= n 2))
      (cont 1)
      (fib/k
       (- n 2)
       (lambda (res1)
         (fib/k
          (- n 1)
          (lambda (res2)
            (cont (+ res1 res2))))))))

(define (fib n)
    (fib/k n (lambda (val) val)))

;;; tests
(bad-acc 10)

(acc-tail 10)

(acc 10)

(bad-fib 7)

(fib 7)

利用 racket/trace 中的 trace 跟踪计算过程,会发现在3者中, acc-tailacc 的计算行为和迭代是一样的,对于 bad-acc,可以明显观察到每一步,并且有明显的起伏.

这是因为 Racket 支持尾递归优化,把尾递归优化成迭代,显然 bad-acc 没能优化成功.

这下面是计算时候的进出栈的变化,可以看到 bad-acc 以及 fib 的栈深是先增长后缩小.

而经过 cps 转换的则是栈深没有发生任何改变.

Welcome to Racket v7.5.
racket@> ,enter "/home/saltb0rn/.config/emacs/site-lisp/test.rkt"
>(bad-acc 10)
> (bad-acc 9)
> >(bad-acc 8)
> > (bad-acc 7)
> > >(bad-acc 6)
> > > (bad-acc 5)
> > > >(bad-acc 4)
> > > > (bad-acc 3)
> > > > >(bad-acc 2)
> > > > > (bad-acc 1)
> > > >[10] (bad-acc 0)
< < < <[10] 0
< < < < < 1
< < < < <3
< < < < 6
< < < <10
< < < 15
< < <21
< < 28
< <36
< 45
<55
55
>(acc-tail-inner 10 0)
>(acc-tail-inner 9 10)
>(acc-tail-inner 8 19)
>(acc-tail-inner 7 27)
>(acc-tail-inner 6 34)
>(acc-tail-inner 5 40)
>(acc-tail-inner 4 45)
>(acc-tail-inner 3 49)
>(acc-tail-inner 2 52)
>(acc-tail-inner 1 54)
>(acc-tail-inner 0 55)
<55
55
>(acc/k 10 #<procedure:...te-lisp/test.rkt:78:13>)
>(acc/k 9 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 8 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 7 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 6 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 5 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 4 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 3 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 2 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 1 #<procedure:...te-lisp/test.rkt:84:15>)
>(acc/k 0 #<procedure:...te-lisp/test.rkt:84:15>)
<55
55
>(bad-fib 7)
> (bad-fib 5)
> >(bad-fib 3)
> > (bad-fib 1)
< < 1
> > (bad-fib 2)
< < 1
< <2
> >(bad-fib 4)
> > (bad-fib 2)
< < 1
> > (bad-fib 3)
> > >(bad-fib 1)
< < <1
> > >(bad-fib 2)
< < <1
< < 2
< <3
< 5
> (bad-fib 6)
> >(bad-fib 4)
> > (bad-fib 2)
< < 1
> > (bad-fib 3)
> > >(bad-fib 1)
< < <1
> > >(bad-fib 2)
< < <1
< < 2
< <3
> >(bad-fib 5)
> > (bad-fib 3)
> > >(bad-fib 1)
< < <1
> > >(bad-fib 2)
< < <1
< < 2
> > (bad-fib 4)
> > >(bad-fib 2)
< < <1
> > >(bad-fib 3)
> > > (bad-fib 1)
< < < 1
> > > (bad-fib 2)
< < < 1
< < <2
< < 3
< <5
< 8
<13
13
>(fib/k 7 #<procedure:...te-lisp/test.rkt:112:13>)
>(fib/k 5 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 3 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 1 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 4 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 3 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 1 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 6 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 4 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 3 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 1 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 5 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 3 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 1 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 4 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 3 #<procedure:...te-lisp/test.rkt:108:10>)
>(fib/k 1 #<procedure:...te-lisp/test.rkt:105:7>)
>(fib/k 2 #<procedure:...te-lisp/test.rkt:108:10>)
<13
13
racket@test.rkt>

和前面三个函数的迭代计算行为不一样, Fibonacci 的计算行为是树形递归,同样可以通过 CPS 得到优化.

仔细看你会发现 acc-tail-inner/kcont 是多余的,在实际开发中如果遇到尾递归的话可以绕路不写,因为那样意义不大,因为 CPS 保证尾调用的发生只是为了方便程序员能够在必要时候把普通递归改写称尾递归,就没必要改成 CPS 了.

即便递归函数已经是尾递归了,可像在 PythonEmacs Lisp 这种有”先天缺陷“ (不支持尾递归优化 Tail Call Optimization, abbr. TCO) 的语言中,尾递归也不能解决爆栈的问题,

不过既然都是尾递归了形式,那就很容易地把尾递归转化成迭代/循环,并且 PythonEmacs Lisp 这些语言都有循环语句.

Trampoline 和 Bounce

写于 2019/4/8

PythonEmacs Lisp 这类语言之所以会因为递归而爆栈,归根到底是因为这些语言中函数会发生多次不返回的调用(就是没来得及返回就开始新的调用).

这个时候直译器就会在每次调用的时候记录下函数的调用状态和其他信息,这个时候栈就会一直上涨,在不加以限制的情况下,若递归逻辑出错就会引发崩溃.

而如果语言支持尾递归优化,那么函数的调用状态信息就会只进栈一次,递归过程中只是修改栈上的状态信息,和迭代一样,就和上面 Racketacc 例子一样.

既然函数调用会压栈,那么在函数调用链增长的时候断开一下就好了,而且尾递归优化的语言中尾递归和迭代效率一样,那就 把尾递归改成迭代.

此外我们还需要一个叫做 Trampolinecontinuation "调度器" (毕竟和事件循环挺像的) 以及一个叫做 bounce 对象, bounce 对象 包含上一步的计算结果和下一步的 continuation,

而 "调度器" 就是一个循环,每一次循环就是通过 bounce 对象里面的信息计算出下一个 bounce, 直到某一步的计算结果不再是 bounce 对象.

接下来我们会演示如何把上面的 acc/k 改写成可以供 Trampoline 使用的形式,也就是把 acc/k 改成 bounce 版本 acc-bounce.

为了断开函数调用链,我们需要 acc/k 中所有的尾调用换成 bounce 值,也就是说把返回值需要变为 bounce ,这是使用 Trampoline 技术的关键.

(至于 Fibonacci 的,这里只给出 Racket 版本,一是为了优先专注于简单的 acc-bounce,二是我懒)

;; bounce 的定义,
(struct bounce (res cont))

;; 对 acc/k 进行修改,要点之一就是把所有 acc/k 的调用(也就是返回值)换成bounce值.
;; 可以这么想,这么做是把原本的调用给拆开了.
(define (acc-bounce n cont)
  (if (= n 0)
      (bounce 0 cont) ;; <= (cont 0), 直接拆
      (bounce n       ;; <= (acc-bounce (- n 1) (lambda (res) (cont (+ n res)))), 这里拆的地方是 (cont (+ n res)),
              (lambda (t)  ;; 和 (cont 0) 不一样,这个需要先把 n 直接提到外面然后得到一个新的 continuation: (lambda (t) (acc-bounce (- t 1) (cont (+ t res)))), 最后拆 (cont (+ t res)) 变成 (bounce (+ t res) cont)
                (acc-bounce (- t 1)
                            (lambda (res) (bounce (+ t res) cont)))))))
;; 由于 Racket 没有什么类似与 Python while statement 那种东西,所以我还是用了递归,在 Python 中可以改写成 while

(define (fib-bounce n cont)
  (if (or (= n 1) (= n 2))
      (bounce 1 cont) ;; <= (cont 0), 直接拆
      (bounce n       ;; <= 拆开第一个调用 (fib/k (- n 2) ...)
              (lambda (t)
                (fib-bounce
                 (- t 2)
                 (lambda (res1)
                   (bounce  ;; <= 拆开第二个调用 (fib/k (-n 1) ...),这里其实可以不拆,
                    (- t 1) ;; 不过为了和把尾调用换成bounce的说法保持一致,应该拆开
                    (lambda (t1)
                      (fib-bounce
                       t1
                       (lambda (res2)
                         (bounce (+ res1 res2) cont))))))))))) ;; <= 最后一次拆开

(define trampoline
  [lambda (b)
    ;; let* 操作符可以暂时理解为其它语言中的局部定义
    ;; kont = bounce-cont(b)
    ;; n    = bounce-res(b)
    ;; res  = kont(n)
    (let* ([kont (bounce-cont b)]
           [n (bounce-res b)]
           [res (kont n)])
      (if (bounce? res)
          (trampoline res)
          res))])

(trampoline (acc-bounce 10000000 (lambda (x) x))

(trampoline (fib-bounce 7 (lambda (x) x))

由于 Racket 本身就支持尾递归优化,所以上面的代码其实只要能保证尾递归就能用的了,根本不需要确保 bounce 版的 acc/k 去掉所有调用.

接下来,我会按照这段代码的思路用 PythonJavaScript 实现一边,这两门语言都是不支持尾递归优化的.

貌似默认最大递归层数都为一千(Python 确定, JavaScript 没说),如果上面的方法能行那么突破了限制的的计算规模就可以通过计算了.

Python 版本:

class Bounce:
    def __init__(self, res, cont):
        self.res = res
        self.cont = cont


def trampoline(bounce):
    while 1:
        res = bounce.cont(bounce.res)
        if isinstance(res, Bounce):
            bounce = res
        else:
            return res


def acc_bounce(n, cont):
    if n == 0:
        return Bounce(n, cont)
    else:
        def kont(t):
            return acc_bounce(t-1, lambda res: Bounce(t+res, cont))
            # return acc_bounce(t-1, lambda res: cont(t+res)) # 这样依然会爆栈,也就是我上面说的一定要把所有调用换成 bounce
        return Bounce(n, kont)


bounce = acc_bounce(50000, lambda x: x)  # 我决定给个50000运算规模
print(trampoline(bounce))

我还发现有一个更加高级的实现方法,直接对 AST 实现变换: https://github.com/0x65/trampoline.

JavaScript 版本:

function trampoline(bounce) {
    let res;
    while (1) {
        res = bounce.cont(bounce.res);
        if (typeof(res) == 'object')
            bounce = res;
        else
            return res;
    }
}

function accBounce(n, cont) {
    if (n == 0) {
        return {
            res: n,
            cont: cont
        };
    } else {
        return {
            res: n,
            cont: function(t) {
                return accBounce(
                    t-1,
                    res => {
                        return {
                            cont: cont,
                            res: t+res
                        };
                    });
            }
        };
    }
}

console.log(trampoline(accBounce(50000, x => x)));

Emacs Lisp 突破 Fibonacci 的递归限制

个人也还挺喜欢 Emacs Lisp,所以展示一个在 Emacs Lisp 下的 fib 逐渐改成 fib-bounce 的过程.

;; bounce
(defstruct (bounce
            (:constructor nil)
            (:constructor bounce (res cont)))
  res cont)

;; tampoline
(defun tampoline (b)
  (let* ((kont (bounce-cont b))
         (val (bounce-res b))
         (res (funcall kont val)))
    (while (bounce-p res)
      (setq
       kont (bounce-cont res)
       val (bounce-res res)
       res (funcall kont val)))
    res))

;; 常规版本
(defun fib (n)
  (if (or (= n 1) (= n 2))
      1
    (+ (fib (- n 2)) (fib (- n 1)))))

;; cps版本
(defun fib/k (n cont)
  (if (or (= n 1) (= n 2))
      (funcall cont 1)
    (fib/k (- n 2)
           (lambda (res1)
             (fib/k (- n 1)
                    (lambda (res2)
                      (funcall cont (+ res1 res2))))))))

;; 基于cps改写的bounce版本
(defun fib-bounce (n cont)
  (if (or (= n 1) (= n 2))
      (bounce 1 cont)
    (bounce n
            (lambda (t1)
              (fib-bounce
               (- t1 2)
               (lambda (res1)
                 (bounce
                  (- t1 1)
                  (lambda (t2)
                    (fib-bounce
                     t2
                     (lambda (res2)
                       (bounce (+ res1 res2) cont)))))))))))

(fib 1000) ;; 直接报错: Lisp nesting exceeds ‘max-lisp-eval-depth
(tampoline (fib-bounce 1000 (lambda (v) v))) ;; 需要很长时间来得出计算结果,但是不会触发递归限制而报错

结语

还是觉得这篇东西有很多地方有欠缺,也说明了我对 CPS 的理解还不够深入.突然觉得 EOPL 写的很好,因为我能明白给我传达的知识,原来写一篇易懂的科普文是如此艰难,真的是佩服这些老前辈.

EOPL 这本书都写到了这些内容,个人十分推荐去阅读这本书,总的来说, CPS 可以保证把函数改写为尾调用,从而优化成尾递归; 而 Trampoline 可以保证不支持尾递归的语言运行尾递归不会发生爆栈,也就是说这两项技术可以实现尾递归优化,不受到编译器/解释器的限制.

最后给出别的一些文章作为参考资料: http://matt.might.net/articles/cps-conversion/,顺便推荐一下 Matt Might 的博客(也就是推荐文章的出处),内容都十分不错,基本都是和 PLT 相关的,有兴趣的可以看一下.

Update

如何把遍历树的递归改写为Trampoline的形式

写于 2021/10/4 21:49

这个问题最早是我1年前在工作时(岗位: WEB前端)想到的,当然由于树的规模不大(我当时测试的火狐浏览器的最大可进栈层数是20000+),所以就算没有答案也不会对当时的工作造成影响.

同事也劝我不要花费精力在上面,可是我的想法是: 20000+ 次递归难道就很大了吗,难道工作中遇到树就不会达到这个规模吗?

现在的我已经在另一家公司了,我的想法很快就有了答案: 20000+ 次递归还真的能遇到.

其实我1年前就尝试过改写了,不过失败了.

今天无意间想起了这个问题,于是我又尝试改写,而这次尝试成功了.

1年前的代码我已经不记得了,所以我假设了一份非常简单的代码来作为问题:

// 树结构的定义
let testData = [
    {
        title: '标题1',
        value: 5,
        children: [
            {
                title: '1-子标题',
                value: 0,
                children: []
            },
            {
                title: '1-子标题2',
                value: 3,
                children: []
            }
        ]
    },

    {
        title: '标题2',
        value: 6,
        children: [
            {
                title: '2-子标题',
                value: 7,
                children: []
            },
            {
                title: '2-子标题2',
                value: 8,
                children: []
            }
        ]
    },
];

// 树遍历
function iterTree(tree) {
    for (let node of tree) {
        console.log(node.title);
        if (node.children.length != 0) {
            iterTree(node.children);
        }
    }
}

这里定义了树结构和树遍历的函数 iterTree, 函数是我们的重点,终极目标是把它改写成符合 Trampoline 形式的 iterTreeBounce.

然而 iterTree 里面有一个 for 循环,这个 for 循环控制这遍历到哪一个树节点,也就是遍历到哪一步, 这导致了无法直接把 iterTree 改写成 iterTreeBounce.

符合 Trampoline 的函数需要返回一个 bounce 对象, bounce 是每一步的计算信息(遍历到哪一个节点), 而"返回"这个动作是会导致 for 循环中断, 也就是说 iterTreeBouncefor 存在天生的冲突.

因此第一步应当把 iterTreefor 循环去掉,这个很简单,把循环改写成递归:

function iterTreeWithoutLoop(tree) {
    if (tree.length != 0) {
        let node = tree[0];
        let treeRest = tree.slice(1);
        console.log(node.title);
        if (node.children.length != 0) {
            iterTreeWithoutLoop(node.children);
        }
        iterTreeWithoutLoop(treeRest);
    }
}

要注意, iterTreeWithoutLoop 并不符合尾递归的定义: 首先它进行了两次自身的调用,其次尾调用时并没有返回自身计算结果.

那么问题来了: 如何把 iterTreeWithoutLoop 改写为尾递归呢?

在1年前,我的第一反映是用 CPS, 一直以为 CPS 是把普通递归函数改写为尾递归的"万能药", 然而它只是把函数改写为尾调用的"万能药",

这个问题暴露了我最初的错误理解,理所当然改写失败了.

最直接的证明就是,即便按照前面的规则改成了 CPS 也无法把 iterTreeWithoutLoop 改写为尾递归,

更何况,按照之前的方法是没办法 直接 把这个函数改成 CPS 的, 当年我就卡在这一步了.

今天的我再看这个问题的时候就有了一个想法:

能不能把两个对于同一个函数的调用合并为一次调用呢? 并且该调用的行为等同于依次对 node.childrentreeRest 调用?

经过一番脑内"尝试"后发现可行,稍微想一下就明白了,在遍历树的时节点分"已遍历"和"未遍历"两种,只有在处理未遍历的节点时才会进行递归.

node.children 以及 treeRest 都是未遍历的节点.那么把它们整理为一个由未遍历节点构成的树不就好了吗?

所以我得到了 iterTreeTail:

function iterTreeTail(tree) {
    if (tree.length != 0) {
        let node = tree[0];
        let treeRest;
        console.log(node.title);
        if (node.children.length != 0) {
            treeRest = node.children.concat(tree.slice(1));
        } else {
            treeRest = tree.slice(1);
        }
        iterTreeTail(treeRest);
    }
}

严格上来说 iterTreeTail 还达不到尾递归定义,因为 iterTreeTail(treeRest); 并非尾调用(没有 return 语句).

最开始的 iterTree 工作就是打印每一个树节点的名字而已,并没有指定返回什么,返回的默认值 undefined.

不过仔细思考后会发现 iterTreeTail 已经符合 iterTree 的结果了; 同时在考虑 JavaScript 函数默认的返回值后,某种意义上 iterTreeTail 也符合了尾递归的定义了:

function iterTreeTail(tree) {
    if (tree.length != 0) {
        let node = tree[0];
        let treeRest;
        console.log(node.title);
        if (node.children.length != 0) {
            treeRest = node.children.concat(tree.slice(1));
        } else {
            treeRest = tree.slice(1);
        }
        return iterTreeTail(treeRest);
    } else {
        return undefined;
    }
}

现在要做的是把它改写成 CPS,把尾递归改写成 CPS 是非常简单的,可以参考前面的 acc-tail-inner/k:

function iterTreeCPS(tree, cont) {
    if (tree.length != 0) {
        let node = tree[0];
        let treeRest;
        console.log(node.title);
        if (node.children.length != 0) {
            treeRest = node.children.concat(tree.slice(1));
        } else {
            treeRest = tree.slice(1);
        }
        return iterTreeCPS(treeRest, cont);
    } else {
        return cont(undefined);
    }
}

基本上到这里了就宣告改写成功了,现在改成符合 Trampoline 的函数基本没什么难度了,下面是最终答案:

function iterTreeBounce(tree, cont) {
    if (tree.length != 0) {
        let node = tree[0];
        let treeRest;
        console.log(node.title);
        if (node.children.length != 0) {
            treeRest = node.children.concat(tree.slice(1));
        } else {
            treeRest = tree.slice(1);
        }
        return {
            cont: (val) => iterTreeBounce(val, cont),
            res: treeRest
        };
        // return iterTreeBounce(treeRest, cont);
    } else {
        return {
            cont: cont,
            res: undefined
        };
        // return cont(undefined);
    }
}

function trampoline(bounce) {
    let res;
    while (1) {
        if (bounce) {
            res = bounce.cont(bounce.res);
            bounce = res;
        } else {
            return res;
        }
    }
}

// 最后测试
trampoline(iterTreeBounce(testData, val => val));

找个改写的方法只花了我半个小时,这半个小时让我解开了1年前的心结,并且对 CPS 的认知发生了翻天覆地的变化.

整体上来说,这半个小时的代价绝对是值得的.

Author: saltb0rn (asche34@outlook.com)

Date: 2018-06-20

Emacs 28.2 (Org mode 9.5.5)

Validate