store_attr

This commit is contained in:
Jeremy Howard 2020-09-13 09:55:40 -07:00
parent c6322c68a5
commit 45d5cd1c47
2 changed files with 4 additions and 4 deletions

View File

@ -1326,7 +1326,7 @@
"outputs": [],
"source": [
"class SGD:\n",
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
" def __init__(self, params, lr, wd=0.): store_attr()\n",
" def step(self):\n",
" for p in self.params:\n",
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
@ -1367,7 +1367,7 @@
"source": [
"class Learner:\n",
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
" store_attr()\n",
" for cb in cbs: cb.learner = self\n",
"\n",
" def one_batch(self):\n",

View File

@ -606,7 +606,7 @@
"outputs": [],
"source": [
"class SGD:\n",
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
" def __init__(self, params, lr, wd=0.): store_attr()\n",
" def step(self):\n",
" for p in self.params:\n",
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
@ -633,7 +633,7 @@
"source": [
"class Learner:\n",
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
" store_attr()\n",
" for cb in cbs: cb.learner = self\n",
"\n",
" def one_batch(self):\n",