Review: Continuations, CPS

This notebook summarizes our exploration into Continuation-Passing Style (CPS) in Python.

Consider the old favorite, factorial, written with standard recursion:

In [1]:
def fact(n):
    if n == 1:
        return 1
    else:
        return n * fact(n - 1)
In [2]:
fact(5)
Out[2]:
120
In [3]:
print fact(976)
531446595523432843309595655631676286918046700071120401951121139499225260808779708253903734438579590789001567153256499026974532145208122087873643313306238289189070623658269698543016547402779591888649821568947095967496820022015143814991298321623724586736242929731448582724523878255343501096687265376303239361025918751770644401625269318998349491371513845669986016371039203358546775743309263112257802733644515736161158277200593641584468163391031508709255925531720432417399397442586490351189223051857477322232784555620263962080211294980593689348354074642641419065770259997965278766562695474414784029601731032192887853279697168309510321834158868839832160194459232470535773935552828427939438139920174853066256157337475265221413689055659821514620159138737381903471930839173567784984608652338571500836858629834993489420696073770679958505677500706543075458196223232993099866824056061656530375249557266073544002062832304070020854394095599009731571825503710515946245616086467874985327113506815553912665121909211818895457574478120861263240544796953288512737175339519420428515653132585494951810164111095791356148825933936848475005597564646811601043886740810111363317973619287471225594514719384113094057275533464876718487027885994428192751270761270328807798198078494715115050102309625310737791820125680444445664535449858199311172122931539597027016399726415572094543616377441753276166801327868248001066353613627657855132931794234372500578178697171717501684483683436319852820597930849359309611520290156953864688369716277148702264642133369599310927408515520524174389991919208129940512190542284195854452891769732343195225930966414266576044452215306379724595581360970088058225508356818341607443242137454667120300909936562310074286053489832971224543966796565406830553329005029707988433862113152344039405042073966044969105238018192931865625944812435825950091974172074435157791358126665952051058450240785300288120531940144405333610961027397547478369603286969963354599551697570083604071140919952955293539053017619227758643093192832276743328192515465444835252850336693050127801446586017331993048426463712130494517101027942772280353910146890682015563832787826393249564472331961617434770469083563423242370503470449067872272145399337603689464181492554975227074272323172834806543155200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

This works, in Python!

But this type of code will fail in languages like C and Java for a completely different reason: you will "overflow" the bits that a number can contain. We'll discuss types later. But for now, realize that both Scheme and Python do a smart thing here: they automatically switch from Integers to BigNums (a data structure-based number representation).

But if we try to compute just one more:

In [4]:
try:
    print fact(977)
except:
    print "Fail!"
Fail!

Why?

In [5]:
try:
    print fact(977)
except Exception as e:
    print e.message
maximum recursion depth exceeded

This is the problem that we are exploring with continuations and CPS.

We can rewrite fact in CPS. Here, we use functions to represent continuations. So in this example, a continuation is a function that takes a value. When you get a result, you "pass it to the continuation"... meaning that since the continuation is a function, you simply "call" the continuation with the result as an argument.

Recall the three steps of CPS transformation:

  1. Add a continuation, k, to the recursive function
  2. Give any result to the continuation
  3. Any computation that is left to do after a recursive call must be done in a new continuation
In [6]:
def fact_cps(n, k):
    if n == 1:
        return k(1)
    else:
        return fact_cps(n - 1,
                        lambda v: k(v * n))
In [7]:
fact_cps(2, lambda i: i)
Out[7]:
2

And give the identity function as the initial continuation:

In [8]:
fact_cps(5, lambda i: i)
Out[8]:
120

Yet:

In [9]:
try:
    print fact_cps(1000, lambda i: i)
except Exception as e:
    print e.message
maximum recursion depth exceeded

The first thing we can do is move the "calling of the continuation" out from inside the recursive function:

In [10]:
def fact_cps(n, k):
    if n == 1:
        return ["apply-cont", k, 1]
    else:
        return fact_cps(n - 1,
                        lambda v: ["apply-cont", k, v * n])

And then make a "trampoline" to run it:

In [11]:
def trampoline(result):
    while isinstance(result, list):
        result = result[1](result[2])
    return result
In [12]:
trampoline(fact_cps(5, lambda i: i))
Out[12]:
120

Yet:

In [13]:
try:
    print trampoline(fact_cps(1000, lambda i: i))
except Exception as e:
    print e.message
maximum recursion depth exceeded

The problem is that we are still calling the recursive function recursively. We can move out the recursive call, similarly to what we did with the calling of the continuation.

We note that after we have transformed a function into CPS, the rest of the computation could be carried out by a simple GOTO. We approximate this by returning a "goto" list that contains the name of the function, and its arguments. Really, we are just moving the "calling of the recursive function" from inside the recursive function to the trampoline:

In [14]:
def fact_cps(n, k):
    if n == 1:
        return ["apply-cont", k, 1]
    else:
        return ["goto", fact_cps, n - 1,
                        lambda v: ["apply-cont", k, v * n]]
In [15]:
def trampoline(result):
    while isinstance(result, list):
        if result[0] == "apply-cont":
            result = result[1](result[2])
        elif result[0] == "goto":
            result = apply(result[1], result[2:])
    return result

Thus, we have three things that calling a recursive CPS function can return:

  • a continuation
  • a goto, a recursive call
  • a result

The trampoline will bounce through these until a result is gotten:

In [16]:
try:
    print trampoline(fact_cps(1000, lambda i: i))
except Exception as e:
    print e.message
402387260077093773543702433923003985719374864210714632543799910429938512398629020592044208486969404800479988610197196058631666872994808558901323829669944590997424504087073759918823627727188732519779505950995276120874975462497043601418278094646496291056393887437886487337119181045825783647849977012476632889835955735432513185323958463075557409114262417474349347553428646576611667797396668820291207379143853719588249808126867838374559731746136085379534524221586593201928090878297308431392844403281231558611036976801357304216168747609675871348312025478589320767169132448426236131412508780208000261683151027341827977704784635868170164365024153691398281264810213092761244896359928705114964975419909342221566832572080821333186116811553615836546984046708975602900950537616475847728421889679646244945160765353408198901385442487984959953319101723355556602139450399736280750137837615307127761926849034352625200015888535147331611702103968175921510907788019393178114194545257223865541461062892187960223838971476088506276862967146674697562911234082439208160153780889893964518263243671616762179168909779911903754031274622289988005195444414282012187361745992642956581746628302955570299024324153181617210465832036786906117260158783520751516284225540265170483304226143974286933061690897968482590125458327168226458066526769958652682272807075781391858178889652208164348344825993266043367660176999612831860788386150279465955131156552036093988180612138558600301435694527224206344631797460594682573103790084024432438465657245014402821885252470935190620929023136493273497565513958720559654228749774011413346962715422845862377387538230483865688976461927383814900140767310446640259899490222221765904339901886018566526485061799702356193897017860040811889729918311021171229845901641921068884387121855646124960798722908519296819372388642614839657382291123125024186649353143970137428531926649875337218940694281434118520158014123344828015051399694290153483077644569099073152433278288269864602789864321139083506217095002597389863554277196742822248757586765752344220207573630569498825087968928162753848863396909959826280956121450994871701244516461260379029309120889086942028510640182154399457156805941872748998094254742173582401063677404595741785160829230135358081840096996372524230560855903700624271243416909004153690105933983835777939410970027753472000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

At this point, there is no stack!

That doesn't mean we can compute the factorial of any huge number... fact_cps still takes memory in the form of a continuation for each recursion. But there are no more stack limits to hold us back!

On the other hand, some recursions can now run forever, because they don't take any more continuations. Consider:

In [17]:
def loop():
    return loop()
In [18]:
try:
    loop()
except:
    print "Fail!"
Fail!

But we can do the same tricks:

In [19]:
def loop_cps(k):
    return ["goto", loop_cps, k]
In [20]:
trampoline(loop_cps(lambda i: i))
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-20-73311b469729> in <module>()
----> 1 trampoline(loop_cps(lambda i: i))

<ipython-input-15-f74f04bae47b> in trampoline(result)
      4             result = result[1](result[2])
      5         elif result[0] == "goto":
----> 6             result = apply(result[1], result[2:])
      7     return result

KeyboardInterrupt: 

(I had to press the Interrupt button to stop, as this would run forever)

Handling multiple recursions can get a little tricky, but is still very mechanical. Consider fib:

In [21]:
def fib(n):
    if n == 0:
        return 1
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)
In [22]:
fib(5)
Out[22]:
8

Again, the three steps of CPS transformation:

  1. Add a continuation, k, to the recursive function
  2. Give any result to the continuation
  3. Any computation that is left to do after a recursive call must be done in a new continuation

First, we write it in standard CPS form:

In [23]:
def fib_cps(n, k):
    if n == 0:
        return k(1)
    elif n == 1:
        return k(1)
    else:
        return fib_cps(n - 1, 
                lambda v1: fib_cps(n - 2,
                            lambda v2: k(v1 + v2)))
In [24]:
fib_cps(5, lambda i: i)
Out[24]:
8

Then, we "unroll" the function calls so that they are done in the trampoline:

In [25]:
def fib_cps(n, k):
    if n == 0:
        return ["apply-cont", k, 1]
    elif n == 1:
        return ["apply-cont", k, 1]
    else:
        return ["goto", fib_cps, n - 1, 
                lambda v1: ["goto", fib_cps, n - 2,
                            lambda v2: ["apply-cont", k, v1 + v2]]]
In [26]:
trampoline(fib_cps(5, lambda i: i))
Out[26]:
8

This will take forever for numbers over 20 something. Why? Not because it is recursive, but because it is stupidly re-computing things over and over.

Don't confuse recursion with stupidity! Recursion is often the right tool for the job, especially if you have a recursive data structure, and are working an appropriate programming language.