原文链接: https://blog.moertel.com/posts/2013-05-11-recursive-to-iterative.html

另一个标题:我希望python有尾递归消除

递归编程很强大,因为它很容易映射到通过归纳法来证明,这使得设计算法和证明算法的正确性变得容易。

但是很多流行的编程语言对递归的支持很差。在python中将一个大的输入丢进递归算法中,你可能会碰到运行时的递归限制。提高限制,你可能会耗尽栈空间从而触发段错误。

这些都很操蛋。

因此,一个重要的技巧就是知道如何将递归算法转换成迭代算法。这样以来,你就可以设计、证明和编写初步递归代码。在完成之后,你可以通过一系列机械化的步骤把你的算法转换成等价的迭代形式。

把递归变成迭代,这个话题足够吸引人,所以我打算做成一个系列的文章。尾部调用,trampolines(蹦床, 就是两个函数互相调用,可以参照[译] 使用JavaScript中的蹦床函数实现安全递归),continuation-passing style(参照CPS) 等等

本篇文章中,我们只看一个简单的方法和一个辅助的技巧。

简单方法

这种转换方法适用于很多简单的递归函数。能用的时候,效果很好,而且结果是精简和快速的。我通常先试一下这种方案,只有在失败的时候才会考虑更加复杂的方案。

简而言之:

  1. 研究函数
  2. 将所有的函数调用转化成尾递归(如果不能,停止尝试,试试其他方法)
  3. 在函数体周围引入一次循环
  4. 将尾部调用转化成continue语句
  5. 整理一下

这个方法的一个重要特性是它是增量正确的– 在每一步之后,你都有一个与原函数等价的函数。所以如果你有单元测试,你可以在每一步之后运行它们,确保没有犯错。

下面让我们来看看这个方法。

例子: 阶乘

对于这个简单的函数,我们能够不需要使用任何技术就直接转化成迭代版本。但是在这里主要想强调的是上面的机械步骤,当我们的递归函数不是那么简单的时候,我们可以信任这个过程。所以我们要研究一个非常简单的函数,这样我们就能够专注于这个过程。

  1. 研究原函数
1
2
3
4
def factorial(n):
if n < 2:
return 1
return n*factorial(n-1)
  1. 将递归转化成尾递归
1
2
3
4
def factorial1a(n, acc=1):
if n < 2:
return 1
return factorial1a(n-1, acc*n)

(如果这一步看起来很混乱,请看文章最后的奖励说明,了解这一步背后的”秘密功能”技巧)

  1. 在函数体周围引入一次循环。你需要while True: body; break
1
2
3
4
5
6
def factorial1b(n, acc=1):
while True:
if n<2:
return 1*acc
return factorial1b(n-1, acc*n)
break

是的,我知道在return 后面放一个break很疯狂,但还是要这么做。清理工作稍后进行。现在,我们要根据数字来判断.

  1. 将所有的尾递归调用f(x=x1, y=y1, ...)替换为(x,y,...)=(x1,y1,...);continue。确保更新所有参数
1
2
3
4
5
6
7
def factorial1c(n, acc=1):
while True:
if n<2:
return 1*acc
(n, acc) = (n-1, acc *n)
continue
break

这一步,我把原来函数的参数列表,括号什么的都复制过来,然后复制到return语句上. 这样就减少了搞砸事情的机会,一切都是机械化的。

  1. 整理代码,让它更符合习惯
1
2
3
4
def factorial1d(n, acc=1):
while n>1:
n, acc = n-1, acc*n
return acc

好吧,这一步不是关于机械化,而是关于风格。消除杂乱无章的东西,简化,让它闪闪发光。

  1. 这样你就完成了

我们收获了什么

我们只是做了五步工作,将我们原来的递归阶乘转化成了等价的迭代。如果我们的编程语言支持尾递归消除,我们可以在第二步停止factorial1a的运算。但是不!!!!! 我们必须继续,一直到第五步,因为我们使用的是Python。

过程虽然并不困难,但还是得手工操作。那它给我们带来了什么?

为了看看它给我们带来了什么好处,我们来看看Python运行时环境里面的情况。我们将使用Online Python Tutor 的可视化查看器来观察factorial, factorial1a和factorial1d各自计算5的阶乘时栈帧的建立情况。

这非常酷,所以不要错过这个环境。可视化它 (在新标签页下打开它)

点击Forward按钮,逐步完成函数的执行。注意Frames栏中发生的情况。当factorial在计算5的阶乘时,堆栈上建立了5个帧。不是巧合。

我们的尾递归函数factorial1a也是一样(你说的对,很惨)

但是对于我们的迭代函数factorial1d来说,就不一样了。它只存在一个堆栈,一次又一次,直到完成。这就是经济!

所以我们才做了这个工作。经济性。我们将O(n)堆栈使用量转换成O(1)堆栈使用量。当n可能很大的时候,这种节省很重要。这可能是得到一个答案和得到段错误的区别

非简单案例

好了,我们解决了factorial. 但那是个简单的问题。如果你的函数没有那么简单呢?那就需要更高级的方法了。

这就是我们下次的话题了。

奖金: 利用秘密功能进行尾递归转化(注: 这一步应该算是全篇最重要的部分)

在简单方法的第二步中,我们用这段代码转换了递归调用:

1
2
3
4
def factorial(n):
if n < 2:
return 1
return n*factorial(n-1)

转换到这个尾递归调用

1
2
3
4
def factorial(n, acc=1):
if n < 2:
return 1*acc
return factorial(n-1, acc*n)

这个转换只要你掌握了窍门就很容易,但你第一次看到它的时候,它就像魔术一样。那么我们来一步步看一下

首先,我们要去掉下面代码中的n*。

1
return n*factorial(n-1)

n* 位于我们对factorial的递归调用和return关键字之间。换句话说,这段代码相当于下面的代码

1
2
3
x = factorial(n-1)
result = n * x
return result

也就是说,我们的代码必须调用factorial函数,等待它的结果x,然后对这个结果做一些事情(乘以n),才能返回它的结果。这个result = n*x太讨厌了,我们必须去掉它。我们要的只是在返回语句中递归调用factorial。

那么我们该如何摆脱这种乘法呢?

这就是诀窍。我们用乘法功能来扩展我们的函数,用它来为我们做乘法。<<– 注: 这个地方看着有些莫名其妙

🤫,这是一个秘密功能。

本质上,每次调用我们的扩展函数的时候,不仅仅会计算一个阶乘,它还会(秘密地) 将阶乘乘以我们给他的任何额外值。持有这些额外值的变量通常被称为”累加器”,所以我在这里使用acc这个名字是为了向传统致敬。

所以这是我们新扩展的函数:

1
2
3
4
def factorial(n, acc=1):
if n<2:
return acc * 1
return acc *n * factorial(n - 1)

看看我是怎么增加秘密乘法功能的?两件事情。

首先,我在原函数中增加了一个额外的参数acc, 即乘数。请注意,它的默认值是1,所以在我们给它一些其他值之前,它没有任何影响(因为1*x = x).

其次,我把每一条return语句都从 return {whatever}改成了return acc*{whatever}。无论我们的函数何时x, 现在都会返回acc*x,就是这样。我们的秘密功能已经完成了!而且证明它的正确性很简单(事实上,我们刚刚证明了它的正确性! 重读第二句)

这两个变化是机械的,很难搞砸,而且,默认什么都不做。这些都是你在给函数添加秘密功能时想要的属性。

好了,现在我们有一个函数,计算n的阶乘,并秘密地将其乘以acc.

现在让我们回到那行麻烦的代码,但在我们新扩展的函数中

1
return acc * n * factorial(n-1)

它计算出n-1的阶乘,然后乘以acc*n,但是等等!我们不需要自己做这个乘法。现在不需要了。现在我们可以让我们的扩展阶乘函数使用秘密功能为我们做这件事情。

因此,我们可以改写成

1
return factorial(n-1, acc*n)

这就是一个尾递归!!!

所以我们的尾递归函数是这样的

1
2
3
4
def factorial(n, acc=1):
if n<2:
return acc*1
return factorial(n-1, acc*n)

现在我们所有的递归调用都是尾递归,这个函数很容易使用本文中介绍的方法转换成迭代形式

我们来复习一下把递归调用变成尾递归的秘籍

  1. 找到不是尾递归的递归调用
  2. 确定该调用与其返回return之间的工作内容
  3. 扩展该函数的秘密功能来完成该工作,比如由一个新的累加器参数控制,该参数的默认值导致它什么也不做。
  4. 使用扩展功能,消除旧工作
  5. 现在你就有一个尾递归了
  6. 重复它,直到所有的递归调用都是尾递归

练习

你的任务是摆脱以下函数中的递归。觉得自己能搞定?fork这个仓库,完成练习

1
2
3
4
5
6
7
8
9
10
11
12
def find_val_or_next_smallest(bst, x):
if bst is None:
return None
elif bst.val == x:
return x
elif bst.val > x:
return find_val_or_next_smallest(bst.left, x)
else:
right_best = find_val_or_next_smallest(bst.right, x)
if right_best is None:
return bst.val
return right_best

答案:

最核心的是,把目前为止找到的最好结果作为额外的参数传递到递归调用中

1
2
3
4
5
6
7
8
9
def find_val_or_next_smallest(bst, x, best=None):
if bst is None:
return best
elif bst.val == x:
return x
elif bst.val > x:
return find_val_or_next_smallest(bst.left, x, best)
else:
return find_val_or_next_smallest(bst.right, x, bst.val)

转变成迭代:

1
2
3
4
5
6
7
8
9
10
def find_val_or_next_smallest(bst, x, best=None):
while True:
if bst is None:
return best
elif bst.val == x:
return x
elif bst.val > x:
bst = bst.left
else:
bst, best = bst.right, bst.val