Lesson update
This commit is contained in:
parent
b7f3f0d750
commit
5b70a64d66
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
171
08_collab.ipynb
171
08_collab.ipynb
@ -580,63 +580,63 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>207</td>\n",
|
||||
" <td>Four Weddings and a Funeral (1994)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>542</td>\n",
|
||||
" <td>My Left Foot (1989)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>565</td>\n",
|
||||
" <td>Remains of the Day, The (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>422</td>\n",
|
||||
" <td>Event Horizon (1997)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>506</td>\n",
|
||||
" <td>Kids (1995)</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>311</td>\n",
|
||||
" <td>African Queen, The (1951)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>845</td>\n",
|
||||
" <td>595</td>\n",
|
||||
" <td>Face/Off (1997)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>617</td>\n",
|
||||
" <td>Evil Dead II (1987)</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>158</td>\n",
|
||||
" <td>Jurassic Park (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>836</td>\n",
|
||||
" <td>Chasing Amy (1997)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>798</td>\n",
|
||||
" <td>Being Human (1993)</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>500</td>\n",
|
||||
" <td>Down by Law (1986)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>409</td>\n",
|
||||
" <td>Much Ado About Nothing (1993)</td>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>474</td>\n",
|
||||
" <td>Emma (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>721</td>\n",
|
||||
" <td>Braveheart (1995)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>316</td>\n",
|
||||
" <td>Psycho (1960)</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>466</td>\n",
|
||||
" <td>Jackie Chan's First Strike (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>883</td>\n",
|
||||
" <td>Judgment Night (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>554</td>\n",
|
||||
" <td>Scream (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>"
|
||||
@ -661,6 +661,27 @@
|
||||
"In order to represent collaborative filtering in PyTorch we can't just use the crosstab representation directly, especially if we want it to fit into our deep learning framework. We can represent our movie and user latent factor tables as simple matrices:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...],\n",
|
||||
" 'title': (#1635) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...]}"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dls.classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -684,6 +705,35 @@
|
||||
"It turns out that we can represent *look up in an index* as a matrix product! The trick is to replace our indices with one hot encoded vectors. Here is an example of what happens if we multiply a vector by a one hot encoded vector representing the index three:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"one_hot_3 = one_hot(3, n_users).float()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([944, 5])"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"user_factors.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -701,7 +751,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"one_hot_3 = one_hot(3, n_users).float()\n",
|
||||
"user_factors.t() @ one_hot_3"
|
||||
]
|
||||
},
|
||||
@ -918,32 +967,32 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>1.326261</td>\n",
|
||||
" <td>1.295701</td>\n",
|
||||
" <td>0.993168</td>\n",
|
||||
" <td>0.990168</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>1.091352</td>\n",
|
||||
" <td>1.091475</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" <td>0.884821</td>\n",
|
||||
" <td>0.911269</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>0.961574</td>\n",
|
||||
" <td>0.977690</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" <td>0.671865</td>\n",
|
||||
" <td>0.875679</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.829995</td>\n",
|
||||
" <td>0.893122</td>\n",
|
||||
" <td>0.471727</td>\n",
|
||||
" <td>0.878200</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.781661</td>\n",
|
||||
" <td>0.876511</td>\n",
|
||||
" <td>0.361314</td>\n",
|
||||
" <td>0.884209</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
@ -1006,33 +1055,33 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0.976380</td>\n",
|
||||
" <td>1.001455</td>\n",
|
||||
" <td>0.973745</td>\n",
|
||||
" <td>0.993206</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0.875964</td>\n",
|
||||
" <td>0.919960</td>\n",
|
||||
" <td>0.869132</td>\n",
|
||||
" <td>0.914323</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>0.685377</td>\n",
|
||||
" <td>0.870664</td>\n",
|
||||
" <td>0.676553</td>\n",
|
||||
" <td>0.870192</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.483701</td>\n",
|
||||
" <td>0.874071</td>\n",
|
||||
" <td>0.485377</td>\n",
|
||||
" <td>0.873865</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.385249</td>\n",
|
||||
" <td>0.878055</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" <td>0.377866</td>\n",
|
||||
" <td>0.877610</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>"
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -405,63 +405,63 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>207</td>\n",
|
||||
" <td>Four Weddings and a Funeral (1994)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>542</td>\n",
|
||||
" <td>My Left Foot (1989)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>565</td>\n",
|
||||
" <td>Remains of the Day, The (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>422</td>\n",
|
||||
" <td>Event Horizon (1997)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>506</td>\n",
|
||||
" <td>Kids (1995)</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>311</td>\n",
|
||||
" <td>African Queen, The (1951)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>845</td>\n",
|
||||
" <td>595</td>\n",
|
||||
" <td>Face/Off (1997)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>617</td>\n",
|
||||
" <td>Evil Dead II (1987)</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>158</td>\n",
|
||||
" <td>Jurassic Park (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>836</td>\n",
|
||||
" <td>Chasing Amy (1997)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>798</td>\n",
|
||||
" <td>Being Human (1993)</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>500</td>\n",
|
||||
" <td>Down by Law (1986)</td>\n",
|
||||
" <td>4</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>409</td>\n",
|
||||
" <td>Much Ado About Nothing (1993)</td>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>474</td>\n",
|
||||
" <td>Emma (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>721</td>\n",
|
||||
" <td>Braveheart (1995)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>316</td>\n",
|
||||
" <td>Psycho (1960)</td>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>466</td>\n",
|
||||
" <td>Jackie Chan's First Strike (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>883</td>\n",
|
||||
" <td>Judgment Night (1993)</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>554</td>\n",
|
||||
" <td>Scream (1996)</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>"
|
||||
@ -479,6 +479,27 @@
|
||||
"dls.show_batch()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...],\n",
|
||||
" 'title': (#1635) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...]}"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dls.classes"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -493,6 +514,35 @@
|
||||
"movie_factors = torch.randn(n_movies, n_factors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"one_hot_3 = one_hot(3, n_users).float()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([944, 5])"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"user_factors.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@ -510,7 +560,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"one_hot_3 = one_hot(3, n_users).float()\n",
|
||||
"user_factors.t() @ one_hot_3"
|
||||
]
|
||||
},
|
||||
@ -641,32 +690,32 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>1.326261</td>\n",
|
||||
" <td>1.295701</td>\n",
|
||||
" <td>0.993168</td>\n",
|
||||
" <td>0.990168</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>1.091352</td>\n",
|
||||
" <td>1.091475</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" <td>0.884821</td>\n",
|
||||
" <td>0.911269</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>0.961574</td>\n",
|
||||
" <td>0.977690</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" <td>0.671865</td>\n",
|
||||
" <td>0.875679</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.829995</td>\n",
|
||||
" <td>0.893122</td>\n",
|
||||
" <td>0.471727</td>\n",
|
||||
" <td>0.878200</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.781661</td>\n",
|
||||
" <td>0.876511</td>\n",
|
||||
" <td>0.361314</td>\n",
|
||||
" <td>0.884209</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
@ -722,33 +771,33 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>0.976380</td>\n",
|
||||
" <td>1.001455</td>\n",
|
||||
" <td>0.973745</td>\n",
|
||||
" <td>0.993206</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0.875964</td>\n",
|
||||
" <td>0.919960</td>\n",
|
||||
" <td>0.869132</td>\n",
|
||||
" <td>0.914323</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>0.685377</td>\n",
|
||||
" <td>0.870664</td>\n",
|
||||
" <td>0.676553</td>\n",
|
||||
" <td>0.870192</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>3</td>\n",
|
||||
" <td>0.483701</td>\n",
|
||||
" <td>0.874071</td>\n",
|
||||
" <td>0.485377</td>\n",
|
||||
" <td>0.873865</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>4</td>\n",
|
||||
" <td>0.385249</td>\n",
|
||||
" <td>0.878055</td>\n",
|
||||
" <td>00:12</td>\n",
|
||||
" <td>0.377866</td>\n",
|
||||
" <td>0.877610</td>\n",
|
||||
" <td>00:11</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>"
|
||||
|
Loading…
Reference in New Issue
Block a user