You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
246 lines
5.8 KiB
Plaintext
246 lines
5.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "operating-illinois",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import combined as c\n",
|
|
"import NeuralNetworkSpecifications as nns"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "white-lottery",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BATCH_SIZE = 5\n",
|
|
"STATES = 3\n",
|
|
"CONSTELLATIONS = STATES -1 #determined by debris tracking\n",
|
|
"MAX = 10\n",
|
|
"FEATURES = 1\n",
|
|
"\n",
|
|
"stocks = torch.randint(MAX,(BATCH_SIZE,1,CONSTELLATIONS), dtype=torch.float32, requires_grad=True)\n",
|
|
"debris = torch.randint(MAX,(BATCH_SIZE,1), dtype=torch.float32, requires_grad=True)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 91,
|
|
"id": "quick-extraction",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[[[1.],\n",
|
|
" [0.]]]], requires_grad=True) torch.Size([1, 1, 2, 1])\n",
|
|
"tensor([[ 1.0000, -0.1000]], requires_grad=True) torch.Size([1, 2])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#launch_costs = torch.randint(3,(1,CONSTELLATIONS,CONSTELLATIONS,FEATURES), dtype=torch.float32)\n",
|
|
"launch_costs = torch.tensor([[[[1.0],[0.0]]]], requires_grad=True)\n",
|
|
"print(launch_costs, launch_costs.shape)\n",
|
|
"#payoff = torch.randint(5,(STATES,CONSTELLATIONS), dtype=torch.float32)\n",
|
|
"payoff = torch.tensor([[1.0, -0.1]], requires_grad=True)\n",
|
|
"print(payoff, payoff.shape)\n",
|
|
"\n",
|
|
"debris_cost = -0.2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 92,
|
|
"id": "textile-cleanup",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def linear_profit(stocks, debris, choices,constellation_number):\n",
|
|
" #Pay particular attention to the dimensions\n",
|
|
" #note that there is an extra dimension in there just ot match that of the profit vector we'll be giving out.\n",
|
|
" \n",
|
|
" #calculate launch expenses\n",
|
|
" \n",
|
|
" launch_expense = (-5 * output)[:,constellation_number,:]\n",
|
|
"\n",
|
|
" #calculate revenue\n",
|
|
"\n",
|
|
" revenue = (payoff * stocks).sum(dim=2)\n",
|
|
" \n",
|
|
" debris_costs = debris * debris_cost \n",
|
|
"\n",
|
|
"\n",
|
|
" profit = (revenue + debris_costs + launch_expense).sum(dim=1)\n",
|
|
" return profit"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 100,
|
|
"id": "single-wheat",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([ 3.2451, 4.3734, 6.5474, -0.2722, -2.8843], grad_fn=<SumBackward1>),\n",
|
|
" torch.Size([5]))"
|
|
]
|
|
},
|
|
"execution_count": 100,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"profit = linear_profit(stocks, debris, output,0)\n",
|
|
"profit, profit.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 123,
|
|
"id": "handy-perry",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([[-0.2000],\n",
|
|
" [-0.0000],\n",
|
|
" [-0.0000],\n",
|
|
" [-0.0000],\n",
|
|
" [-0.0000]]),)"
|
|
]
|
|
},
|
|
"execution_count": 123,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"torch.autograd.grad(profit[0], (debris), create_graph=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 95,
|
|
"id": "purple-superior",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"policy = nns.ChoiceFunction(BATCH_SIZE\n",
|
|
" ,STATES\n",
|
|
" ,FEATURES\n",
|
|
" ,CONSTELLATIONS\n",
|
|
" ,12\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "auburn-leonard",
|
|
"metadata": {},
|
|
"source": [
|
|
"example to get profit = 1\n",
|
|
"```python\n",
|
|
"optimizer = torch.optim.Adam(policy.parameters(),lr=0.001)\n",
|
|
"\n",
|
|
"for i in range(10000):\n",
|
|
" #training loop\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" output = policy.forward(s.values)\n",
|
|
"\n",
|
|
" l = ((1-linear_profit(s.values,output))**2).sum()\n",
|
|
"\n",
|
|
"\n",
|
|
" l.backward()\n",
|
|
"\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" if i%200==0:\n",
|
|
" print(l)\n",
|
|
" \n",
|
|
"\n",
|
|
"results = policy.forward(s.values)\n",
|
|
"print(results.mean(dim=0), \"\\n\",results.std(dim=0))\n",
|
|
"```\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 96,
|
|
"id": "herbal-manual",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[[0.2910],\n",
|
|
" [0.4003]],\n",
|
|
"\n",
|
|
" [[0.1053],\n",
|
|
" [0.2446]],\n",
|
|
"\n",
|
|
" [[0.1705],\n",
|
|
" [0.2758]],\n",
|
|
"\n",
|
|
" [[0.1944],\n",
|
|
" [0.3421]],\n",
|
|
"\n",
|
|
" [[0.5369],\n",
|
|
" [0.6181]]], grad_fn=<ReluBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 96,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"output = policy.forward(s.values)\n",
|
|
"output"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "another-timing",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|