vectorized update_r
This commit is contained in:
@@ -89,7 +89,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"$ r(i, k) \\leftarrow - \\max\\limits_{k' s.t. k' \\neq k}\\{ a(i, k') + s(i, k') \\}$"
|
||||
"$ r(i, k) \\leftarrow s(i, k) - \\max\\limits_{k' s.t. k' \\neq k}\\{ a(i, k') + s(i, k') \\}$"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -97,13 +97,36 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def update_r(lmda=0.9): \n",
|
||||
" for i in range(x.shape[0]):\n",
|
||||
" for k in range(x.shape[0]):\n",
|
||||
" v = S[i, :] + A[i, :]\n",
|
||||
" v[k] = -np.inf\n",
|
||||
" v[i]= -np.inf\n",
|
||||
" R[i, k] = R[i, k] * lmda + (1 - lmda) * (S[i, k] - np.max(v))"
|
||||
"def update_r(lmda=0.9):\n",
|
||||
" # For every column k, except for the column with the maximum value the max is the same.\n",
|
||||
" # So we can subtract the maximum for every row, and only need to do something different for k == argmax\n",
|
||||
" \n",
|
||||
" global R\n",
|
||||
" v = S + A\n",
|
||||
" rows = np.arange(x.shape[0])\n",
|
||||
" # We only compare the current point to all other points, so the diagonal can be filled with -infinity\n",
|
||||
" np.fill_diagonal(v, -np.inf)\n",
|
||||
" \n",
|
||||
" # max values\n",
|
||||
" idx_max = np.argmax(v, axis=1)\n",
|
||||
" first_max = v[rows, idx_max]\n",
|
||||
" \n",
|
||||
" # Second max values. For every column where k is the max value.\n",
|
||||
" v[rows, idx_max] = -np.inf\n",
|
||||
" second_max = v[rows, np.argmax(v, axis=1)]\n",
|
||||
" \n",
|
||||
" # Broadcast the maximum value per row over all the columns per row.\n",
|
||||
" max_matrix = np.zeros_like(R) + first_max[:, None]\n",
|
||||
" max_matrix[rows, idx_max] = second_max\n",
|
||||
" \n",
|
||||
" max_matrix = S - max_matrix\n",
|
||||
" \n",
|
||||
" R = R * lmda + (1 - lmda) * R_new\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"A, R, S = create_matrices()\n",
|
||||
"update_r()\n",
|
||||
"R"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -113,6 +136,30 @@
|
||||
"$ a(i, k) \\leftarrow \\min\\{0, r(k,k) + \\sum\\limits_{i' s.t. i' \\notin \\{i, k\\}}{\\max\\{0, r(i', k)\\}}$ "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.diag(k)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"A, R, S = create_matrices()\n",
|
||||
"preference = np.median(S)\n",
|
||||
"# update_r()\n",
|
||||
"\n",
|
||||
"i, k = np.meshgrid(np.arange(100), np.arange(100))\n",
|
||||
"\n",
|
||||
"v = S + A\n",
|
||||
"\n",
|
||||
"v[i, k] = -np.inf\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user