Revised with 1.6.0

This commit is contained in:
David Doblas Jiménez 2021-06-22 18:55:20 +02:00
parent 017f2e8300
commit ef2cf0c06b

View File

@ -0,0 +1,842 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural Networks\n",
"In this notebook, we will walk through one main neural nets example. And that is, classifying the infamous MNIST dataset. **If you have no experience with neural nets prior to this notebook, I recommend doing a quick search for an \"intro to neural nets\"**, there are multiple tutorials/blog posts out there and you can choose the one that works for you.\n",
"\n",
"Here, we will use the `Flux` package, but if you want to look at other packages I encourage you to look at `Knet.jl` and `TensorFlow.jl`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"#import Pkg; Pkg.add([\"Flux\",\n",
"# ])\n",
"using Flux, Flux.Data.MNIST\n",
"using Flux: onehotbatch, argmax, crossentropy, throttle\n",
"using Base.Iterators: repeated\n",
"using Images"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at one of the images."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: Flux's datasets are deprecated, please use the package MLDatasets.jl\n",
"└ @ Flux.Data /opt/julia/packages/Flux/qp1gc/src/data/Data.jl:17\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAAAAADji6uXAAAABGdBTUEAALGPC/xhBQAAAAFzUkdCAK7OHOkAAAAgY0hSTQAAeiYAAICEAAD6AAAAgOgAAHUwAADqYAAAOpgAABdwnLpRPAAAAYtJREFUaAW9wb+LzgEAB+AnPqPUXUZdWUhRlMhksFyZDbIYleVMlH/ATpRJ2WwiFv+DssggkqvrhktnYOAY3v3e7w99nifKoizKoizKoizKoizKoizKoizKoizKoizKoizKoixmuornuIkn+Gt/URZlURYz3cUeHuENvtpflEVZlEVZlEVZlMUMp3DUwk/8sVyURVmUxQzXccTCBjYtF2VRFmUxw6qFH9gyTJRFWZTFRIdwwcJrvDJMlEVZlMVEF3HawpbhoizKoiwmOIyTFnbwwHBRFmVRFhOcxW3s4DE+Gy7KoizKYoJ7WMMOnhknyqIsymKkFXzAZfzGnnGiLMqiLEY6hyvYxkN8Mk6URVmUxQiruI9jeImnxouyKIuyGGgFt3AGb3EHm8aLsiiLshjoPK5hGxv4aJooi7Ioi4Fu4AS+4ZfpoizKoiwGWMdxvMcl7JouyqIsymKJNbzAd6xj1zxRFmVRFkscQLCBd+aLsiiLsljiCw76f6IsyqIsyqIsyqIsyqIsyqLsH5MmL74zpQwVAAAAAElFTkSuQmCC",
"text/plain": [
"28×28 Array{Gray{N0f8},2} with eltype Gray{N0f8}:\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) … Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) … Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) … Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" ⋮ ⋱ \n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) … Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) … Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)\n",
" Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0) Gray{N0f8}(0.0)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"imgs = MNIST.images()\n",
"colorview(Gray,imgs[100])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Matrix{Gray{N0f8}} (alias for Array{Gray{Normed{UInt8, 8}}, 2})"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"typeof(imgs[100])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we will transofrm the gray scale values to Float32 types. Here, using Float32 will speedup the neural network substantially withough compromising the quality of the solution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"60000-element Vector{Matrix{Float32}}:\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" ⋮\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]\n",
" [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"myFloat32(X) = Float32.(X)\n",
"fpt_imgs = myFloat32.(imgs) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Matrix{Float32} (alias for Array{Float32, 2})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"typeof(fpt_imgs[3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will now create a few helpful functions..."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"vectorize(x) = x[:]\n",
"vectorized_imgs = vectorize.(fpt_imgs);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Vector{Vector{Float32}} (alias for Array{Array{Float32, 1}, 1})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"typeof(vectorized_imgs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will again make use of `...` as the splat operator to concatenate all images into one matrix."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(784, 60000)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = hcat(vectorized_imgs...)\n",
"size(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, every column in `X` is an image of a number. We have `60,000` images. When reshaped into a 28-by-28 matrix, and displayed as an image, can be seen as a handwritten number. Here is an example below."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAAAAADji6uXAAAABGdBTUEAALGPC/xhBQAAAAFzUkdCAK7OHOkAAAAgY0hSTQAAeiYAAICEAAD6AAAAgOgAAHUwAADqYAAAOpgAABdwnLpRPAAAAbVJREFUaAW9wT2rDQAABuDnnN6UhSLdlNQdbrZrMjCgbkl2E9kY/ASrSfkBBukOjMJkMJkQRvJRihhsSpFSDGdw7nHP/Tid3ueJsiiLsiiLsiiLsiiLsiiLsiiLsiiLsiiLsiiLsiiLOVvBHZzAW/+LsiiLsphwHHtxz2yO4IXpoizKoiwmnMQS7tm+IRZxEAPri7Ioi7KYcAFPzGY/LuI23lhflEVZlMWEodndNPLedFEWZVEWY5axYHa7jTwyXZRFWZTFmDPYaTYLWDTyxXRRFmVRFmMOGXll+65jAe/w3XRRFmVRFut4bmt24TTO45SRq/hmuiiLsiiLdeyx1mEMsYID2IFzGOInnuEXgpc2FmVRFmUx5if+4Aau+GcZA/zGD7zGLbzAY3zFZ+zEGxuLsiiLshhzGR9xzFqf8ACv8dT/LmEfPthclEVZlMWEa7Zvxchdm4uyKIuymKP7NhdlURZlURZlURZlMScDLOGJjUVZlEVZzMkfDG0uyqIsymKOjmLVxqIsyqIs5mRga6IsyqIs5uAhztqaKIuyKIs5WMWqrYmyKIuyKIuyKIuyKIuyKIuyv/irMYSJ7ydGAAAAAElFTkSuQmCC",
"text/plain": [
"28×28 reinterpret(reshape, Gray{Float32}, ::Matrix{Float32}) with eltype Gray{Float32}:\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" ⋮ ⋱ \n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onefigure = X[:,3]\n",
"t1 = reshape(onefigure,28,28)\n",
"colorview(Gray,t1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will obtain the labels. These are the true labels for the `60,000` images."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: Flux's datasets are deprecated, please use the package MLDatasets.jl\n",
"└ @ Flux.Data /opt/julia/packages/Flux/qp1gc/src/data/Data.jl:17\n"
]
},
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = MNIST.labels()\n",
"labels[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From these labels, we will create a new output column for each image. These columns will be indicator vectors of where the correct label is.\n",
"\n",
"For example if the figure corresponding to column `X[:,i]` is a `3`, the `i'th` column in this new matrix `Y` is `[0 0 0 1 0 0 0 0 0 0]`. (It is the entry number 4 because entry 1 corresponds to the digit 0, so the counting starts from zero). The `onehotbatch` function allows us to create this easily."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"Y = onehotbatch(labels,0:9);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we will actually build our neural network. We will use two layers. The hidden layer will have 32 nodes, and the output layer will have 10 nodes. i.e. we will go from: `28*28 => 32 => 10`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Chain(Dense(784, 32, relu), Dense(32, 10), softmax)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = Chain(\n",
" Dense(28^2,32,relu),\n",
" Dense(32,10),\n",
" softmax)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What does `m`, the neural network mean here? \n",
"\n",
"If you've worked with neural networks before you know that the solution is often not found by just one pass on the neural network. One pass happens, and a solution is generated at the output layer, then this solution is compared to the ground truth solution we already have (the columns from `Y`), and the network goes back and adjusts its weights and parameters and then try again. Here, since `m` is not \"trained\" yet, one pass of `m` on a figure generates the following (not-so-great) answer. We will see later how this changes after training."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10-element Vector{Float32}:\n",
" 0.07273515\n",
" 0.1181932\n",
" 0.07084775\n",
" 0.09932565\n",
" 0.091627285\n",
" 0.0535648\n",
" 0.08405617\n",
" 0.17503478\n",
" 0.12413119\n",
" 0.110483915"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m(onefigure)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run our neural network, we need a loss function and an accuracy function. The accuracy function is used to compare the output result from the output layer in the neural network to the groundtruth result. The loss function is used to evaluate the performance of the overall model after new weights have been recalculated at each pass."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"accuracy (generic function with 1 method)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss(x, y) = Flux.crossentropy(m(x),y)\n",
"accuracy(x, y) = mean(argmax(m(x)).== argmax(y))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we will repeat our data so that we have more samples to pass to the neural network, which means there will be more chances for corrections."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"datasetx = repeated((X,Y),200)\n",
"C = collect(datasetx);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will create a function to display the loss at each step."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"#1 (generic function with 1 method)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evalcb = ()->@show(loss(X,Y))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Params([Float32[0.0069256853 0.046717647 … 0.044092037 0.011616245; 0.027889356 -0.036823545 … -0.03683305 -0.037362; … ; -0.04521925 0.08110898 … -0.023713868 -0.01291985; -0.048453096 -0.06124558 … 0.07536976 0.014303556], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.05023211 0.11609526 … -0.011820401 -0.257053; 0.24301344 -0.07858523 … 0.26390162 0.25781706; … ; 0.15992884 0.076332204 … -0.29110464 -0.016269587; -0.06458335 0.0910367 … 0.14926136 -0.3354616], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ps = Flux.params(m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we are ready to train the model, we will use the `Flux.train!` function. Let's take a look at the documentation."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{verbatim}\n",
"train!(loss, params, data, opt; cb)\n",
"\\end{verbatim}\n",
"For each datapoint \\texttt{d} in \\texttt{data}, compute the gradient of \\texttt{loss} with respect to \\texttt{params} through backpropagation and call the optimizer \\texttt{opt}.\n",
"\n",
"If \\texttt{d} is a tuple of arguments to \\texttt{loss} call \\texttt{loss(d...)}, else call \\texttt{loss(d)}.\n",
"\n",
"A callback is given with the keyword argument \\texttt{cb}. For example, this will print \"training\" every 10 seconds (using \\href{@ref}{\\texttt{Flux.throttle}}): train!(loss, params, data, opt, cb = throttle(() -> println(\"training\"), 10))\n",
"\n",
"The callback can call \\href{@ref}{\\texttt{Flux.stop}} to interrupt the training loop.\n",
"\n",
"Multiple optimisers and callbacks can be passed to \\texttt{opt} and \\texttt{cb} as arrays.\n",
"\n"
],
"text/markdown": [
"```\n",
"train!(loss, params, data, opt; cb)\n",
"```\n",
"\n",
"For each datapoint `d` in `data`, compute the gradient of `loss` with respect to `params` through backpropagation and call the optimizer `opt`.\n",
"\n",
"If `d` is a tuple of arguments to `loss` call `loss(d...)`, else call `loss(d)`.\n",
"\n",
"A callback is given with the keyword argument `cb`. For example, this will print \"training\" every 10 seconds (using [`Flux.throttle`](@ref)): train!(loss, params, data, opt, cb = throttle(() -> println(\"training\"), 10))\n",
"\n",
"The callback can call [`Flux.stop`](@ref) to interrupt the training loop.\n",
"\n",
"Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.\n"
],
"text/plain": [
"\u001b[36m train!(loss, params, data, opt; cb)\u001b[39m\n",
"\n",
" For each datapoint \u001b[36md\u001b[39m in \u001b[36mdata\u001b[39m, compute the gradient of \u001b[36mloss\u001b[39m with respect to\n",
" \u001b[36mparams\u001b[39m through backpropagation and call the optimizer \u001b[36mopt\u001b[39m.\n",
"\n",
" If \u001b[36md\u001b[39m is a tuple of arguments to \u001b[36mloss\u001b[39m call \u001b[36mloss(d...)\u001b[39m, else call \u001b[36mloss(d)\u001b[39m.\n",
"\n",
" A callback is given with the keyword argument \u001b[36mcb\u001b[39m. For example, this will\n",
" print \"training\" every 10 seconds (using \u001b[36mFlux.throttle\u001b[39m): train!(loss,\n",
" params, data, opt, cb = throttle(() -> println(\"training\"), 10))\n",
"\n",
" The callback can call \u001b[36mFlux.stop\u001b[39m to interrupt the training loop.\n",
"\n",
" Multiple optimisers and callbacks can be passed to \u001b[36mopt\u001b[39m and \u001b[36mcb\u001b[39m as arrays."
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"?Flux.train!"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss(X, Y) = 2.2936802f0\n",
"loss(X, Y) = 1.492911f0\n",
"loss(X, Y) = 0.93367034f0\n",
"loss(X, Y) = 0.65324676f0\n",
"loss(X, Y) = 0.51964563f0\n",
"loss(X, Y) = 0.44787213f0\n",
"loss(X, Y) = 0.39854404f0\n",
"loss(X, Y) = 0.36422792f0\n",
"loss(X, Y) = 0.34129208f0\n",
"loss(X, Y) = 0.32058308f0\n",
"loss(X, Y) = 0.30383688f0\n",
"loss(X, Y) = 0.28992468f0\n"
]
}
],
"source": [
"opt = ADAM()\n",
"Flux.train!(loss,ps,datasetx,opt,cb=throttle(evalcb,10))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will now get the test data."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: Flux's datasets are deprecated, please use the package MLDatasets.jl\n",
"└ @ Flux.Data /opt/julia/packages/Flux/qp1gc/src/data/Data.jl:17\n"
]
}
],
"source": [
"tX = hcat(float.(reshape.(MNIST.images(:test),:))...);\n",
"test_image = m(tX[:,1]);"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"argmax(test_image)-1"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAHAAAABwCAAAAADji6uXAAAABGdBTUEAALGPC/xhBQAAAAFzUkdCAK7OHOkAAAAgY0hSTQAAeiYAAICEAAD6AAAAgOgAAHUwAADqYAAAOpgAABdwnLpRPAAAAdBJREFUaAW9wb1qlgcABtCDeToUXLRU6FB/cOtSgggFWyh0EV0EvYXUoXQpBFxCQYdAxg7egeAFlBIKKXTRJYv4UyrGwYoIhQS0Q1ChDu8QBL/4vfnCc06URVmURVmURVmURVmURVmURVmURVmURVmURVmURVmURVmURVmURVmURVlMcAkLeIZt3MBzPDKbKIuyKIsJVnDcjst4ift29xQrWPd+URZlURYTLOBLPMAXmMe3+Ar/4HM73uBffGbwBOveL8qiLMpigjWsGawaHMI81nHajm08xF84jMcmi7Ioi7IYYQt/GKx510Ucwl3cNFmURVmUxT44gus4gKvYNFmURVmUxT74AZ9iC3/bXZRFWZTFjM7gisEF3LO7KIuyKIsZncNHWMNtHxZlURZlMYOPcRav8DNe+7Aoi7IoixksYh6ruGU6URZlURZ7dB5LeIFrphdlURZlsQef4BfM4TfcNr0oi7Ioi5HmsIoT2MCScaIsyqIsRjqJUwY/YcM4URZlURYjHMPvBov41XhRFmVRFiN8j6MGf+J/40VZlEVZTOkb/Gh2URZlURZT+hoHDTbwn72JsiiLshjpDr7Dpr2JsiiLspjSMpbNLsqiLMreApamPWWOWvFrAAAAAElFTkSuQmCC",
"text/plain": [
"28×28 Array{Gray{Float32},2} with eltype Gray{Float32}:\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" ⋮ ⋱ \n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) … Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)\n",
" Gray{Float32}(0.0) Gray{Float32}(0.0) Gray{Float32}(0.0)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t1 = reshape(tX[:,1],28,28)\n",
"colorview(Gray, t1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about the image we tried a few cells earlier and returned the \"not-so-great\" answer."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10-element Vector{Float32}:\n",
" 0.9968736\n",
" 2.957651f-7\n",
" 0.00023512624\n",
" 0.00019518785\n",
" 1.3028435f-7\n",
" 0.0024404125\n",
" 2.0499616f-5\n",
" 9.000515f-5\n",
" 0.00012760425\n",
" 1.7137647f-5"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"onefigure = X[:,2]\n",
"m(onefigure)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10-element Flux.OneHotVector{10,UInt32}:\n",
" 1\n",
" 0\n",
" 0\n",
" 0\n",
" 0\n",
" 0\n",
" 0\n",
" 0\n",
" 0\n",
" 0"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y[:,2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Finally...\n",
"After finishing this notebook, you should be able to:\n",
"- [ ] prepare data to fit the format to create a neural network using Flux.jl\n",
"- [ ] create a neural network with Flux.jl\n",
"- [ ] creating an accuracy function and loss function to be passed to train the neural network\n",
"- [ ] train the neural network\n",
"- [ ] describe a few tips that can help make your nerual network faster or more accurate (such as using Float32 as opposed to Float32)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🥳 One cool finding\n",
"\n",
"We ran a trained a neural network on a dataset of of handwritten digits (called the MNIST dataset). At the end, we were able to pass this figure to the neural network and the return result was:\n",
"\n",
"<center><img src=\"data/1001.png\" width=\"150\"></center>\n",
"\n",
"```\n",
"10-element Array{Float32,1}:\n",
" 0.00029263002\n",
" 1.5993925f-5\n",
" 0.0002862561\n",
" 0.0035434738\n",
" 1.388653f-5\n",
" 2.4878627f-5\n",
" 6.433018f-7\n",
" 0.99414164 ### <= this is the highest number!\n",
" 0.000118321994\n",
" 0.0015623316\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.6.0",
"language": "julia",
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}