@ -3,7 +3,7 @@
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 1,
"execution_count": 1,
"id": "religious-anaheim ",
"id": "ceramic-doctrine ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@ -11,16 +11,8 @@
"from torch.autograd.functional import jacobian\n",
"from torch.autograd.functional import jacobian\n",
"import itertools\n",
"import itertools\n",
"import math\n",
"import math\n",
"import abc"
"import abc\n",
]
"\n",
},
{
"cell_type": "code",
"execution_count": 2,
"id": "green-brunei",
"metadata": {},
"outputs": [],
"source": [
"class EconomicAgent(metaclass=abc.ABCMeta):\n",
"class EconomicAgent(metaclass=abc.ABCMeta):\n",
" @abc.abstractmethod\n",
" @abc.abstractmethod\n",
" def period_benefit(self,state,estimand_interface):\n",
" def period_benefit(self,state,estimand_interface):\n",
@ -99,16 +91,8 @@
" return len(self.stocks)\n",
" return len(self.stocks)\n",
" @property\n",
" @property\n",
" def number_debris_trackers(self):\n",
" def number_debris_trackers(self):\n",
" return len(self.debris)"
" return len(self.debris)\n",
]
"\n",
},
{
"cell_type": "code",
"execution_count": 3,
"id": "sweet-injection",
"metadata": {},
"outputs": [],
"source": [
" \n",
" \n",
"class EstimandInterface():\n",
"class EstimandInterface():\n",
" \"\"\"\n",
" \"\"\"\n",
@ -177,16 +161,9 @@
" \n",
" \n",
" def __str__(self):\n",
" def __str__(self):\n",
" #just a human readable descriptor\n",
" #just a human readable descriptor\n",
" return \"Launch Decisions and Partial Derivativs of value function with\\n\\tlaunches\\n\\t\\t {}\\n\\tPartials\\n\\t\\t{}\".format(self.choices,self.partials)\n"
" return \"Launch Decisions and Partial Derivativs of value function with\\n\\tlaunches\\n\\t\\t {}\\n\\tPartials\\n\\t\\t{}\".format(self.choices,self.partials)\n",
]
"\n",
},
"\n",
{
"cell_type": "code",
"execution_count": 4,
"id": "right-dinner",
"metadata": {},
"outputs": [],
"source": [
"class ChoiceFunction(torch.nn.Module):\n",
"class ChoiceFunction(torch.nn.Module):\n",
" \"\"\"\n",
" \"\"\"\n",
" This is used to estimate the launch function\n",
" This is used to estimate the launch function\n",
@ -234,16 +211,8 @@
" \n",
" \n",
" intermediate_values = self.relu(intermediate_values) #launches are always positive, this may need removed for other types of choices.\n",
" intermediate_values = self.relu(intermediate_values) #launches are always positive, this may need removed for other types of choices.\n",
" \n",
" \n",
" return intermediate_values"
" return intermediate_values\n",
]
"\n",
},
{
"cell_type": "code",
"execution_count": 5,
"id": "global-wallet",
"metadata": {},
"outputs": [],
"source": [
"class PartialDerivativesOfValueEstimand(torch.nn.Module):\n",
"class PartialDerivativesOfValueEstimand(torch.nn.Module):\n",
" \"\"\"\n",
" \"\"\"\n",
" This is used to estimate the partial derivatives of the value functions\n",
" This is used to estimate the partial derivatives of the value functions\n",
@ -305,8 +274,8 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 6 ,
"execution_count": 2 ,
"id": "resident-cooper ",
"id": "executive-royal ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@ -338,7 +307,7 @@
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"id": "compatible-conviction ",
"id": "numerical-mexico ",
"metadata": {},
"metadata": {},
"source": [
"source": [
"# Testing\n",
"# Testing\n",
@ -348,8 +317,8 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 7 ,
"execution_count": 3 ,
"id": "explicit-sponsorship ",
"id": "packed-economics ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@ -362,25 +331,25 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 8 ,
"execution_count": 4 ,
"id": "desperate-color ",
"id": "compliant-circle ",
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
"data": {
"data": {
"text/plain": [
"text/plain": [
"tensor([[[88., 68., 13 .]],\n",
"tensor([[[ 1., 30., 11 .]],\n",
"\n",
"\n",
" [[23., 8., 62 .]],\n",
" [[60., 74., 1 .]],\n",
"\n",
"\n",
" [[96., 65., 89 .]],\n",
" [[46., 33., 70 .]],\n",
"\n",
"\n",
" [[16., 27., 6 2.]],\n",
" [[42., 29., 3 2.]],\n",
"\n",
"\n",
" [[40., 38., 20 .]]])"
" [[82., 72., 57 .]]])"
]
]
},
},
"execution_count": 8 ,
"execution_count": 4 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
@ -391,22 +360,8 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"id": "median-nurse",
"id": "theoretical-spectrum",
"metadata": {},
"outputs": [],
"source": [
"enn = EstimandNN(batch_size\n",
" ,states\n",
" ,choices\n",
" ,constellations\n",
" ,12)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "under-monroe",
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
@ -415,35 +370,35 @@
"text": [
"text": [
"Launch Decisions and Partial Derivativs of value function with\n",
"Launch Decisions and Partial Derivativs of value function with\n",
"\tlaunches\n",
"\tlaunches\n",
"\t\t tensor([[[0.8138 ],\n",
"\t\t tensor([[[0.0000 ],\n",
" [4.6481 ]],\n",
" [0.0000 ]],\n",
"\n",
"\n",
" [[1.1540 ],\n",
" [[2.0907 ],\n",
" [2.0568 ]],\n",
" [0.1053 ]],\n",
"\n",
"\n",
" [[2.117 0],\n",
" [[2.973 0],\n",
" [6.2769 ]],\n",
" [2.2000 ]],\n",
"\n",
"\n",
" [[1.3752 ],\n",
" [[2.3975 ],\n",
" [2.4555 ]],\n",
" [1.2877 ]],\n",
"\n",
"\n",
" [[0.7025 ],\n",
" [[4.2107 ],\n",
" [2.5947 ]]], grad_fn=<ReluBackward0>)\n",
" [2.0752 ]]], grad_fn=<ReluBackward0>)\n",
"\tPartials\n",
"\tPartials\n",
"\t\ttensor([[[-1.7285, -1.5841, -1.0559 ],\n",
"\t\ttensor([[[ 0.1939, 0.3954, 0.0730 ],\n",
" [ 2.9694, 4.2772, 3.6800 ]],\n",
" [-0.9428, 0.6145, -0.9247 ]],\n",
"\n",
"\n",
" [[-0.6313, -1.6874, -0.1176 ],\n",
" [[ 1.1686, 3.0170, 0.3393 ],\n",
" [ 2.3680, 3.5758, 2.4247 ]],\n",
" [-7.1474, 2.3495, -7.0566 ]],\n",
"\n",
"\n",
" [[-2.1381, -3.2882, -0.9620 ],\n",
" [[-2.0849, 3.0883, -3.3791 ],\n",
" [ 5.2646, 7.8475, 5.8994 ]],\n",
" [-0.6664, 0.0361, -2.2530 ]],\n",
"\n",
"\n",
" [[-1.2167, -2.0969, -0.499 8],\n",
" [[-0.7117, 2.5474, -1.645 8],\n",
" [ 1.7140, 2.4235, 2.1813 ]],\n",
" [-2.1937, 0.6897, -3.0382 ]],\n",
"\n",
"\n",
" [[-1.1293, -1.2674, -0.638 6],\n",
" [[-1.0262, 4.5973, -2.660 6],\n",
" [ 1.5440, 2.1548, 2.0289 ]]], grad_fn=<AddBackward0>)\n"
" [-5.4307, 1.4510, -6.6972 ]]], grad_fn=<AddBackward0>)\n"
]
]
}
}
],
],
@ -453,8 +408,8 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 11 ,
"execution_count": 7 ,
"id": "nonprofit-castle ",
"id": "vulnerable-penalty ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
@ -465,12 +420,12 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 12 ,
"execution_count": 30 ,
"id": "crucial-homeless ",
"id": "classified-estimate ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"b = ChoiceFunction(batch_size\n",
"ch = ChoiceFunction(batch_size\n",
" ,states\n",
" ,states\n",
" ,choices\n",
" ,choices\n",
" ,constellations\n",
" ,constellations\n",
@ -479,16 +434,16 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 1 3,
"execution_count": 31 ,
"id": "practical-journalist ",
"id": "martial-premium ",
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
"name": "stdout",
"name": "stdout",
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"tensor(29.0569 , grad_fn=<SumBackward0>)\n",
"tensor(46.8100 , grad_fn=<SumBackward0>)\n",
"tensor(23187.2695 , grad_fn=<SumBackward0>)\n",
"tensor(82442.4219 , grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
"tensor(0., grad_fn=<SumBackward0>)\n",
@ -518,19 +473,19 @@
" [0.]]], grad_fn=<ReluBackward0>)"
" [0.]]], grad_fn=<ReluBackward0>)"
]
]
},
},
"execution_count": 1 3,
"execution_count": 31 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
],
],
"source": [
"source": [
"optimizer = torch.optim.SGD(b .parameters(),lr=0.01)\n",
"optimizer = torch.optim.SGD(ch .parameters(),lr=0.01)\n",
"\n",
"\n",
"for i in range(10):\n",
"for i in range(10):\n",
" #training loop\n",
" #training loop\n",
" optimizer.zero_grad()\n",
" optimizer.zero_grad()\n",
"\n",
"\n",
" output = b .forward(stocks_and_debris)\n",
" output = ch .forward(stocks_and_debris)\n",
"\n",
"\n",
" l = lossb(output)\n",
" l = lossb(output)\n",
"\n",
"\n",
@ -541,71 +496,260 @@
" print(l)\n",
" print(l)\n",
" \n",
" \n",
"\n",
"\n",
"b .forward(stocks_and_debris)"
"ch .forward(stocks_and_debris)"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 1 4,
"execution_count": 45 ,
"id": "correct-complex ",
"id": "corrected-jewelry ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": [
"source": [
"def lossa (a):\n",
"def lossc (a):\n",
" #test loss function\n",
" #test loss function\n",
" return (a.choices**2).sum() + (a.partials**2).sum()"
" return (a**2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "opened-figure",
"metadata": {},
"outputs": [],
"source": [
"pd = PartialDerivativesOfValueEstimand(\n",
" batch_size\n",
" ,constellations\n",
" ,states\n",
" ,12)"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 15,
"execution_count": 74 ,
"id": "pharmaceutical-brush",
"id": "chicken-inspector ",
"metadata": {},
"metadata": {},
"outputs": [
"outputs": [
{
{
"name": "stdout",
"name": "stdout",
"output_type": "stream",
"output_type": "stream",
"text": [
"text": [
"tensor(336.1971, grad_fn=<AddBackward0>)\n",
"tensor(1.9948e-06, grad_fn=<SumBackward0>)\n",
"tensor(67583.6484, grad_fn=<AddBackward0>)\n",
"tensor(1.7427e-05, grad_fn=<SumBackward0>)\n",
"tensor(1.5658e+26, grad_fn=<AddBackward0>)\n",
"tensor(5.7993e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(2.9985e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(6.5281e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(7.8818e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(4.4327e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(1.1240e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n",
"tensor(1.2478e-06, grad_fn=<SumBackward0>)\n",
"tensor(nan, grad_fn=<AddBackward0>)\n"
"tensor(3.5818e-06, grad_fn=<SumBackward0>)\n",
"tensor(4.3732e-06, grad_fn=<SumBackward0>)\n",
"tensor(2.7699e-06, grad_fn=<SumBackward0>)\n",
"tensor(8.9659e-07, grad_fn=<SumBackward0>)\n",
"tensor(5.7541e-07, grad_fn=<SumBackward0>)\n",
"tensor(1.5010e-06, grad_fn=<SumBackward0>)\n"
]
]
},
},
{
{
"data": {
"data": {
"text/plain": [
"text/plain": [
"tensor([[[0.],\n",
"tensor([[[ 0.0002, -0.0002, -0.0003 ],\n",
" [0.]],\n",
" [ 0.0001, -0.0003, -0.0002 ]],\n",
"\n",
"\n",
" [[0.],\n",
" [[ 0.0002, -0.0003, -0.0003 ],\n",
" [0.]],\n",
" [ 0.0003, -0.0004, -0.0002 ]],\n",
"\n",
"\n",
" [[0.],\n",
" [[ 0.0002, -0.0003, -0.0003 ],\n",
" [0.]],\n",
" [ 0.0002, -0.0003, -0.0003 ]],\n",
"\n",
"\n",
" [[0.],\n",
" [[ 0.0002, -0.0002, -0.0004 ],\n",
" [0.]],\n",
" [ 0.0003, -0.0003, -0.0003 ]],\n",
"\n",
"\n",
" [[0.],\n",
" [[ 0.0003, -0.0003, -0.0002 ],\n",
" [0.]]], grad_fn=<Relu Backward0>)"
" [ 0.0003, -0.0003, -0.0002]]], grad_fn=<Add Backward0>)"
]
]
},
},
"execution_count": 15 ,
"execution_count": 74 ,
"metadata": {},
"metadata": {},
"output_type": "execute_result"
"output_type": "execute_result"
}
}
],
],
"source": [
"source": [
"optimizer = torch.optim.SGD(enn.parameters(),lr=0.001) #note the use of enn in the optimizer \n",
"optimizer = torch.optim.Adam(pd.parameters(),lr=0.0001) \n",
"\n",
"\n",
"for i in range(10):\n",
"for i in range(15):\n",
" #training loop\n",
" optimizer.zero_grad()\n",
"\n",
" output = pd.forward(stocks_and_debris)\n",
"\n",
" l = lossc(output)\n",
"\n",
" l.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" print(l)\n",
" \n",
"\n",
"pd.forward(stocks_and_debris)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "southwest-diamond",
"metadata": {},
"outputs": [],
"source": [
"def lossa(a):\n",
" #test loss function\n",
" return (a.choices**2).sum() + (a.partials**2).sum()"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "brave-treat",
"metadata": {},
"outputs": [],
"source": [
"enn = EstimandNN(batch_size\n",
" ,states\n",
" ,choices\n",
" ,constellations\n",
" ,12)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "functional-render",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(112.1970, grad_fn=<AddBackward0>)\n",
"10 tensor(79.8152, grad_fn=<AddBackward0>)\n",
"20 tensor(55.6422, grad_fn=<AddBackward0>)\n",
"30 tensor(38.5636, grad_fn=<AddBackward0>)\n",
"40 tensor(26.9156, grad_fn=<AddBackward0>)\n",
"50 tensor(18.9986, grad_fn=<AddBackward0>)\n",
"60 tensor(13.6606, grad_fn=<AddBackward0>)\n",
"70 tensor(10.1881, grad_fn=<AddBackward0>)\n",
"80 tensor(8.0395, grad_fn=<AddBackward0>)\n",
"90 tensor(6.7618, grad_fn=<AddBackward0>)\n",
"100 tensor(6.0101, grad_fn=<AddBackward0>)\n",
"110 tensor(5.5517, grad_fn=<AddBackward0>)\n",
"120 tensor(5.2434, grad_fn=<AddBackward0>)\n",
"130 tensor(5.0054, grad_fn=<AddBackward0>)\n",
"140 tensor(4.7988, grad_fn=<AddBackward0>)\n",
"150 tensor(4.6069, grad_fn=<AddBackward0>)\n",
"160 tensor(4.4235, grad_fn=<AddBackward0>)\n",
"170 tensor(4.2468, grad_fn=<AddBackward0>)\n",
"180 tensor(4.0763, grad_fn=<AddBackward0>)\n",
"190 tensor(3.9117, grad_fn=<AddBackward0>)\n",
"200 tensor(3.7532, grad_fn=<AddBackward0>)\n",
"210 tensor(3.6005, grad_fn=<AddBackward0>)\n",
"220 tensor(3.4535, grad_fn=<AddBackward0>)\n",
"230 tensor(3.3121, grad_fn=<AddBackward0>)\n",
"240 tensor(3.1761, grad_fn=<AddBackward0>)\n",
"250 tensor(3.0454, grad_fn=<AddBackward0>)\n",
"260 tensor(2.9198, grad_fn=<AddBackward0>)\n",
"270 tensor(2.7991, grad_fn=<AddBackward0>)\n",
"280 tensor(2.6832, grad_fn=<AddBackward0>)\n",
"290 tensor(2.5720, grad_fn=<AddBackward0>)\n",
"300 tensor(2.4653, grad_fn=<AddBackward0>)\n",
"310 tensor(2.3629, grad_fn=<AddBackward0>)\n",
"320 tensor(2.2646, grad_fn=<AddBackward0>)\n",
"330 tensor(2.1704, grad_fn=<AddBackward0>)\n",
"340 tensor(2.0800, grad_fn=<AddBackward0>)\n",
"350 tensor(1.9933, grad_fn=<AddBackward0>)\n",
"360 tensor(1.9103, grad_fn=<AddBackward0>)\n",
"370 tensor(1.8306, grad_fn=<AddBackward0>)\n",
"380 tensor(1.7543, grad_fn=<AddBackward0>)\n",
"390 tensor(1.6812, grad_fn=<AddBackward0>)\n",
"400 tensor(1.6111, grad_fn=<AddBackward0>)\n",
"410 tensor(1.5440, grad_fn=<AddBackward0>)\n",
"420 tensor(1.4797, grad_fn=<AddBackward0>)\n",
"430 tensor(1.4180, grad_fn=<AddBackward0>)\n",
"440 tensor(1.3590, grad_fn=<AddBackward0>)\n",
"450 tensor(1.3025, grad_fn=<AddBackward0>)\n",
"460 tensor(1.2484, grad_fn=<AddBackward0>)\n",
"470 tensor(1.1965, grad_fn=<AddBackward0>)\n",
"480 tensor(1.1469, grad_fn=<AddBackward0>)\n",
"490 tensor(1.0994, grad_fn=<AddBackward0>)\n",
"500 tensor(1.0540, grad_fn=<AddBackward0>)\n",
"510 tensor(1.0104, grad_fn=<AddBackward0>)\n",
"520 tensor(0.9688, grad_fn=<AddBackward0>)\n",
"530 tensor(0.9290, grad_fn=<AddBackward0>)\n",
"540 tensor(0.8908, grad_fn=<AddBackward0>)\n",
"550 tensor(0.8544, grad_fn=<AddBackward0>)\n",
"560 tensor(0.8195, grad_fn=<AddBackward0>)\n",
"570 tensor(0.7861, grad_fn=<AddBackward0>)\n",
"580 tensor(0.7542, grad_fn=<AddBackward0>)\n",
"590 tensor(0.7237, grad_fn=<AddBackward0>)\n",
"600 tensor(0.6945, grad_fn=<AddBackward0>)\n",
"610 tensor(0.6667, grad_fn=<AddBackward0>)\n",
"620 tensor(0.6400, grad_fn=<AddBackward0>)\n",
"630 tensor(0.6146, grad_fn=<AddBackward0>)\n",
"640 tensor(0.5903, grad_fn=<AddBackward0>)\n",
"650 tensor(0.5671, grad_fn=<AddBackward0>)\n",
"660 tensor(0.5449, grad_fn=<AddBackward0>)\n",
"670 tensor(0.5237, grad_fn=<AddBackward0>)\n",
"680 tensor(0.5035, grad_fn=<AddBackward0>)\n",
"690 tensor(0.4842, grad_fn=<AddBackward0>)\n",
"700 tensor(0.4658, grad_fn=<AddBackward0>)\n",
"710 tensor(0.4482, grad_fn=<AddBackward0>)\n",
"720 tensor(0.4315, grad_fn=<AddBackward0>)\n",
"730 tensor(0.4155, grad_fn=<AddBackward0>)\n",
"740 tensor(0.4002, grad_fn=<AddBackward0>)\n",
"750 tensor(0.3857, grad_fn=<AddBackward0>)\n",
"760 tensor(0.3718, grad_fn=<AddBackward0>)\n",
"770 tensor(0.3586, grad_fn=<AddBackward0>)\n",
"780 tensor(0.3460, grad_fn=<AddBackward0>)\n",
"790 tensor(0.3340, grad_fn=<AddBackward0>)\n",
"800 tensor(0.3226, grad_fn=<AddBackward0>)\n",
"810 tensor(0.3117, grad_fn=<AddBackward0>)\n",
"820 tensor(0.3013, grad_fn=<AddBackward0>)\n",
"830 tensor(0.2914, grad_fn=<AddBackward0>)\n",
"840 tensor(0.2820, grad_fn=<AddBackward0>)\n",
"850 tensor(0.2730, grad_fn=<AddBackward0>)\n",
"860 tensor(0.2645, grad_fn=<AddBackward0>)\n",
"870 tensor(0.2564, grad_fn=<AddBackward0>)\n",
"880 tensor(0.2486, grad_fn=<AddBackward0>)\n",
"890 tensor(0.2413, grad_fn=<AddBackward0>)\n",
"900 tensor(0.2342, grad_fn=<AddBackward0>)\n",
"910 tensor(0.2276, grad_fn=<AddBackward0>)\n",
"920 tensor(0.2212, grad_fn=<AddBackward0>)\n",
"930 tensor(0.2151, grad_fn=<AddBackward0>)\n",
"940 tensor(0.2094, grad_fn=<AddBackward0>)\n",
"950 tensor(0.2039, grad_fn=<AddBackward0>)\n",
"960 tensor(0.1986, grad_fn=<AddBackward0>)\n",
"970 tensor(0.1936, grad_fn=<AddBackward0>)\n",
"980 tensor(0.1889, grad_fn=<AddBackward0>)\n",
"990 tensor(0.1844, grad_fn=<AddBackward0>)\n"
]
},
{
"data": {
"text/plain": [
"<__main__.EstimandInterface at 0x7f85609fce20>"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimizer = torch.optim.Adam(enn.parameters(),lr=0.0001) #note the use of enn in the optimizer\n",
"\n",
"for i in range(1000):\n",
" #training loop\n",
" #training loop\n",
" optimizer.zero_grad()\n",
" optimizer.zero_grad()\n",
"\n",
"\n",
@ -617,16 +761,17 @@
"\n",
"\n",
" optimizer.step()\n",
" optimizer.step()\n",
"\n",
"\n",
" print(l)\n",
" if i%10==0:\n",
" print(i, l)\n",
" \n",
" \n",
"\n",
"\n",
"b .forward(stocks_and_debris)"
"enn .forward(stocks_and_debris)"
]
]
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
"id": "other-subdivision ",
"id": "voluntary-postage ",
"metadata": {},
"metadata": {},
"outputs": [],
"outputs": [],
"source": []
"source": []