store_attr
This commit is contained in:
parent
c6322c68a5
commit
45d5cd1c47
@ -1326,7 +1326,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class SGD:\n",
|
"class SGD:\n",
|
||||||
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
|
" def __init__(self, params, lr, wd=0.): store_attr()\n",
|
||||||
" def step(self):\n",
|
" def step(self):\n",
|
||||||
" for p in self.params:\n",
|
" for p in self.params:\n",
|
||||||
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
||||||
@ -1367,7 +1367,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"class Learner:\n",
|
"class Learner:\n",
|
||||||
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
||||||
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
|
" store_attr()\n",
|
||||||
" for cb in cbs: cb.learner = self\n",
|
" for cb in cbs: cb.learner = self\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def one_batch(self):\n",
|
" def one_batch(self):\n",
|
||||||
|
@ -606,7 +606,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"class SGD:\n",
|
"class SGD:\n",
|
||||||
" def __init__(self, params, lr, wd=0.): store_attr(self, 'params,lr,wd')\n",
|
" def __init__(self, params, lr, wd=0.): store_attr()\n",
|
||||||
" def step(self):\n",
|
" def step(self):\n",
|
||||||
" for p in self.params:\n",
|
" for p in self.params:\n",
|
||||||
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
||||||
@ -633,7 +633,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"class Learner:\n",
|
"class Learner:\n",
|
||||||
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
||||||
" store_attr(self, 'model,dls,loss_func,lr,cbs,opt_func')\n",
|
" store_attr()\n",
|
||||||
" for cb in cbs: cb.learner = self\n",
|
" for cb in cbs: cb.learner = self\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def one_batch(self):\n",
|
" def one_batch(self):\n",
|
||||||
|
Loading…
Reference in New Issue
Block a user