Today I'd like to talk about doing Calculus with Haskell. This has been a topic I like re-visiting every now and again because I like to think there's a few different strategies one can employ with it.

(Note: this is an older post that I am trying to revise, so I will probably make lots of edits over time as I realize my mistakes)

Let's dive right into it.

Understanding Functions

Now obviously, Calculus is a hard area to approach in programming. The idea of Calculus is that we can study mathematical functions, mappings of numbers from a Domain to a Range of numbers. For people familiar with Haskell already, this is a Functor type where we create a mapping between two categories, whether the two categories are the same or not. However, this isn't really important.

Let's take a function in Haskell which takes a number then does something with it.

f :: Num a => a -> a
f x = x + 1

This is a very easy f(x) = x + 1 function, nothing crazy. Now, Calculus is the understanding the rate at which functions change at any instaneous point in it's domain. Our f(x) takes in the domain of all real numbers, and outputs a number on the range of real numbers. If we graphed it, it would appear as a slanted line with it's y-intercept starting at (x=0, y=1) because we set x to zero in our function and returned a one.

If we wanted a deeper look at this, how fast would we say this function is changing? That is to say, if we draw another line at a tangent on this graph, what would the slope of that line be? For this, it's a constant change of one. The function grows at a constant rate of 1, so it isn't very special.

For a more complex function, we can look at something like a power function.

g :: Num a => a -> a
g x = x * x

This is the same as g(x) = x^2, or x to the power of 2. At a given value x, how fast is this function changing? If we traced a tangent line all along this function's graph, we would see that when x=0 that it's tangent line is completely flat.

Now, we can simplify this process of finding tangent lines by using something called a derivative. A derivative is another function that describes how quickly our original function will change at any given point in it's domain. I am really not great at describing how derivatives and limits work so I'll leave this Wikipedia link here for those curious, but for the most part, there is a set of formulas and rules we can use to get derivatives of functions with ease.

The next section we are going to talk about how we can represent these functions using ordinary Haskell sumtypes.

Expression Trees

Haskell supports something called sum-types, which allows us to define data-types under one variant type name but allow us to use multiple values for that type. For example, the Maybe monad everyone should know by heart is defined as easily as

data Maybe a = Nothing | Just a

This enables us to define a strict type of Maybe a, where the a can mean literally anything, but we can use two types of values for it's variable; one being Nothing while the other is Just a, or "just the value".

-- plain definitions
let blank = Nothing
let aValue = Just 3

-- wrap any value into a Maybe type
f :: a -> Maybe a
f a = Just a

It's a polymorphic type, meaning we can allocate any type we so choose to be stored in the Maybe type. Even another Maybe type!

The point being is that it is easy to define these weird, recursive types that are oddly expressive and so minimal because of Haskell's type inferencing. If we wanted to define a linked list entirely within a single data type, we can.

data List a = Nil | Cons a (List a) 

If you know Lisp or get the gist of single-linked lists in computer science, well then Haskell just made it real easy to define that with a single line of code.

If we wanted to do n-ary trees, another hot topic of computer science, we can also do that, so let's do a binary tree.

data Tree a = Nil | Node a (Tree a) (Tree a)

Like the Linked List, we need a nil-type to say that there's nothing here to indicate the end of the tree's traversal. But since we're on the topic of trees, let's talk about expression trees, something often times used for parsing in calculator programs.

Imagine we have a limited set of arithmetic, like addition, subtraction, division, multiplication, and so on. Let's picture that as a sum-type.

data Expr a = NaN
            | Var Char
            | Const a
            | Add (Expr a) (Expr a)
            | Sub (Expr a) (Expr a)
            | Mul (Expr a) (Expr a)
            | Div (Expr a) (Expr a)
            | Pow (Expr a) (Expr a)
            deriving (Show, Eq)

We have the basic PEMDAS operations here, so let's start talking about what we can do to derive some basic functions.

Like we saw with a basic f(x) = x + 1, any time a variable is involved, it's derivative usually starts at 1, because it's only changing at the rate of itself. If the function was f(x) = 2x + 1, then the derivative becomes 2. Let's start mapping out a function that can derive any expression tree for us.

derive :: Expr a -> Expr a
derive NaN       = NaN     -- can't derive a NaN value
derive (Const c) = Const 0 -- a constant has no change in value
derive (Var x)   = Const 1 -- a single-degree variable's deriv is always 1

So that covers the basic tools. What happens when we have a function that incorporates addition or subtraction? Those are also very easy. Picture that we have a function f(x) = 2x + 3x. Instead of just adding them together because of a shared variable, derive them separately. We know 2x's derivative is 2, and 3x's derivative is 3, so the final derivative of f(x) = 2x + 3x should be f'(x) = 2 + 3, or roughly f'(x) = 5, which is what the derivative of f(x) = 2x + 3x = 5x would have been.

There will be times when the components of an addition will not be the same, so we can instead derive each operand of the addition and subtraction functions.

derive (Add l r) = Add (derive l) (derive r)
derive (Sub l r) = Add (derive l) (derive r)

We recursively apply our deriver function to both sides to get our new values. Whether they become combined is going to be up to another function, but for now let's focus on what happens with multiplication, because that's where it's going to get wacky.

First, how do we get the value 2x in our expression tree? It's going to look like a nested Mul value, something like Mul (Var 'x') (Const 2), which may not be the prettiest thing in the world to look at. Writing rules gets a little bit tricky around here, so let's start off small.

derive (Mul (Const c) (Var x)) = Const c -- d/dx 5x -> 5
derive (Mul (Var x) (Const c)) = Const c -- same except flipped

Our basic derivatives like 3x extract the value from the constant part of the Mul value. This is the easier part of multiplication. For the next part, we have to cover the multiplication rule of Calculus.

Pretend we have a very weird-looking function like f(x) = sin(x) * cos(x). While we don't have cos or sin defined in our expression trees (yet), how would we even go about deriving something like this? Specially with trigonometric functions no less. Turns out there's a special rule when it comes to multiplying functions together where we multiply the left-hand function by the right-hand derivative, then multiply the right-hand function by the left-hand's derivative, then add these two values together. Confused? I hope not because it's super hard to understand without a deeper look at the derivative limit rule, but trust me that this is the formula.

derive (Mul f g) = Add (Mul f (derive g)) (Mul g (derive f))

While we're here, we may as well do something for division, which is oddly similar, but simply divides by the right-hand's function squared.

derive (Div f g) = Div (Sub (Mul g (derive f)) (Mul f (derive g))) (Pow g (Const 2))

Well, so far we covered addition, subtraction, multiplication and division. What's left? Exponents! What happens when we take a variable raised to a power? This one is pretty simple.

If we have a variable x raised to a power n, then it's derivative will look like nx^(n-1). We bring the power down, then subtract 1 from the original n and use that as the new power. There will be a few cases to cover here.

-- when we have a x^n -> nx^(n-1)
derive (Pow (Var l) (Const n)) = Pow (Mul (Const n) (Var l)) (Const (n-1))
-- when we have ax^n -> nax^(n-1)
derive (Pow (Mul (Const c) (Var l)) (Const n)) = Pow (Mul (Const (n*c)) (Var l)) (Const (n-1))
-- when we have a generic function f(x)^n -> f'(x) * n*f(x)^(n-1)
derive (Pow f (Const n)) = Mul (derive f) (Pow (Mul (Const n) f) (Const (n-1)))

Whew, got a little wacky there with the last one, but that is actually something called the Chain rule which has to do with when we have something called composite functions (functions that are wrapped like f(g(x))).

So that covers the basis of derivative rules in Calculus. The next section we are going to talk about simplification of expression trees so that our deriver function doesn't have to take into account every edge-case of how a tree may look.

Simplifying Expressions

Simplify is similar to how we would think about it in normal math pen-and-paper terms. Is one expression so similar looking that it's actually the same? This is going to help us clean up expression trees later so that they are as compressed as possible.

Let's take for instance some basic arithmetic properties.

1 + 1 = 2
2 + 0 = 2
0 + 2 = 2

When adding two flat values together, we combine them. When adding a zero, we just use whatever is non-zero. Does this look like anything relating to an expression tree?

Add (Const 1) (Const 1) => Const 2
Add (Const 2) (Const 0) => Const 2
Add (Const 0) (Const 2) => Const 2

Those are it's rough equivalents in expression tree terms. But we really don't need the components that don't actually do anything, and really, they just clutter up our tree too much. Our deriver doesn't need these values at the end because they get cancelled out no matter what, so we can clean up the tree by creating a simplify function.

simplify :: Expr a -> Expr a
simplify (Add (Const l) (Const r)) = Const (l+r)
simplify (Add l         (Const 0)) = simplify l
simplify (Add (Const 0)         r) = simplify r

-- base case, this is at the end
simplify e = e

The first line is a rule to combine two constant values into a single constant value, removing the need for an addition node. The other two are a way of cleaning up an addition node when the branches are zero. We recursively apply our simplifier over the entire tree so we can match other rules along the way, but for now we need a base case of simplify e = e so we can recurse properly.

The same can be done for subtraction because it is the inverse of addition (with a minor twist).

simplify (Sub (Const l) (Const r)) = Const (l-r)
simplify (Sub a         (Const 0)) = simplify a
simplify (Sub (Const 0)         b) = simplify (neg b)
simplify (Sub a (Mul (Const (-1)) b)) = Add a b

So one difference is that when we have an expression x-0, it reduces to x, but if we had 0-x, then we get -x. Let's pretend we have a neg function that is capable of negating anything we pass it, but that last rule is how we express negative numbers using a multiply node. If we end up with an expression that looks like x-(-y), then it becomes an addition node instead.

Now, for multiplication, division and power nodes, there are a lot of weird edge cases. I'm going to list them out and leave comments in the sections that might need explaining, but hopefully it's easy to understand.

-- multiply node rules
simplify (Mul (Const a) (Const b)) = Const (a * b)
simplify (Mul a         (Const 1)) = simplify a
simplify (Mul (Const 1)         b) = simplify b
simplify (Mul a         (Const 0)) = Const 0
simplify (Mul (Const 0)         b) = Const 0
simplify (Mul (Var l)   (Const c)) = Mul (Const c) (Var l) -- keep constant on left-hand side
simplify (Mul (Mul (Const (-1)) f) (Mul (Const (-1)) g)) = Mul f g

-- combine equal terms (and when there's neg numbers involved)
simplify (Mul a                    b) | a==b = Pow a (Const 2)
simplify (Mul a (Mul (Const (-1)) b)) | a==b = Mul (Const (-1)) (Pow a (Const 2))
simplify (Mul (Mul (Const (-1)) a) b) | a==b = Mul (Const (-1)) (Pow a (Const 2))

-- division rules (divide by zero results in a NaN)
simplify (Div (Const a) (Const b)) = Const (a / b)
simplify (Div (Const 0)         _) = Const 0
simplify (Div _         (Const 0)) = NaN
simplify (Div a                 b) | a == b = Const 1   -- (x/x) = 1
simplify (Div a         (Const 1)) = simplify a

simplify (Pow (Const a) (Const b)) = Const (a ** b)
simplify (Pow a         (Const 1)) = simplify a
simplify (Pow a         (Const 0)) = Const 1
simplify (Pow (Pow c (Const b)) (Const a)) = Pow c (Const (a*b))

-- instead of having (-f + g), lets turn it into (g - f)
simplify (Add (Mul (Const (-1)) a) b) = Sub b a

We use some Haskell guard statements to identify some rules for combining expression tree nodes. Like for division, if we have Div a b but a == b, then it's actually a Const 1 instead, because it divides by itself. This also happens a bit in the multiplication rules to combine terms into powers instead.

Lastly, let's add some recursive simplify rules.

simplify (Add a b) = Add (simplify a) (simplify b)
simplify (Sub a b) = Sub (simplify a) (simplify b)
simplify (Mul a b) = Mul (simplify a) (simplify b)
simplify (Div a b) = Div (simplify a) (simplify b)

Now with all those rules defined, we can test derive some equations (that sounds horribly like test drive doesn't it).

Testing it Out

Okay, can we finally use this to do some Calculus homework for us? Absolutely!

testD :: Expr a -> IO ()
testD = putStrLn . show . derive . simplify

main :: IO ()
main = do
  testD $ Const 5
  testD $ Var 'x'
  testD $ Mul (Const 5) (Var 'x')
  testD $ Pow (Var 'x') (Const 2)
  testD $ Pow (Mul (Const 7) (Var 'x')) (Const 5)
  testD $ Add (Pow (Mul (Const 7) (Var 'x')) (Const 5)) (Pow (Mul (Const 8) (Var 'x')) (Const 3))

If we do this, our output should be

Const 0
Const 1
Const 5
Mul (Const 2) (Var 'x')
Pow (Mul (Const 35) (Var 'x')) (Const 4)
Add (Pow (Mul (Const 35) (Var 'x')) (Const 4)) (Pow (Mul (Const 24) (Var 'x')) (Const 2))

So it isn't exactly the prettiest to type out at times, but this is the underlying basis for something called Symbolic Calculus, which libraries like SymPy are really great at doing. Only this time we're leveraging Haskell's ability to use algebraic data types and pattern match destructuring to create simple rules. In fact you can dive even deeper into SymPy's definition for an Add class which handles many instances of different types of internal values to see how wild it can get in Python.

Evaluating Derivatives with Variable Replacement

So in total, we've defined an expression tree, created derivative rules, and then simplified the trees if they get too complex and need to be shortened. What can we do now?

A key thing would be able to, well, numerically use these functions. We can do that with some basic variable subsitution. This will involve some pattern matching against the overall structure of the expression tree, but will enable us to evaluate trees later.

Let's start with a basic replace function which will subsitute a given expression tree's variable with whatever value we want, given we provide the name of the variable to replace (otherwise no replacement should happen).

replace :: Num a => Char -> a -> Expr a -> Expr a
replace x v (Var s) | x == s = (Const v)
replace x v (Add l r) = Add (replace x v l) (replace x v r)
replace x v (Sub l r) = Sub (replace x v l) (replace x v r)
replace x v (Mul l r) = Mul (replace x v l) (replace x v r)
replace x v (Div l r) = Div (replace x v l) (replace x v r)
replace x v (Pow l r) = Pow (replace x v l) (replace x v r)

replace _ _ e = e -- others cannot be recursively replaced

Let's test this out on a basic 2x + 3 expression.

let sample = Add (Mul (Const 2) (Var 'x')) (Const 3)

let sample' = replace 'x' 3 sample
-- evaluates to `Add (Mul (Const 2) (Const 3)) (Const 3)

There are many more rules we can add to our Haskell program, like Logarithmic functions, Trig functions, hyperbolic functions, or even complex functions, but I might leave those up to both the reader, and for me to add in later as I see fit. I doubt I did a great job of explaining both Haskell pattern matching rules and some Differential Calculus rules in the same post. I can not make a claim that I am a Haskell professor or a math professor, but there's room for growth.

Hopefully you enjoyed reading this very minimal dive into differential calculus with Haskell.