Lesson update

This commit is contained in:
Sylvain Gugger 2020-04-28 10:12:59 -07:00
parent b7f3f0d750
commit 5b70a64d66
6 changed files with 1411 additions and 422 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -580,63 +580,63 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>207</td>\n",
" <td>Four Weddings and a Funeral (1994)</td>\n",
" <td>3</td>\n",
" <td>542</td>\n",
" <td>My Left Foot (1989)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>565</td>\n",
" <td>Remains of the Day, The (1993)</td>\n",
" <td>5</td>\n",
" <td>422</td>\n",
" <td>Event Horizon (1997)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>506</td>\n",
" <td>Kids (1995)</td>\n",
" <td>1</td>\n",
" <td>311</td>\n",
" <td>African Queen, The (1951)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>845</td>\n",
" <td>595</td>\n",
" <td>Face/Off (1997)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>617</td>\n",
" <td>Evil Dead II (1987)</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>158</td>\n",
" <td>Jurassic Park (1993)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>836</td>\n",
" <td>Chasing Amy (1997)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>798</td>\n",
" <td>Being Human (1993)</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>500</td>\n",
" <td>Down by Law (1986)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>409</td>\n",
" <td>Much Ado About Nothing (1993)</td>\n",
" <th>7</th>\n",
" <td>474</td>\n",
" <td>Emma (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>721</td>\n",
" <td>Braveheart (1995)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>316</td>\n",
" <td>Psycho (1960)</td>\n",
" <td>2</td>\n",
" <td>466</td>\n",
" <td>Jackie Chan's First Strike (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>883</td>\n",
" <td>Judgment Night (1993)</td>\n",
" <td>5</td>\n",
" <td>554</td>\n",
" <td>Scream (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
@ -661,6 +661,27 @@
"In order to represent collaborative filtering in PyTorch we can't just use the crosstab representation directly, especially if we want it to fit into our deep learning framework. We can represent our movie and user latent factor tables as simple matrices:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...],\n",
" 'title': (#1635) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...]}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.classes"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -684,6 +705,35 @@
"It turns out that we can represent *look up in an index* as a matrix product! The trick is to replace our indices with one hot encoded vectors. Here is an example of what happens if we multiply a vector by a one hot encoded vector representing the index three:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"one_hot_3 = one_hot(3, n_users).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([944, 5])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -701,7 +751,6 @@
}
],
"source": [
"one_hot_3 = one_hot(3, n_users).float()\n",
"user_factors.t() @ one_hot_3"
]
},
@ -918,32 +967,32 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.326261</td>\n",
" <td>1.295701</td>\n",
" <td>0.993168</td>\n",
" <td>0.990168</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.091352</td>\n",
" <td>1.091475</td>\n",
" <td>00:11</td>\n",
" <td>0.884821</td>\n",
" <td>0.911269</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.961574</td>\n",
" <td>0.977690</td>\n",
" <td>00:11</td>\n",
" <td>0.671865</td>\n",
" <td>0.875679</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.829995</td>\n",
" <td>0.893122</td>\n",
" <td>0.471727</td>\n",
" <td>0.878200</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.781661</td>\n",
" <td>0.876511</td>\n",
" <td>0.361314</td>\n",
" <td>0.884209</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
@ -1006,33 +1055,33 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.976380</td>\n",
" <td>1.001455</td>\n",
" <td>0.973745</td>\n",
" <td>0.993206</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.875964</td>\n",
" <td>0.919960</td>\n",
" <td>0.869132</td>\n",
" <td>0.914323</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.685377</td>\n",
" <td>0.870664</td>\n",
" <td>0.676553</td>\n",
" <td>0.870192</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.483701</td>\n",
" <td>0.874071</td>\n",
" <td>0.485377</td>\n",
" <td>0.873865</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.385249</td>\n",
" <td>0.878055</td>\n",
" <td>00:12</td>\n",
" <td>0.377866</td>\n",
" <td>0.877610</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -405,63 +405,63 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>207</td>\n",
" <td>Four Weddings and a Funeral (1994)</td>\n",
" <td>3</td>\n",
" <td>542</td>\n",
" <td>My Left Foot (1989)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>565</td>\n",
" <td>Remains of the Day, The (1993)</td>\n",
" <td>5</td>\n",
" <td>422</td>\n",
" <td>Event Horizon (1997)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>506</td>\n",
" <td>Kids (1995)</td>\n",
" <td>1</td>\n",
" <td>311</td>\n",
" <td>African Queen, The (1951)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>845</td>\n",
" <td>595</td>\n",
" <td>Face/Off (1997)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>617</td>\n",
" <td>Evil Dead II (1987)</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>158</td>\n",
" <td>Jurassic Park (1993)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>836</td>\n",
" <td>Chasing Amy (1997)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>798</td>\n",
" <td>Being Human (1993)</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>500</td>\n",
" <td>Down by Law (1986)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>409</td>\n",
" <td>Much Ado About Nothing (1993)</td>\n",
" <th>7</th>\n",
" <td>474</td>\n",
" <td>Emma (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>721</td>\n",
" <td>Braveheart (1995)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>316</td>\n",
" <td>Psycho (1960)</td>\n",
" <td>2</td>\n",
" <td>466</td>\n",
" <td>Jackie Chan's First Strike (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>883</td>\n",
" <td>Judgment Night (1993)</td>\n",
" <td>5</td>\n",
" <td>554</td>\n",
" <td>Scream (1996)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
@ -479,6 +479,27 @@
"dls.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...],\n",
" 'title': (#1635) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...]}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.classes"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -493,6 +514,35 @@
"movie_factors = torch.randn(n_movies, n_factors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"one_hot_3 = one_hot(3, n_users).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([944, 5])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -510,7 +560,6 @@
}
],
"source": [
"one_hot_3 = one_hot(3, n_users).float()\n",
"user_factors.t() @ one_hot_3"
]
},
@ -641,32 +690,32 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.326261</td>\n",
" <td>1.295701</td>\n",
" <td>0.993168</td>\n",
" <td>0.990168</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.091352</td>\n",
" <td>1.091475</td>\n",
" <td>00:11</td>\n",
" <td>0.884821</td>\n",
" <td>0.911269</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.961574</td>\n",
" <td>0.977690</td>\n",
" <td>00:11</td>\n",
" <td>0.671865</td>\n",
" <td>0.875679</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.829995</td>\n",
" <td>0.893122</td>\n",
" <td>0.471727</td>\n",
" <td>0.878200</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.781661</td>\n",
" <td>0.876511</td>\n",
" <td>0.361314</td>\n",
" <td>0.884209</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
@ -722,33 +771,33 @@
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.976380</td>\n",
" <td>1.001455</td>\n",
" <td>0.973745</td>\n",
" <td>0.993206</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.875964</td>\n",
" <td>0.919960</td>\n",
" <td>0.869132</td>\n",
" <td>0.914323</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.685377</td>\n",
" <td>0.870664</td>\n",
" <td>0.676553</td>\n",
" <td>0.870192</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.483701</td>\n",
" <td>0.874071</td>\n",
" <td>0.485377</td>\n",
" <td>0.873865</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.385249</td>\n",
" <td>0.878055</td>\n",
" <td>00:12</td>\n",
" <td>0.377866</td>\n",
" <td>0.877610</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"