You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Orbits/Code/Untitled.ipynb

246 lines
5.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "operating-illinois",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import combined as c\n",
"import NeuralNetworkSpecifications as nns"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "white-lottery",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 5\n",
"STATES = 3\n",
"CONSTELLATIONS = STATES -1 #determined by debris tracking\n",
"MAX = 10\n",
"FEATURES = 1\n",
"\n",
"stocks = torch.randint(MAX,(BATCH_SIZE,1,CONSTELLATIONS), dtype=torch.float32, requires_grad=True)\n",
"debris = torch.randint(MAX,(BATCH_SIZE,1), dtype=torch.float32, requires_grad=True)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "quick-extraction",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[[1.],\n",
" [0.]]]], requires_grad=True) torch.Size([1, 1, 2, 1])\n",
"tensor([[ 1.0000, -0.1000]], requires_grad=True) torch.Size([1, 2])\n"
]
}
],
"source": [
"#launch_costs = torch.randint(3,(1,CONSTELLATIONS,CONSTELLATIONS,FEATURES), dtype=torch.float32)\n",
"launch_costs = torch.tensor([[[[1.0],[0.0]]]], requires_grad=True)\n",
"print(launch_costs, launch_costs.shape)\n",
"#payoff = torch.randint(5,(STATES,CONSTELLATIONS), dtype=torch.float32)\n",
"payoff = torch.tensor([[1.0, -0.1]], requires_grad=True)\n",
"print(payoff, payoff.shape)\n",
"\n",
"debris_cost = -0.2"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "textile-cleanup",
"metadata": {},
"outputs": [],
"source": [
"def linear_profit(stocks, debris, choices,constellation_number):\n",
" #Pay particular attention to the dimensions\n",
" #note that there is an extra dimension in there just ot match that of the profit vector we'll be giving out.\n",
" \n",
" #calculate launch expenses\n",
" \n",
" launch_expense = (-5 * output)[:,constellation_number,:]\n",
"\n",
" #calculate revenue\n",
"\n",
" revenue = (payoff * stocks).sum(dim=2)\n",
" \n",
" debris_costs = debris * debris_cost \n",
"\n",
"\n",
" profit = (revenue + debris_costs + launch_expense).sum(dim=1)\n",
" return profit"
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "single-wheat",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 3.2451, 4.3734, 6.5474, -0.2722, -2.8843], grad_fn=<SumBackward1>),\n",
" torch.Size([5]))"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"profit = linear_profit(stocks, debris, output,0)\n",
"profit, profit.shape"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "handy-perry",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.2000],\n",
" [-0.0000],\n",
" [-0.0000],\n",
" [-0.0000],\n",
" [-0.0000]]),)"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.autograd.grad(profit[0], (debris), create_graph=True)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "purple-superior",
"metadata": {},
"outputs": [],
"source": [
"policy = nns.ChoiceFunction(BATCH_SIZE\n",
" ,STATES\n",
" ,FEATURES\n",
" ,CONSTELLATIONS\n",
" ,12\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "auburn-leonard",
"metadata": {},
"source": [
"example to get profit = 1\n",
"```python\n",
"optimizer = torch.optim.Adam(policy.parameters(),lr=0.001)\n",
"\n",
"for i in range(10000):\n",
" #training loop\n",
" optimizer.zero_grad()\n",
"\n",
" output = policy.forward(s.values)\n",
"\n",
" l = ((1-linear_profit(s.values,output))**2).sum()\n",
"\n",
"\n",
" l.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" if i%200==0:\n",
" print(l)\n",
" \n",
"\n",
"results = policy.forward(s.values)\n",
"print(results.mean(dim=0), \"\\n\",results.std(dim=0))\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "herbal-manual",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.2910],\n",
" [0.4003]],\n",
"\n",
" [[0.1053],\n",
" [0.2446]],\n",
"\n",
" [[0.1705],\n",
" [0.2758]],\n",
"\n",
" [[0.1944],\n",
" [0.3421]],\n",
"\n",
" [[0.5369],\n",
" [0.6181]]], grad_fn=<ReluBackward0>)"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = policy.forward(s.values)\n",
"output"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "another-timing",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}