For past few days, folds were stuck in my head for some reason and needed some unfolding ๐. I did so and below is the summary of my understanding for the benefit of my future self.
Why
Consider the scenario where we have an array of numbers and we would like to add them together without using a loop. No loops, no problem, we can use recursion.
const sum = ([h, ...t]: number[]): number => h === undefined ? 0 : h + sum(t);
assert.equal(sum([1, 2, 3]), 6);
assert.equal(sum([5]), 5); // array with 1 element
assert.equal(sum([]), 0); // empty array
The function sum
:
- accepts an array of numbers.
- destructures it into head
h
and tailt
:[h, ...t]
. - returns
0
if the head isundefined
. This serves as a base case for the recursion. - else carries on the
sum
operation with the tail:h + sum(t)
.
Now, let's define a function to multiply the numbers in an array:
const product = ([h, ...t]: number[]): number => h === undefined ? 1 : h * product(t);
assert.equal(product([2, 2, 3]), 12);
As we can see, both look almost same. The only bits that vary are:
- Base case value: what to return when we get down to empty array i.e. the base case of recursion.
- The operation:
sum
in one case andproduct
in the other.
This is where folds come in. They generalize the traversing the array and carrying out some operation with combines the array elements in some way.
Folds
We can traverse an array in one of the two ways: from the right or the left.
Right Fold
Let's define right fold foldr
:
const foldr = <A, B>(f: (x: A, acc: B) => B, acc: B, [h, ...t]: A[]): B => h === undefined ? acc : f(h, foldr(f, acc, t));
There is quite a bit that's going on there. Let's go over it step by step.
Arguments:
- The combiner function
f: (x: A, acc: B) => B
: It accepts the current element of the array and existing accumulator, combines them in some fashion and produces new value of accumulator. - accumulator
acc: B
: Initial value and the one that should be returned for the base case of the recursion. - array
[h, ...t]: A[]
: that we need to traverse and combine in some fashion.
Coming to the generics types <A, B>(f: (x: A, acc: B) => B, acc: B, [h, ...t]: A[]): B
, it could be surprising to see two separate types being used: A
for the array elements and and B
for the accumulator. The final return type of foldr
is also B
i.e. the generic type of the accumulator.
Why not only A
, which is the type of array elements, when all we are doing is traversing the array and producing final result by combining the elements in some fashion.
It turns out it's very much possible to combine the array elements into a different type and the generic type B
covers that usage. In some cases, A
and B
will be same, in some cases, not. We'll see an example later where it's not.
Now, let's see foldr
in action. Let's define our sum
and product
functions in terms of foldr
:
const sumFoldr = (xs: number[]) => foldr((x, acc) => x + acc, 0, xs);
assert.equal(sumFoldr([1, 2, 3]), 6);
const productFoldr = (xs: number[]) => foldr((x, acc) => x * acc, 1, xs);
assert.equal(productFoldr([2, 2, 3]), 12);
As we can see, we get expected results.
I found John Whitington's More OCAML book has one of the most straight-forward and to-the-point illustrations of folds execution.
The call trace makes one thing obvious: foldr
is not tail-recursive. The call stack grows till we reach to the end of array before the combine operation starts and stack unwinds.
Left Fold
Let's define left fold foldl
:
const foldl = <A, B>(f: (x: A, acc: B) => B, acc: B, [h, ...t]: A[]): B => h === undefined ? acc : foldl(f, f(h, acc), t);
The function signature is same as foldr
, the difference being how the combiner function is applied: foldl(f, f(h, acc), t)
. We start with initial value of accumulator, apply the combiner function to produce new value for accumulator and use the new value to continue recursing over the remaining array.
Here is how the execution trace looks like:
Now, let's see foldl
in action. Let's define our sum
and product
functions in terms of foldl
:
const sumFoldl = (xs: number[]) => foldl((x, acc) => x + acc, 0, xs);
assert.equal(sumFoldl([1, 2, 3]), 6);
const productFoldl = (xs: number[]) => foldl((x, acc) => x * acc, 1, xs);
assert.equal(productFoldl([2, 2, 3]), 12);
And expected results.
Map and Reduce
Now that we have the fold implementation in place, lets implement two common functions, map
and reduce
in terms of fold. These are defined as Array instance methods in the standard JavaScript API, but we'll implement these as functions.
const map = <A, B>(xs: A[], cb: (x: A) => B): B[] => foldl((x, acc) => {
acc.push(cb(x));
return acc;
}, [] as B[], xs);
assert.deepEqual(map([1, 2, 3], x => x * 2), [2, 4, 6]);
// to demonstrate usage of return array containing different type
assert.deepEqual(map([1, 2, 3], _x => 'ho'), ['ho', 'ho', 'ho']);
// reduce
const reduce = <A>([h, ...t]: A[], cb: (pre: A, cur: A) => A) => foldl((x, acc) => cb(x, acc), h, t);
assert.deepEqual(reduce([7, 3, 8], (pre, cur) => pre + cur), 18);
The map
example demonstrates the use of different type for accumulator. It's a rather contrived example, but demonstrates the point well.
Folding over functions
We went over folding over primitive values in the last section. Folding over functions is also quite common and useful operation. Function piping and composition are the two use cases where we can use folding over functions to create a new one.
Pipe
A pipe
function of functions f1
, f2
and f3
can be defined as: pipe([f1, f2, f3])(x) = f3(f2((f1(x))))
.
We give input x
to first function f1
, take the result and pipe it as input to f2
, get the result and pipe it as input to f3
to get the final result.
Let's create pipe creator function called plumber
that takes two functions and returns their pipe function.
const plumber = <A>(fn1: IdType<A>, fn2: IdType<A>) => (x: A) => fn2(fn1(x));
What's this IdType<A>
type of the functions and why it's needed?
If we have an array of functions and would like to create a pipe function using plumber
function, we have a problem with kickstarting the process with the first function.
plumber
expects 2 arguments and we have just one. That's where Identity function comes in. It's a function that simply returns the argument it gets.
We use the identity function as initial value with the first function in the array to kickstart the pipe formation.
Let's create a pipe function in imperative fashion first to understand it better.
type IdType<A> = (x: A) => A;
const double = (i: number) => i * 2;
const triple = (i: number) => i * 3;
const quadruple = (i: number) => i * 4;
const fns = [double, triple, quadruple];
const plumber = <A>(fn1: IdType<A>, fn2: IdType<A>) => (x: A) => fn2(fn1(x));
// since plumber needs two functions to form the pipeline, we need something to start with the
// first function in the array and that something is the id function.
const idNumber: IdType<number> = x => x; // id function for number type
let acc = idNumber;
for (const fn of fns) {
acc = plumber(acc, fn);
}
assert.equal(acc(1), 24); // acc is the final pipe function
As we can see, we are traversing the array from left to right, assigning the composed pipe function up to that point to the accumulator and the final value of the accumulator is the final pipe function. As such, this is a perfect fit for foldl
and below is the implementation based on foldl
.
// pipe([f1, f2, f3])(x) = f3(f2((f1(x))))
const pipe = <A>(fns: Array<IdType<A>>) => foldl((fn, acc) => x => acc(fn(x)), (x: A) => x, fns);
const half = (x: number) => x / 2;
const third = (x: number) => x / 3;
const tenTimes = (x: number) => x * 10;
const pipeline = pipe([half, third, tenTimes]);
// this is equivalent to tenTimes(third(half(24))) === 40
assert.equal(pipeline(24), tenTimes(third(half(24))));
Compose
A compose
function of functions f1
, f2
and f3
can be defined as: compose([f1, f2, f3])(x) = f1(f2((f3(x))))
.
We start traversing the array from right, give input x
to function f3
, take the result and provide it as input to f2
, get the result and provide it as input to f1
to get the final result. It's a perfect fit for foldr
and here is the implementation.
const compose = <A>(fns: Array<IdType<A>>) => foldr((fn, acc) => x => fn(acc(x)), (x: A) => x, fns);
const plusOne: IdType<number> = x => x + 1;
// or add type to the parameter to conform to IdType<number>
const fiveTimes = (x: number) => x * 5;
const composition = compose([plusOne, fiveTimes]);
// this is equivalent to plusOne(fiveTimes(10)) === 51
assert.equal(composition(10), plusOne(fiveTimes(10)));
Here is the complete code listing for quick reference.
import assert from 'node:assert/strict';
// recursive addition of elements of an array
const sum = ([h, ...t]: number[]): number => h === undefined ? 0 : h + sum(t);
assert.equal(sum([1, 2, 3]), 6);
assert.equal(sum([5]), 5); // array with 1 element
assert.equal(sum([]), 0); // empty array
// recursive multiplication of lements of an array
const product = ([h, ...t]: number[]): number => h === undefined ? 1 : h * product(t);
assert.equal(product([2, 2, 3]), 12);
assert.equal(product([5]), 5);
assert.equal(product([]), 1);
/* as we can see sum and product are almost same. The things that vary is the base case value -
* (0 for sum and 1 for product) and the operation. Let's generalize it.
*/
const foldr = <A, B>(f: (x: A, acc: B) => B, acc: B, [h, ...t]: A[]): B => h === undefined ? acc : f(h, foldr(f, acc, t));
const sumFoldr = (xs: number[]) => foldr((x, acc) => x + acc, 0, xs);
assert.equal(sumFoldr([1, 2, 3]), 6);
const productFoldr = (xs: number[]) => foldr((x, acc) => x * acc, 1, xs);
assert.equal(productFoldr([2, 2, 3]), 12);
/* now let's look at foldl */
const foldl = <A, B>(f: (x: A, acc: B) => B, acc: B, [h, ...t]: A[]): B => h === undefined ? acc : foldl(f, f(h, acc), t);
const sumFoldl = (xs: number[]) => foldl((x, acc) => x + acc, 0, xs);
assert.equal(sumFoldl([1, 2, 3]), 6);
const productFoldl = (xs: number[]) => foldl((x, acc) => x * acc, 1, xs);
assert.equal(productFoldl([2, 2, 3]), 12);
/* let's implement a couple of JavaScript standard apis using folds: map, reduce, not exact but close enough. */
// map - the reason for two type parameters is the returned array can be of any type.
const map = <A, B>(xs: A[], cb: (x: A) => B): B[] => foldl((x, acc) => {
acc.push(cb(x));
return acc;
}, [] as B[], xs);
assert.deepEqual(map([1, 2, 3], x => x * 2), [2, 4, 6]);
// to demonstrate usage of return array containing different type
assert.deepEqual(map([1, 2, 3], _x => 'ho'), ['ho', 'ho', 'ho']);
// reduce
const reduce = <A>([h, ...t]: A[], cb: (pre: A, cur: A) => A) => foldl((x, acc) => cb(x, acc), h, t);
assert.deepEqual(reduce([7, 3, 8], (pre, cur) => pre + cur), 18);
/* pipe and compose */
/* define type for identity */
type IdType<A> = (x: A) => A;
const double = (i: number) => i * 2;
const triple = (i: number) => i * 3;
const quadruple = (i: number) => i * 4;
const fns = [double, triple, quadruple];
const plumber = <A>(fn1: IdType<A>, fn2: IdType<A>) => (x: A) => fn2(fn1(x));
// since plumber needs two functions to form the pipeline, we need something to start with the
// first function in the array and that something is the id function.
const idNumber: IdType<number> = x => x; // id function for number type
let acc = idNumber;
for (const fn of fns) {
acc = plumber(acc, fn);
}
assert.equal(acc(1), 24); // acc is the final pipe function
// pipe([f1, f2, f3])(x) = f3(f2((f1(x))))
const pipe = <A>(fns: Array<IdType<A>>) => foldl((fn, acc) => x => acc(fn(x)), (x: A) => x, fns);
const half = (x: number) => x / 2;
const third = (x: number) => x / 3;
const tenTimes = (x: number) => x * 10;
const pipeline = pipe([half, third, tenTimes]);
// this is equivalent to tenTimes(third(half(24))) === 40
assert.equal(pipeline(24), tenTimes(third(half(24))));
/* compose: compose([f1, f2, f3])(x) = f1(f2((f3(x)))) */
const compose = <A>(fns: Array<IdType<A>>) => foldr((fn, acc) => x => fn(acc(x)), (x: A) => x, fns);
const plusOne: IdType<number> = x => x + 1;
// or add type to the parameter to conform to IdType<number>
const fiveTimes = (x: number) => x * 5;
const composition = compose([plusOne, fiveTimes]);
// this is equivalent to plusOne(fiveTimes(10)) === 51
assert.equal(composition(10), plusOne(fiveTimes(10)));
That's it for today. Happy coding ๐ป!