CalculusWithJuliaNotes.jl/quarto/derivatives/symbolic_derivatives.qmd
2024-07-31 11:24:53 -04:00

227 lines
7.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Symbolic derivatives
{{< include ../_common_code.qmd >}}
This section uses the `TermInterface` add-on package.
```{julia}
using TermInterface
```
```{julia}
#| echo: false
const frontmatter = (
title = "Symbolic derivatives",
description = "Calculus with Julia: Symbolic derivatives",
tags = ["CalculusWithJulia", "derivatives", "symbolic derivatives"],
);
```
---
The ability to breakdown an expression into operations and their arguments is necessary when trying to apply the differentiation rules. Such rules are applied from the outside in. Identifying the proper "outside" function is usually most of the battle when finding derivatives.
In the following example, we provide a sketch of a framework to differentiate expressions by a chosen symbol to illustrate how the outer function drives the task of differentiation.
The `Symbolics` package provides native symbolic manipulation abilities for `Julia`, similar to `SymPy`, though without the dependence on `Python`. The `TermInterface` package, used by `Symbolics`, provides a generic interface for expression manipulation for this package that *also* is implemented for `Julia`'s expressions and symbols.
An expression is an unevaluated portion of code that for our purposes below contains other expressions, symbols, and numeric literals. They are held in the `Expr` type. A symbol, such as `:x`, is distinct from a string (e.g. `"x"`) and is useful to the programmer to distinguish between the contents a variable points to from the name of the variable. Symbols are fundamental to metaprogramming in `Julia`. An expression is a specification of some set of statements to execute. A numeric literal is just a number.
The three main functions from `TermInterface` we leverage are `isexpr`, `operation`, and `arguments`. The `operation` function returns the "outside" function of an expression. For example:
```{julia}
operation(:(sin(x)))
```
We see the `sin` function, referred to by a symbol (`:sin`). The `:(...)` above *quotes* the argument, and does not evaluate it, hence `x` need not be defined above. (The `:` notation is used to create both symbols and expressions.)
The arguments are the terms that the outside function is called on. For our purposes there may be $1$ (*unary*), $2$ (*binary*), or more than $2$ (*nary*) arguments. (We ignore zero-argument functions.) For example:
```{julia}
arguments(:(-x)), arguments(:(pi^2)), arguments(:(1 + x + x^2))
```
(The last one may be surprising, but all three arguments are passed to the `+` function.)
`TermInterface` has an `arity` function defined by `length(arguments(ex))` that will be used for dispatch below.
Differentiation must distinguish between expressions, variables, and numbers. Mathematically expressions have an "outer" function, whereas variables and numbers can be directly differentiated. The `isexpr` function in `TermInterface` returns `true` when passed an expression, and `false` when passed a symbol or numeric literal. The latter two may be distinguished by `isa(..., Symbol)`.
Here we create a function, `D`, that when it encounters an expression it *dispatches* to a specific method of `D` based on the outer operation and arity, otherwise if it encounters a symbol or a numeric literal it does the differentiation:
```{julia}
function D(ex, var=:x)
if isexpr(ex)
op, args = operation(ex), arguments(ex)
D(Val(op), Val(arity(ex)), args, var)
elseif isa(ex, Symbol) && ex == var
1
else
0
end
end
```
(The use of `Val` is an idiom of `Julia` allowing dispatch on certain values such as function names and numbers.)
Now to develop methods for `D` for different "outside" functions and arities.
Addition can be unary (`:(+x)` is a valid quoting, even if it might simplify to the symbol `:x` when evaluated), *binary*, or *nary*. Here we implement the *sum rule*:
```{julia}
D(::Val{:+}, ::Val{1}, args, var) = D(first(args), var)
function D(::Val{:+}, ::Val{2}, args, var)
a, b = D.(args, var)
:($a + $b)
end
function D(::Val{:+}, ::Any, args, var)
as = D.(args, var)
:(+($as...))
end
```
The `args` are always held in a container, so the unary method must pull out the first one. The binary case should read as: apply `D` to each of the two arguments, and then create a quoted expression containing the sum of the results. The dollar signs interpolate into the quoting. (The "primes" are unicode notation achieved through `\prime[tab]` and not operations.) The *nary* method (which catches *any* arity besides `1` and `2`) does something similar, only using splatting to produce the sum.
Subtraction must also be implemented in a similar manner, but not for the *nary* case, as subtraction is not associative:
```{julia}
function D(::Val{:-}, ::Val{1}, args, var)
a = D(first(args), var)
:(-$a)
end
function D(::Val{:-}, ::Val{2}, args, var)
a, b = D.(args, var)
:($a - $b)
end
```
The *product rule* is similar to addition, in that $3$ cases are considered:
```{julia}
D(op::Val{:*}, ::Val{1}, args, var) = D(first(args), var)
function D(::Val{:*}, ::Val{2}, args, var)
a, b = args
a, b = D.(args, var)
:($a * $b + $a * $b)
end
function D(op::Val{:*}, ::Any, args, var)
a, bs... = args
b = :(*($(bs...)))
a = D(a, var)
b = D(b, var)
:($a * $b + $a * $b)
end
```
The *nary* case above just peels off the first factor and then uses the binary product rule.
Division is only a binary operation, so here we have the *quotient rule*:
```{julia}
function D(::Val{:/}, ::Val{2}, args, var)
u,v = args
u, v = D(u, var), D(v, var)
:( ($u*$v - $u*$v)/$v^2 )
end
```
Powers are handled a bit differently. The power rule would require checking if the exponent does not contain the variable of differentiation, exponential derivatives would require checking the base does not contain the variable of differentation. Trying to implement both would be tedious, so we use the fact that $x = \exp(\log(x))$ (for `x` in the domain of `log`, more care is necessary if `x` is negative) to differentiate:
```{julia}
function D(::Val{:^}, ::Val{2}, args, var)
a, b = args
D(:(exp($b*log($a))), var) # a > 0 assumed here
end
```
That leaves the task of defining a rule to differentiate both `exp` and `log`. We do so with *unary* definitions. In the following we also implement `sin` and `cos` rules:
```{julia}
function D(::Val{:exp}, ::Val{1}, args, var)
a = only(args)
a = D(a, var)
:(exp($a) * $a)
end
function D(::Val{:log}, ::Val{1}, args, var)
a = only(args)
a = D(a, var)
:(1/$a * $a)
end
function D(::Val{:sin}, ::Val{1}, args, var)
a = only(args)
a = D(a, var)
:(cos($a) * $a)
end
function D(::Val{:cos}, ::Val{1}, args, var)
a = only(args)
a = D(a, var)
:(-sin($a) * $a)
end
```
The pattern is similar for each. The `$a` factor is needed due to the *chain rule*. The above illustrates the simple pattern necessary to add a derivative rule for a function.
:::{.callout-note}
Several automatic differentiation packages use a set of rules defined following an interface spelled out in the package `ChainRules.jl`. Leveraging multi-dimensional derivatives, the chain rule is the only rule needed of the sum, product, quotient and chain rules.
:::
More functions could be included, but for this example the above will suffice, as now the system is ready to be put to work.
```{julia}
ex = :(x + 2/x)
D(ex, :x)
```
The output does not simplify, so some work is needed to identify `1 - 2/x^2` as the answer.
```{julia}
ex = :( (x + sin(x))/sin(x))
D(ex, :x)
```
Again, simplification is not performed.
Finally, we have a second derivative taken below:
```{julia}
ex = :(sin(x) - x - x^3/6)
D(D(ex, :x), :x)
```
The length of the expression should lead to further appreciation for simplification steps taken when doing such a computation by hand.