current work, including a whole lot of julia stuff (Pluto.jl FTW)

temporaryWork^2
youainti 5 years ago
parent d5635622cc
commit 66a237cd19

1
.gitignore vendored

@ -300,4 +300,5 @@ TSWLatexianTemp*
*.pdf
#Don't track python/jupyterlab stuff
*/.ipynb_checkpoints/*
.ipynb_checkpoints/*
*/__pycache__/*

@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "royal-trace",
"id": "operating-illinois",
"metadata": {},
"outputs": [],
"source": [
@ -14,26 +14,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "atlantic-finish",
"execution_count": 25,
"id": "white-lottery",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[8., 5., 3.]],\n",
"\n",
" [[3., 6., 6.]],\n",
"\n",
" [[3., 7., 2.]],\n",
"\n",
" [[4., 8., 2.]],\n",
"\n",
" [[0., 6., 8.]]], grad_fn=<CatBackward>) torch.Size([5, 1, 3])\n"
]
}
],
"outputs": [],
"source": [
"BATCH_SIZE = 5\n",
"STATES = 3\n",
@ -42,25 +26,14 @@
"FEATURES = 1\n",
"\n",
"stocks = torch.randint(MAX,(BATCH_SIZE,1,CONSTELLATIONS), dtype=torch.float32, requires_grad=True)\n",
"debris = torch.randint(MAX,(BATCH_SIZE,1,1), dtype=torch.float32, requires_grad=True)\n",
"\n",
"s = c.States(stocks, debris)\n",
"\n",
"print(s.values,s.values.shape)"
"debris = torch.randint(MAX,(BATCH_SIZE,1), dtype=torch.float32, requires_grad=True)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "prostate-liverpool",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"id": "simplified-permission",
"execution_count": 91,
"id": "quick-extraction",
"metadata": {},
"outputs": [
{
@ -68,53 +41,100 @@
"output_type": "stream",
"text": [
"tensor([[[[1.],\n",
" [0.]],\n",
"\n",
" [[0.],\n",
" [1.]]]]) torch.Size([1, 2, 2, 1])\n",
"tensor([[ 1.0000, 0.0000],\n",
" [ 0.0000, 1.0000],\n",
" [-0.2000, -0.2000]]) torch.Size([3, 2])\n"
" [0.]]]], requires_grad=True) torch.Size([1, 1, 2, 1])\n",
"tensor([[ 1.0000, -0.1000]], requires_grad=True) torch.Size([1, 2])\n"
]
}
],
"source": [
"#launch_costs = torch.randint(3,(1,CONSTELLATIONS,CONSTELLATIONS,FEATURES), dtype=torch.float32)\n",
"launch_costs = torch.tensor([[[[1.0],[0]],[[0.0],[1]]]])\n",
"launch_costs = torch.tensor([[[[1.0],[0.0]]]], requires_grad=True)\n",
"print(launch_costs, launch_costs.shape)\n",
"#payoff = torch.randint(5,(STATES,CONSTELLATIONS), dtype=torch.float32)\n",
"payoff = torch.tensor([[1.0, 0],[0,1.0],[-0.2,-0.2]])\n",
"print(payoff, payoff.shape)"
"payoff = torch.tensor([[1.0, -0.1]], requires_grad=True)\n",
"print(payoff, payoff.shape)\n",
"\n",
"debris_cost = -0.2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "level-angle",
"execution_count": 92,
"id": "textile-cleanup",
"metadata": {},
"outputs": [],
"source": [
"def linear_profit(states, choices):\n",
"def linear_profit(stocks, debris, choices,constellation_number):\n",
" #Pay particular attention to the dimensions\n",
" #note that there is an extra dimension in there just ot match that of the profit vector we'll be giving out.\n",
" \n",
" #calculate launch expenses\n",
" \n",
" launch_expense = torch.tensordot(choices,launch_costs, [[-2,-1],[-2,-1]])\n",
" launch_expense = (-5 * output)[:,constellation_number,:]\n",
"\n",
" #calculate revenue\n",
"\n",
" revenue = torch.tensordot(s.values, payoff, [[-1],[0]])\n",
" revenue = (payoff * stocks).sum(dim=2)\n",
" \n",
" debris_costs = debris * debris_cost \n",
"\n",
"\n",
" profit = revenue - launch_expense\n",
" profit = (revenue + debris_costs + launch_expense).sum(dim=1)\n",
" return profit"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "copyrighted-acting",
"execution_count": 100,
"id": "single-wheat",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([ 3.2451, 4.3734, 6.5474, -0.2722, -2.8843], grad_fn=<SumBackward1>),\n",
" torch.Size([5]))"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"profit = linear_profit(stocks, debris, output,0)\n",
"profit, profit.shape"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "handy-perry",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.2000],\n",
" [-0.0000],\n",
" [-0.0000],\n",
" [-0.0000],\n",
" [-0.0000]]),)"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.autograd.grad(profit[0], (debris), create_graph=True)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "purple-superior",
"metadata": {},
"outputs": [],
"source": [
@ -128,7 +148,7 @@
},
{
"cell_type": "markdown",
"id": "casual-career",
"id": "auburn-leonard",
"metadata": {},
"source": [
"example to get profit = 1\n",
@ -159,30 +179,30 @@
},
{
"cell_type": "code",
"execution_count": 6,
"id": "straight-negative",
"execution_count": 96,
"id": "herbal-manual",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.0000],\n",
" [0.0000]],\n",
"tensor([[[0.2910],\n",
" [0.4003]],\n",
"\n",
" [[0.0000],\n",
" [0.0000]],\n",
" [[0.1053],\n",
" [0.2446]],\n",
"\n",
" [[0.0000],\n",
" [0.0000]],\n",
" [[0.1705],\n",
" [0.2758]],\n",
"\n",
" [[0.0000],\n",
" [0.0000]],\n",
" [[0.1944],\n",
" [0.3421]],\n",
"\n",
" [[0.3742],\n",
" [0.0000]]], grad_fn=<ReluBackward0>)"
" [[0.5369],\n",
" [0.6181]]], grad_fn=<ReluBackward0>)"
]
},
"execution_count": 6,
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
@ -192,73 +212,10 @@
"output"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "independent-deficit",
"metadata": {},
"outputs": [],
"source": [
"t = torch.ones_like(output, requires_grad=True)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "romance-force",
"metadata": {
"tags": []
},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-57-efee93d7c257>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m#this is where I lose the gradient. This is where I need a gradient so that I can call .backward below\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtest_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": [
"def test_loss(options):\n",
" return torch.autograd.functional.jacobian(linear_profit, (s.values, options))[0].sum()\n",
" #something is off here ^\n",
" #this is where I lose the gradient. This is where I need a gradient so that I can call .backward below\n",
"\n",
"test_loss(output).backward()"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "asian-death",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-55-ac1f78ecd780>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "prospective-nelson",
"id": "another-timing",
"metadata": {},
"outputs": [],
"source": []

@ -0,0 +1,945 @@
### A Pluto.jl notebook ###
# v0.17.0
using Markdown
using InteractiveUtils
# ╔═╡ cc16838c-3b25-11ec-2489-11c3d35f26f4
using Flux,LinearAlgebra,Zygote ,PlutoUI
# ╔═╡ 458bb826-4eaf-42ca-b889-4d1c50a2ecae
using Flux.Optimise: AbstractOptimiser
# ╔═╡ 7d6afccb-c45e-4371-827a-43d588d4945f
md"""
# Actor Critic Model
This is an implementation of an optimizer for an actor critic model.
"""
# ╔═╡ a810a371-d2e4-4d4f-9cf6-1b98582c4b0f
md"""
## Physical Model
"""
# ╔═╡ 2e59f530-aeb0-4ce1-9312-e81a66520181
abstract type AbstractPhysicalModel end
# ╔═╡ c57e9967-75e8-47d2-836e-f28159dc6592
#setup physical model
struct BasicModel <: AbstractPhysicalModel
#rate at which debris hits satellites
debris_collision_rate::Real
#rate at which satellites of different constellations collide
satellite_collision_rates::Matrix{Float64}
#rate at which debris exits orbits
decay_rate::Real
#rate at which satellites
autocatalysis_rate::Real
#ratio at which a collision between satellites produced debris
satellite_collision_debris_ratio::Real
#Ratio at which launches produce debris
launch_debris_ratio::Real
end
# ╔═╡ 4153bfd6-2ec6-4025-93ba-8e6e061809e8
md"""
## Setup NeuralNets
"""
# ╔═╡ bbc7e4b0-5edb-4e5d-9512-1aaeae1bb4ec
md"""
Custom function to zero out parameter traces
"""
# ╔═╡ 533250c6-fd76-47ef-a561-d567da04ae61
# ╔═╡ a305c33f-3388-4120-b036-3093c3ed6aa3
md"""
## setup Actor Critic struct and loop
"""
# ╔═╡ 0d375c34-3226-40b6-adaf-2f531c00d3ac
function zero(a::Flux.Params)
Flux.Params([Base.zero(x) for x=a])
end
# ╔═╡ e79e4ecf-caec-4c1a-9626-0c1a8b5006ce
mutable struct ActorCritic
#parameters
λʷ::Real
λᶿ::Real
αʷ::Real
αᶿ::Real
αᴿ::Real
#keep track of eligibility traces
zᶿ::Flux.Params
::Flux.Params
#keep track of update rate
::AbstractFloat
#Inside generator
ActorCritic(
λʷ::Real
,λᶿ::Real
,αʷ::Real
,αᶿ::Real
,αᴿ::Real
,θ::Flux.Params
,w::Flux.Params
,R::Real
) =
begin
= zero(θ) #custom zero handles params
zw = zero(w) #custom zero handles params
new(λʷ,λᶿ,αʷ,αᶿ,αᴿ,,zw,R)
end
end
# ╔═╡ 0ac4ac61-0ca2-4ed6-8855-85c9d1939afa
md"""
# Example use: Model setup
"""
# ╔═╡ 7102c12c-581b-4204-9986-741889cb03f6
#implement tranistion function
#percentage survival function
function survival(
stocks
,debris
,physical_model::AbstractPhysicalModel
)
return exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris))
end
# ╔═╡ 283bdd91-d742-4f4b-a19b-2756ac01ce2c
#stock update rules
function G(
stocks::Vector
,debris::Vector
,launches::Vector
, physical_model::AbstractPhysicalModel
)
return LinearAlgebra.diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches
end
# ╔═╡ 8a3f9518-ae0a-46f3-b37d-1017ce99c70d
#debris evolution
function H(stocks,debris,launches,physical_model)
#get changes in debris from natural dynamics
natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris
#get changes in debris from satellite loss
satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks
#get changes in debris from launches
launch_debris = physical_model.launch_debris_ratio*sum(launches)
#return total debris level
return natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris
end
# ╔═╡ e4c8abfa-e96c-44b7-9858-af46adc1164a
#implement reward function
begin
const payoff = 3*LinearAlgebra.I #- 0.02*ones(N_constellations,N_constellations)
#Define the market profit function
F(stocks,debris,launches) = (stocks - 3.0*launches .+ (debris*-0.2))[1]
end
# ╔═╡ 42be8352-7ecf-4da4-bff0-61e328144ed1
#test
F(4,2,2)
# ╔═╡ e2135298-e280-40ac-a549-6a15f13b84ee
md"""
## Example use: Single actor.
### Model and Optimizer Parameterization
"""
# ╔═╡ 5a2dab62-985c-4689-ab48-da04981d15a0
#Model shape
begin
const N_constellations = 1;
const N_debris = 1;
const N_states = N_constellations + N_debris;
end
# ╔═╡ 3f14b399-cae8-433a-8729-fc91c9b0bee9
# Launch function
launch_policy = Flux.Chain(
Flux.Parallel(vcat
#parallel joins together stocks and debris, along with intermediate interpretation
,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu)
,Flux.Dense(N_states*2, N_states*2,Flux.σ)
)
,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu)
,Flux.Dense(N_states, N_states,Flux.σ)
)
)
#Apply some transformations
,Flux.Dense(N_states*3,128,Flux.σ)
,Flux.Dense(128,128,Flux.σ)
,Flux.Dense(128,N_constellations,Flux.relu)
)
# ╔═╡ 89dabcc8-7f4e-4df4-8445-eae9f50ebfea
#inspect parameters
Flux.params(launch_policy)
# ╔═╡ 15d5ce62-135d-493e-abad-3ec1f99bfd3b
#Test the function above
zero(Flux.params(launch_policy))
# ╔═╡ 2c9346eb-048d-4667-b22b-583b814ae1a3
# Launch function
value = Flux.Chain(
Flux.Parallel(vcat
#parallel joins together stocks and debris, along with intermediate interpretation
,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu)
,Flux.Dense(N_states*2, N_states*2,Flux.σ)
)
,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu)
,Flux.Dense(N_states, N_states,Flux.σ)
)
)
#Apply some transformations
,Flux.Dense(N_states*3,128,Flux.σ)
,Flux.Dense(128,128,Flux.σ)
,Flux.Dense(128,1)
);
# ╔═╡ 95b9943b-8c13-4717-8239-2bae9f93ee63
#create the current actor critic optimizer
optim = ActorCritic(
0.5
,0.5
,2.0
,2.0
,2.0
,Flux.params(launch_policy)
,Flux.params(value)
,0.5
)
# ╔═╡ 43704f08-6f1a-47bb-87e3-483a529b1654
#Starting States
begin
stock_state = ones(N_constellations);
debris_state = ones(N_debris);
end;
# ╔═╡ 6f3649c1-2ade-437d-81cb-d9f46306d1e0
#check launch policy
launch_policy((stock_state,debris_state))
# ╔═╡ 7dfc61a9-569a-4e97-a292-4c2b15a78e3a
value((stock_state,debris_state))
# ╔═╡ 2f0ff236-0dff-4cd8-bbb4-f1e81ecef308
#=
Setup the physical model
These values are just guesstimates
=#
begin
#Getting loss parameters together.
loss_param = 2e-3;
loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I);
#orbital decay rate
decay_param = 0.01;
#debris generation parameters
autocatalysis_param = 0.001;
satellite_loss_debris_rate = 5.0;
launch_debris_rate = 0.05;
#Todo, wrap physical model as a struct with the parameters
bm = BasicModel(
loss_param
,loss_weights
,decay_param
,autocatalysis_param
,satellite_loss_debris_rate
,launch_debris_rate
);
end
# ╔═╡ df771ced-e541-4bb2-b6a3-bb562bbaa89c
md"""
### Evaluation loop
"""
# ╔═╡ 6d9a71ea-b9d1-4efd-aa32-9053b96a66e1
#=Actor Critic Loop
This iterates on the actor critic loop,
=#
with_terminal() do
for iit in 1:12
w = Flux.params(value)
θ = Flux.params(launch_policy)
#Get learning parameters
action = launch_policy((stock_state,debris_state))
new_stock_state = G(stock_state,debris_state,action,bm)
new_debris_state = H(stock_state,debris_state,action,bm)
println(new_debris_state)
#get current value
current_value = value((stock_state,debris_state))
new_value = value((new_stock_state,new_debris_state))
#need to define
R = F(stock_state,debris_state,action)
#FIX: R is a vector, so are the values. This needs fixed.
δ = ((R .- optim.) .+ (new_value .- current_value))[1] #fix
#store values
optim. = optim. + δ*optim.αᴿ
#check for exit conditions.
#probably use grad=0
continue #issue in calculating gradients
#update the learning traces
optim. = optim.λʷ .* optim. .+ Flux.gradient(value,w)
optim.zᶿ = optim.λᶿ .* optim.zᶿ .+ Flux.gradient(policy,θ)
continue
#update the policies
w = w .+ δ* optim.αʷ .* optim.;
θ = θ .+ δ* optim.αᶿ .* optim.zᶿ;
println("its working, $iit")
end
end
# ╔═╡ 9941e984-e51b-4e64-8d16-13c8071363b0
#TODO:issue is here, taking gradient. I get to figure this out next.
Zygote.gradient(() -> sum(value(stock_state,debris_state),w))
# ╔═╡ dd9f5f04-31e4-441e-9f5f-3019a30692d6
with_terminal() do
x = rand(Float32, 10)
m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax)
l(x) = Flux.Losses.crossentropy(m(x), [0.5, 0.5])
grads = gradient(params(m)) do
l(x)
end
for p in params(m)
println(grads[p])
end
end
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
Flux = "~0.12.8"
PlutoUI = "~0.7.17"
Zygote = "~0.6.29"
"""
# ╔═╡ 00000000-0000-0000-0000-000000000002
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"
[[AbstractTrees]]
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.4"
[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.1"
[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
[[ArrayInterface]]
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d9352737cef8525944bf9ef34392d756321cbd54"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.38"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[BFloat16s]]
deps = ["LinearAlgebra", "Printf", "Random", "Test"]
git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.2.0"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"
[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "2c8329f16addffd09e6ca84c556e2185a4933c64"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.5.0"
[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"]
git-tree-sha1 = "035ef8a5382a614b2d8e3091b6fdbb1c2b050e11"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.12.1"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.0"
[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.12.8"
[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.40.0"
[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
[[DataAPI]]
git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.9.0"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.10"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.3"
[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.3.1"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.6"
[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[ExprTools]]
git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.6"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.7"
[[FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "ArrayInterface", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NNlibCUDA", "Pkg", "Printf", "Random", "Reexport", "SHA", "SparseArrays", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"]
git-tree-sha1 = "e8b37bb43c01eed0418821d1f9d20eca5ba6ab21"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.12.8"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "63777916efbcb0ab6173d09a658fb7f2783de485"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.21"
[[Functors]]
git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.7"
[[GPUArrays]]
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "7772508f17f1d482fe0df72cabc5b55bec06bbe0"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.1.2"
[[GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "77d915a0af27d474f0aaf12fcd46c400a552e84c"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.13.7"
[[Hyperscript]]
deps = ["Test"]
git-tree-sha1 = "8d511d5b81240fc8e6802386302675bdf47737b9"
uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91"
version = "0.0.4"
[[HypertextLiteral]]
git-tree-sha1 = "5efcf53d798efede8fee5b2c8b09284be359bf24"
uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2"
version = "0.9.2"
[[IOCapture]]
deps = ["Logging", "Random"]
git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.2.2"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.3"
[[IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.1"
[[IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"
[[JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.3.0"
[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.2"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile"]
git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.8.4"
[[LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.6.0"
[[LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.11+0"
[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[LogExpFunctions]]
deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.4"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.9"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.0.2"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.29"
[[NNlibCUDA]]
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9"
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
version = "0.1.9"
[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"
[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"
[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "d911b6a12ba974dabe2291c6d450094a7226b372"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.1.1"
[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[PlutoUI]]
deps = ["Base64", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"]
git-tree-sha1 = "615f3a1eff94add4bca9476ded096de60b46443b"
uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
version = "0.7.17"
[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.2"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.4.2"
[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.5.3"
[[RealDot]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
version = "0.1.0"
[[Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.0.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "f0bccf98e16759818ffc5d97ac3ebf87eb950150"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.8.1"
[[Static]]
deps = ["IfElse"]
git-tree-sha1 = "e7bc80dc93f50857a5d1e3c8121495852f407e6a"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.4.0"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.13"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsAPI]]
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"
[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"
[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.13"
[[TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.6"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["Libdl", "Printf", "Zlib_jll"]
git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.9.4"
[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "0fc9959bcabc4668c403810b4e851f6b8962eac9"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.29"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.2"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
"""
# ╔═╡ Cell order:
# ╟─7d6afccb-c45e-4371-827a-43d588d4945f
# ╠═cc16838c-3b25-11ec-2489-11c3d35f26f4
# ╟─a810a371-d2e4-4d4f-9cf6-1b98582c4b0f
# ╠═2e59f530-aeb0-4ce1-9312-e81a66520181
# ╠═c57e9967-75e8-47d2-836e-f28159dc6592
# ╟─4153bfd6-2ec6-4025-93ba-8e6e061809e8
# ╟─bbc7e4b0-5edb-4e5d-9512-1aaeae1bb4ec
# ╠═533250c6-fd76-47ef-a561-d567da04ae61
# ╠═3f14b399-cae8-433a-8729-fc91c9b0bee9
# ╠═2c9346eb-048d-4667-b22b-583b814ae1a3
# ╠═89dabcc8-7f4e-4df4-8445-eae9f50ebfea
# ╠═6f3649c1-2ade-437d-81cb-d9f46306d1e0
# ╠═7dfc61a9-569a-4e97-a292-4c2b15a78e3a
# ╟─a305c33f-3388-4120-b036-3093c3ed6aa3
# ╠═458bb826-4eaf-42ca-b889-4d1c50a2ecae
# ╠═0d375c34-3226-40b6-adaf-2f531c00d3ac
# ╠═15d5ce62-135d-493e-abad-3ec1f99bfd3b
# ╠═e79e4ecf-caec-4c1a-9626-0c1a8b5006ce
# ╟─0ac4ac61-0ca2-4ed6-8855-85c9d1939afa
# ╠═7102c12c-581b-4204-9986-741889cb03f6
# ╠═283bdd91-d742-4f4b-a19b-2756ac01ce2c
# ╠═8a3f9518-ae0a-46f3-b37d-1017ce99c70d
# ╠═e4c8abfa-e96c-44b7-9858-af46adc1164a
# ╠═42be8352-7ecf-4da4-bff0-61e328144ed1
# ╟─e2135298-e280-40ac-a549-6a15f13b84ee
# ╠═95b9943b-8c13-4717-8239-2bae9f93ee63
# ╠═5a2dab62-985c-4689-ab48-da04981d15a0
# ╠═43704f08-6f1a-47bb-87e3-483a529b1654
# ╠═2f0ff236-0dff-4cd8-bbb4-f1e81ecef308
# ╟─df771ced-e541-4bb2-b6a3-bb562bbaa89c
# ╠═6d9a71ea-b9d1-4efd-aa32-9053b96a66e1
# ╠═9941e984-e51b-4e64-8d16-13c8071363b0
# ╠═dd9f5f04-31e4-441e-9f5f-3019a30692d6
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002

@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "41dcca64-963f-488e-b92e-f1dc5109359a",
"metadata": {},
"outputs": [],
"source": [
"using Enzyme"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ca966ab8-469e-4f8c-af54-579c55f54bd4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"()"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"function mymul!(R, A, B)\n",
" @assert axes(A,2) == axes(B,1)\n",
" @inbounds @simd for i in eachindex(R)\n",
" R[i] = 0\n",
" end\n",
" @inbounds for j in axes(B, 2), i in axes(A, 1)\n",
" @inbounds @simd for k in axes(A,2)\n",
" R[i,j] += A[i,k] * B[k,j]\n",
" end\n",
" end\n",
" nothing\n",
"end\n",
"\n",
"\n",
"A = rand(5, 3)\n",
"B = rand(3, 7)\n",
"\n",
"R = zeros(size(A,1), size(B,2))\n",
"∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n",
"\n",
"∂z_∂A = zero(A)\n",
"∂z_∂B = zero(B)\n",
"\n",
"Enzyme.autodiff(mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7442bfb5-3146-493d-9abc-9afeb56c0471",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"true"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"R ≈ A * B &&\n",
"∂z_∂A ≈ ∂z_∂R * B' && # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]\n",
"∂z_∂B ≈ A' * ∂z_∂R # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "36a4cd0f-c5e2-4a6f-b434-2d347686a08b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×7 Matrix{Float64}:\n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n",
" 0.0 0.0 0.0 0.0 0.0 0.0 0.0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#reset\n",
"R = zeros(size(A,1), size(B,2))\n",
"∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n",
"\n",
"∂z_∂A = zero(A)\n",
"∂z_∂B = zero(B)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "49cbf4e1-1ef0-4428-90af-02ac2b46c2a8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"revenue! (generic function with 1 method)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function revenue!(R, A, B)\n",
" @assert axes(A,2) == axes(B,1)\n",
" @inbounds @simd for i in eachindex(R)\n",
" R[i] = 0\n",
" end\n",
" @inbounds for j in axes(B, 2), i in axes(A, 1)\n",
" @inbounds @simd for k in axes(A,2)\n",
" R[i,j] += A[i,k] * B[k,j]\n",
" end\n",
" end\n",
" nothing\n",
"end\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5fea97a5-39ac-4ade-9cb4-4aba66a80825",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 5;\n",
"constellations = 2;\n",
"payoff_mat = zeros(batch_size,1);\n",
"\n",
"stocks = rand(batch_size, constellations);\n",
"\n",
"payoffs = rand(constellations,1);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d06ec650-621c-4a97-8bbd-0bfcfd4ab8d5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×1 Matrix{Float64}:\n",
" 0.05943992677268309\n",
" 0.16746133343364858\n",
" 0.22311130107900645\n",
" 0.1326381498910713\n",
" 0.23997313509634804"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"revenue!(payoff_mat, stocks,payoffs)\n",
"\n",
"payoff_mat"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "86425c6b-baa3-494d-8c2c-35eaf08cefaf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×1 Matrix{Float64}:\n",
" 1.0\n",
" 1.0\n",
" 1.0\n",
" 1.0\n",
" 1.0"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"payoff_mat = zero(payoff_mat)\n",
"∂payoff_mat = ones(size(payoff_mat)...)\n",
"\n",
"∂stocks = zero(stocks)\n",
"∂payoff_mat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbb5fa3f-b3c6-48bf-9df2-657d5d7aae18",
"metadata": {},
"outputs": [],
"source": [
"autodiff(revenue!\n",
" ,Duplicated(payoff_mat, ∂payoff_mat)\n",
" ,Const(payoffs)\n",
" ,Duplicated(stocks,∂stocks)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0872e372-f3d1-4fca-b89f-f134a3dc563d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "984b1bd7-e53a-41c3-bafd-f56aacdae4b7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.6.2",
"language": "julia",
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,925 @@
### A Pluto.jl notebook ###
# v0.17.0
using Markdown
using InteractiveUtils
# ╔═╡ 20b324a4-3b6c-11ec-3673-d98ec8af9009
import Zygote, LinearAlgebra,Flux,BenchmarkTools,PlutoUI
# ╔═╡ 64f31ef7-e0ba-4353-8e51-e8356a894656
abstract type AbstractPhysicalModel end
# ╔═╡ 11a53b0c-465b-4f16-88d2-e0163e471fd6
#setup physical model
struct BasicModel <: AbstractPhysicalModel
#rate at which debris hits satellites
debris_collision_rate::Real
#rate at which satellites of different constellations collide
satellite_collision_rates::Matrix{Float64}
#rate at which debris exits orbits
decay_rate::Real
#rate at which satellites
autocatalysis_rate::Real
#ratio at which a collision between satellites produced debris
satellite_collision_debris_ratio::Real
#Ratio at which launches produce debris
launch_debris_ratio::Real
end
# ╔═╡ 1d4955f9-1d1c-4572-bc7b-9f002ea54042
begin
const N_constellations = 3;
const N_debris = 1;
const N_states = N_constellations + N_debris;
end
# ╔═╡ 6175d40b-0fa6-4cb4-8ef3-9765953ce97e
#=
Setup the physical model
These values are just guesstimates
=#
begin
#Getting loss parameters together.
loss_param = 2e-3;
loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I);
#orbital decay rate
decay_param = 0.01;
#debris generation parameters
autocatalysis_param = 0.001;
satellite_loss_debris_rate = 5.0;
launch_debris_rate = 0.05;
#Todo, wrap physical model as a struct with the parameters
bm = BasicModel(
loss_param
,loss_weights
,decay_param
,autocatalysis_param
,satellite_loss_debris_rate
,launch_debris_rate
);
end
# ╔═╡ 28c0d31f-e90a-415c-b35e-312bdf771ddf
md"""
# Setup Model Functions
- Debris Transitions
- Derivatives of debris transitions
"""
# ╔═╡ 7462740e-23da-4151-81b5-cc2e6cbcf2c8
#implement tranistion function
#percentage survival function
function survival(
stocks
,debris
,physical_model::AbstractPhysicalModel
)
return exp.(
-physical_model.satellite_collision_rates * stocks
.- (physical_model.debris_collision_rate*debris)
)
end
# ╔═╡ 1eee99c9-5fef-498d-8a57-9c18e2f1cf49
#= Stock levels evolution
=#
function G(
stocks::Vector
,debris::Vector
,launches::Vector
, physical_model::AbstractPhysicalModel
)
return LinearAlgebra.diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches
end
# ╔═╡ bc3492c6-5c3b-4b08-86d4-dd4026e25655
#= Debris evolution
The model is
=#
function H(stocks,debris,launches,physical_model)
#get changes in debris from natural dynamics
natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris
#get changes in debris from satellite loss
satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks
#Possible Issue: Broadcasts? ^^^
#get changes in debris from launches
launch_debris = physical_model.launch_debris_ratio*sum(launches)
#Possible Issue: Sum? ^^^
#return total debris level
return natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris
end
# ╔═╡ 56314025-fb0a-4e98-8628-091bda52708c
begin
number_params=10
∂value = Flux.Chain(
Flux.Parallel(vcat
#parallel joins together stocks and debris
,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu)
#,Flux.Dense(N_states*2, N_states*2,Flux.σ)
)
,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu)
#,Flux.Dense(N_states, N_states)
)
)
#Apply some transformations
,Flux.Dense(N_states*3,number_params,Flux.σ)
#,Flux.Dense(number_params,number_params,Flux.σ)
#Split out into partials related to stocks and debris
,Flux.Parallel(vcat
,Flux.Chain(
#Flux.Dense(number_params, number_params,Flux.relu)
#,
Flux.Dense(number_params, N_constellations,Flux.σ)
)
,Flux.Chain(
#Flux.Dense(number_params, number_params,Flux.relu)
#,
Flux.Dense(number_params, N_debris)
)
)
)
end
# ╔═╡ f61682fe-da30-459f-b807-4fc3e8f36f32
# ╔═╡ 31237f62-428a-4df3-9caf-1f5c5ade6ade
md"""
### Payout function
"""
# ╔═╡ 55fd682d-2808-4786-ad52-647ac6892860
begin
pay_matrix = zeros(N_constellations,N_constellations) + LinearAlgebra.I
end
# ╔═╡ bf869218-034e-42f9-96c5-4267a8c8fcfb
function F1(
s #stock levels
,d #debris levels
,a #actions
#,econ::Int #TODO: a struct with various bits of info
#,taxes::Int #TODO: a struct with info on taxes
#,debris_interactions::Matrix #Ignoring for now
,n::Matrix #constellation of interest
)
return n*(pay_matrix*s) - n*5.0*a
#-taxes and econ should be structs with parameters later
end
# ╔═╡ 4817d4fe-b86c-43e2-94d2-fa6c424fa001
begin
n1 = [1.0 2 3]
d1 = [-0.002 0]
end
# ╔═╡ 48a0d42f-bd00-4298-8b28-56acf2dbc8a7
# ╔═╡ 5d85cdb6-a627-4dd5-b76c-79ed2bc019c6
# ╔═╡ 5a5ca97f-28a9-4196-a1a6-40f3ae4bbb9c
# ╔═╡ 932b7a31-d5f4-44a3-a6c3-f608003a0a6f
md"""
## Building a function for the transition residuals (transversality conditions)
"""
# ╔═╡ 29a8f5c6-d5b8-4ec6-95a4-89025245787d
loss(x) = sum(x.^2)
# ╔═╡ ba7ca719-9687-4f1a-9a36-89b795a1bc13
md"""
### Testing training the value partials
"""
# ╔═╡ fda49870-c355-4b46-8d04-47268cb0372d
md"""
## Building a function for Optimality Residuals (Optimality Condition)
"""
# ╔═╡ 287ff2f9-3469-4e29-91a3-91ed54109f5a
md"""
## Parameters for testing
"""
# ╔═╡ d173ec63-b7c3-4042-862c-2626039d6d94
β = 0.95
# ╔═╡ 8eca54ad-f95b-450c-8567-480f0ab7ea19
function transition_residuals(G,H,F
,s,s #stocks t and t+1
,d,d #debris ''
,p #policy
,n #the constellation we are dealing with
,physics #Parameters of the physical model of the world
;econ=0.0,taxes=0.0 #economic parameters (to be incorporated later) and tax policy
)
#calculate partials of transition functions
∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,physics),s)[1]
∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,physics),d)[1]
∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,physics),s)[1]
∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,physics),d)[1]
#concatenate to create iterated vector
∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s,d))
#get partials of benefit function
∂F_∂d = Zygote.jacobian(debris->F(s,debris,p,n1),d)[1]
∂F_∂s = Zygote.jacobian(stocks->F(stocks,d,p,n1),s)[1] #this probaby should not be transposed
#concatenate
∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )'
#calculate the optimality residuals
∂F_∂θ + β*∂Vₜ₊₁_∂θ - ∂value((s,d))
end
# ╔═╡ 274e3a3f-fe23-4d6f-893c-622db0413af0
begin
s = [1.0,2,3]
d = [1.0]
p = [1.0,2,3]
end
# ╔═╡ 5e754d71-f4c3-476c-9656-61f07687a534
G(s,d,[1,1,3],bm)
# ╔═╡ 749e6b65-fd8e-45ed-9a2b-f97274549933
survival(s,d,bm)
# ╔═╡ d38be057-952a-450e-ba23-bd2f40e79de3
F1(s,d,p,n1)
# ╔═╡ affe6dd2-05e2-4852-9ef0-9605bdb5e34b
Zygote.jacobian(stocks -> F1(stocks,d,p,n1),s)
# ╔═╡ 1c8f2c25-b8da-4e31-adac-ad575203f9cf
begin
#calculate partials of transition functions
∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,bm),s)[1]
∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,bm),d)[1]
∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,bm),s)[1]
∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,bm),d)[1]
#concatenate to create iterated vector
∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s,d))
#get partials of benefit function
∂F_∂d = Zygote.jacobian(debris->F1(s,debris,p,n1),d)[1]
∂F_∂s = Zygote.jacobian(stocks->F1(stocks,d,p,n1),s)[1] #this probaby should not be transposed
#concatenate
∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )'
#calculate the optimality residuals
a = ∂F_∂θ + β * ∂Vₜ₊₁_∂θ - ∂value((s,d))
end
# ╔═╡ 48b09bfb-87fe-4974-a4f3-7a6ece521da6
sum(a.^2)
# ╔═╡ e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad
loss(a)
# ╔═╡ 1e0cf592-fd96-4a01-84d1-078301ac7f99
x = Flux.gradient(() -> loss(a), Flux.params(∂value))
# ╔═╡ 8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b
PlutoUI.with_terminal() do
println("beginning")
for pr in Flux.params(∂value)
println(x[pr])
end
end
# ╔═╡ 9b59255c-77b4-4df6-a70b-e1fdaab7619b
grads_value1 = Flux.gradient(() -> loss(transition_residuals(G,H,F1,s,s,d,d,p,1,bm)),Flux.params(∂value))
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
BenchmarkTools = "~1.2.0"
Flux = "~0.12.8"
PlutoUI = "~0.7.17"
Zygote = "~0.6.29"
"""
# ╔═╡ 00000000-0000-0000-0000-000000000002
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
[[AbstractFFTs]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.0.1"
[[AbstractTrees]]
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
version = "0.3.4"
[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.3.1"
[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
[[ArrayInterface]]
deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "d9352737cef8525944bf9ef34392d756321cbd54"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.38"
[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[BFloat16s]]
deps = ["LinearAlgebra", "Printf", "Random", "Test"]
git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.2.0"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
[[BenchmarkTools]]
deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
git-tree-sha1 = "61adeb0823084487000600ef8b1c00cc2474cd47"
uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
version = "1.2.0"
[[CEnum]]
git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"
[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "2c8329f16addffd09e6ca84c556e2185a4933c64"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.5.0"
[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"]
git-tree-sha1 = "035ef8a5382a614b2d8e3091b6fdbb1c2b050e11"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.12.1"
[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.0"
[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0"
[[Colors]]
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"]
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40"
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
version = "0.12.8"
[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"
[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.40.0"
[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
[[DataAPI]]
git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.9.0"
[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.18.10"
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
[[DelimitedFiles]]
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
version = "1.0.3"
[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.3.1"
[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.6"
[[Downloads]]
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[ExprTools]]
git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.6"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.7"
[[FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[Flux]]
deps = ["AbstractTrees", "Adapt", "ArrayInterface", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NNlibCUDA", "Pkg", "Printf", "Random", "Reexport", "SHA", "SparseArrays", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"]
git-tree-sha1 = "e8b37bb43c01eed0418821d1f9d20eca5ba6ab21"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.12.8"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "63777916efbcb0ab6173d09a658fb7f2783de485"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.21"
[[Functors]]
git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.7"
[[GPUArrays]]
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "7772508f17f1d482fe0df72cabc5b55bec06bbe0"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.1.2"
[[GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "77d915a0af27d474f0aaf12fcd46c400a552e84c"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.13.7"
[[Hyperscript]]
deps = ["Test"]
git-tree-sha1 = "8d511d5b81240fc8e6802386302675bdf47737b9"
uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91"
version = "0.0.4"
[[HypertextLiteral]]
git-tree-sha1 = "5efcf53d798efede8fee5b2c8b09284be359bf24"
uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2"
version = "0.9.2"
[[IOCapture]]
deps = ["Logging", "Random"]
git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.2.2"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.3"
[[IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.1"
[[IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"
[[JLLWrappers]]
deps = ["Preferences"]
git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
version = "1.3.0"
[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.2"
[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile"]
git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.8.4"
[[LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.6.0"
[[LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.11+0"
[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
[[LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
[[LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[LibSSH2_jll]]
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[LogExpFunctions]]
deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.4"
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[MacroTools]]
deps = ["Markdown", "Random"]
git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.9"
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[MbedTLS_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
[[Media]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
version = "0.5.0"
[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.0.2"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[MozillaCACerts_jll]]
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.29"
[[NNlibCUDA]]
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9"
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
version = "0.1.9"
[[NaNMath]]
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
[[OpenLibm_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
[[OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.5+0"
[[OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"
[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "d911b6a12ba974dabe2291c6d450094a7226b372"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "2.1.1"
[[Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[PlutoUI]]
deps = ["Base64", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"]
git-tree-sha1 = "615f3a1eff94add4bca9476ded096de60b46443b"
uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
version = "0.7.17"
[[Preferences]]
deps = ["TOML"]
git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
uuid = "21216c6a-2e73-6563-6e65-726566657250"
version = "1.2.2"
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[Profile]]
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.4.2"
[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.5.3"
[[RealDot]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
version = "0.1.0"
[[Reexport]]
git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
version = "1.2.2"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.0.1"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "f0bccf98e16759818ffc5d97ac3ebf87eb950150"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "1.8.1"
[[Static]]
deps = ["IfElse"]
git-tree-sha1 = "e7bc80dc93f50857a5d1e3c8121495852f407e6a"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.4.0"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.13"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[StatsAPI]]
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.0.0"
[[StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"
[[TOML]]
deps = ["Dates"]
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
[[Tar]]
deps = ["ArgTools", "SHA"]
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.13"
[[TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.6"
[[UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
[[ZipFile]]
deps = ["Libdl", "Printf", "Zlib_jll"]
git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de"
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
version = "0.9.4"
[[Zlib_jll]]
deps = ["Libdl"]
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "0fc9959bcabc4668c403810b4e851f6b8962eac9"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.29"
[[ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.2"
[[nghttp2_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
[[p7zip_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
"""
# ╔═╡ Cell order:
# ╠═20b324a4-3b6c-11ec-3673-d98ec8af9009
# ╠═64f31ef7-e0ba-4353-8e51-e8356a894656
# ╠═11a53b0c-465b-4f16-88d2-e0163e471fd6
# ╠═1d4955f9-1d1c-4572-bc7b-9f002ea54042
# ╠═6175d40b-0fa6-4cb4-8ef3-9765953ce97e
# ╠═28c0d31f-e90a-415c-b35e-312bdf771ddf
# ╠═7462740e-23da-4151-81b5-cc2e6cbcf2c8
# ╠═1eee99c9-5fef-498d-8a57-9c18e2f1cf49
# ╠═bc3492c6-5c3b-4b08-86d4-dd4026e25655
# ╠═5e754d71-f4c3-476c-9656-61f07687a534
# ╠═749e6b65-fd8e-45ed-9a2b-f97274549933
# ╠═56314025-fb0a-4e98-8628-091bda52708c
# ╠═f61682fe-da30-459f-b807-4fc3e8f36f32
# ╟─31237f62-428a-4df3-9caf-1f5c5ade6ade
# ╠═55fd682d-2808-4786-ad52-647ac6892860
# ╠═bf869218-034e-42f9-96c5-4267a8c8fcfb
# ╠═4817d4fe-b86c-43e2-94d2-fa6c424fa001
# ╠═d38be057-952a-450e-ba23-bd2f40e79de3
# ╠═affe6dd2-05e2-4852-9ef0-9605bdb5e34b
# ╠═48a0d42f-bd00-4298-8b28-56acf2dbc8a7
# ╠═5d85cdb6-a627-4dd5-b76c-79ed2bc019c6
# ╠═5a5ca97f-28a9-4196-a1a6-40f3ae4bbb9c
# ╠═932b7a31-d5f4-44a3-a6c3-f608003a0a6f
# ╠═8eca54ad-f95b-450c-8567-480f0ab7ea19
# ╠═1c8f2c25-b8da-4e31-adac-ad575203f9cf
# ╠═48b09bfb-87fe-4974-a4f3-7a6ece521da6
# ╠═29a8f5c6-d5b8-4ec6-95a4-89025245787d
# ╠═e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad
# ╠═1e0cf592-fd96-4a01-84d1-078301ac7f99
# ╠═8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b
# ╟─ba7ca719-9687-4f1a-9a36-89b795a1bc13
# ╠═9b59255c-77b4-4df6-a70b-e1fdaab7619b
# ╠═fda49870-c355-4b46-8d04-47268cb0372d
# ╠═287ff2f9-3469-4e29-91a3-91ed54109f5a
# ╠═d173ec63-b7c3-4042-862c-2626039d6d94
# ╠═274e3a3f-fe23-4d6f-893c-622db0413af0
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002

@ -0,0 +1,617 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 124,
"id": "3d049ef3-2630-4064-a2bd-ee28c7a89e9e",
"metadata": {},
"outputs": [],
"source": [
"using LinearAlgebra, Zygote, Tullio"
]
},
{
"cell_type": "markdown",
"id": "cfa5d466-5354-4ef4-a4ab-2bf2ed69626e",
"metadata": {},
"source": [
"### notes\n",
"\n",
"Notice how this approach estimates policy and the value function in a single go.\n",
"While you could eliminate the value function partials, it would be much more complex.\n",
"Note that in RL, both must be iterated on."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2bc6f9bd-8308-40a0-bee4-acbb1be4672d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Model dimensions\n",
"N_constellations = 5\n",
"N_debris = 1\n",
"N_states = N_constellations + N_debris"
]
},
{
"cell_type": "markdown",
"id": "15670d3b-9651-4fc5-b98a-29ba60b4b38e",
"metadata": {},
"source": [
"## Built Payoff functions"
]
},
{
"cell_type": "code",
"execution_count": 156,
"id": "1e99275a-9ab5-4c9e-921b-9587145b1fb5",
"metadata": {},
"outputs": [],
"source": [
"stocks = rand(1:N_constellations,N_constellations);\n",
"debris = rand(1:3);\n",
"payoff = 3*I + ones(5,5) .+ [1,0,0,0,0]; #TODO: move this into a struct\n",
"β=0.95\n",
"launches = ones(N_constellations);"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "a4e4bbaf-61de-4860-adc2-1e91f0626ead",
"metadata": {},
"outputs": [],
"source": [
"#Define the market profit function\n",
"F(stocks,debris,payoff,launches) = payoff*stocks + 3.0*launches .+ (debris*-0.2)\n",
"\n",
"#create derivative functions\n",
"∂f_∂stocks(st,debris,payoff,launches) = Zygote.jacobian(s -> F(s,debris,payoff,launches),st)[1];\n",
"∂f_∂debris(stocks,d,payoff,launches) = Zygote.jacobian(s -> F(stocks,s,payoff,launches),d)[1];\n",
"∂f_∂launches(stocks,debris,payoff,launches) = Zygote.jacobian(l -> F(stocks,debris,payoff,l),launches)[1];\n"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "ad73e64a-2fa5-4bdf-a4e4-22357b53002c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 5.0 2.0 2.0 2.0 2.0\n",
" 1.0 4.0 1.0 1.0 1.0\n",
" 1.0 1.0 4.0 1.0 1.0\n",
" 1.0 1.0 1.0 4.0 1.0\n",
" 1.0 1.0 1.0 1.0 4.0"
]
},
"execution_count": 158,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂f_∂stocks(stocks,debris,payoff,launches) \n",
"#Rows are constellations, columns are derivatives indexed by constellations"
]
},
{
"cell_type": "code",
"execution_count": 159,
"id": "a6159740-8ef2-40e0-b10f-c451cf78418d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5-element Vector{Float64}:\n",
" -0.2\n",
" -0.2\n",
" -0.2\n",
" -0.2\n",
" -0.2"
]
},
"execution_count": 159,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂f_∂debris(stocks,debris,payoff,launches)"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "b46da3d9-1765-4427-b3de-bf0a9fdd6576",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 3.0 0.0 0.0 0.0 0.0\n",
" 0.0 3.0 0.0 0.0 0.0\n",
" 0.0 0.0 3.0 0.0 0.0\n",
" 0.0 0.0 0.0 3.0 0.0\n",
" 0.0 0.0 0.0 0.0 3.0"
]
},
"execution_count": 160,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂f_∂launches(stocks,debris,payoff,launches)"
]
},
{
"cell_type": "markdown",
"id": "7a3a09cc-e48d-4c27-9e4f-02e45a5fc5c1",
"metadata": {},
"source": [
"## Building Physical Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "638ed4ac-36b0-4efb-bcb4-13d345470dd7",
"metadata": {},
"outputs": [],
"source": [
"struct BasicModel\n",
" #rate at which debris hits satellites\n",
" debris_collision_rate\n",
" #rate at which satellites of different constellations collide\n",
" satellite_collision_rates\n",
" #rate at which debris exits orbits\n",
" decay_rate\n",
" #rate at which satellites\n",
" autocatalysis_rate\n",
" #ratio at which a collision between satellites produced debris\n",
" satellite_collision_debris_ratio\n",
" #Ratio at which launches produce debris\n",
" launch_debris_ratio\n",
"end\n",
"\n",
"#Getting loss parameters together.\n",
"loss_param = 2e-3;\n",
"loss_weights = loss_param*(ones(N_constellations,N_constellations) - I);\n",
"\n",
"#orbital decay rate\n",
"decay_param = 0.01;\n",
"\n",
"#debris generation parameters\n",
"autocatalysis_param = 0.001;\n",
"satellite_loss_debris_rate = 5.0;\n",
"launch_debris_rate = 0.05;\n",
"\n",
"#Todo, wrap physical model as a struct with the parameters\n",
"bm = BasicModel(\n",
" loss_param\n",
" ,loss_weights\n",
" ,decay_param\n",
" ,autocatalysis_param\n",
" ,satellite_loss_debris_rate\n",
" ,launch_debris_rate\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 163,
"id": "ae04a4f2-7401-450e-ba98-bfec375b7646",
"metadata": {},
"outputs": [],
"source": [
"#percentage survival function\n",
"function survival(stocks,debris,physical_model) \n",
" exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris));\n",
"end\n",
"\n",
"#stock update rules\n",
"function G(stocks,debris,launches, physical_model)\n",
" return diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches\n",
"end;\n",
"\n",
"#stock evolution wrt various things\n",
"∂G_∂launches(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(stocks,debris,x,bm),launches)[1];\n",
"∂G_∂debris(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(stocks,x,launches,bm),debris)[1];\n",
"∂G_∂stocks(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(x,debris,launches,bm),stocks)[1];\n",
"\n",
"#debris evolution \n",
"function H(stocks,debris,launches,physical_model)\n",
" #get changes in debris from natural dynamics\n",
" natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris \n",
" \n",
" #get changes in debris from satellite loss\n",
" satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks \n",
" \n",
" #get changes in debris from launches\n",
" launch_debris = physical_model.launch_debris_ratio*sum(launches)\n",
" \n",
" #return total debris level\n",
" return natural_debris_dynamics + satellite_loss_debris + launch_debris\n",
"end;\n",
"\n",
"#get jacobians of debris dynamics\n",
"∂H_∂launches(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(stocks,debris,x,physical_model),launches)[1];\n",
"∂H_∂debris(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(stocks,x,launches,physical_model),debris)[1];\n",
"∂H_∂stocks(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(x,debris,launches,physical_model),stocks)[1];"
]
},
{
"cell_type": "code",
"execution_count": 166,
"id": "c6519b6e-f9ef-466e-a9ac-746b8f36b9e7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 1.0 0.0 0.0 0.0 0.0\n",
" 0.0 1.0 0.0 0.0 0.0\n",
" 0.0 0.0 1.0 0.0 0.0\n",
" 0.0 0.0 0.0 1.0 0.0\n",
" 0.0 0.0 0.0 0.0 1.0"
]
},
"execution_count": 166,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂G_∂launches(stocks,debris,launches,bm)"
]
},
{
"cell_type": "code",
"execution_count": 167,
"id": "1d46c09d-2c91-4302-8e8d-1c0abeb010aa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5-element Vector{Float64}:\n",
" -0.003843157756609293\n",
" -0.009665715046375067\n",
" -0.009665715046375067\n",
" -0.009665715046375067\n",
" -0.003843157756609293"
]
},
"execution_count": 167,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂G_∂debris(stocks,debris,launches,bm)"
]
},
{
"cell_type": "code",
"execution_count": 168,
"id": "1f791c8a-5457-4ec2-bf0b-68eb6bdd578f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 0.950789 -0.00384316 -0.00384316 -0.00384316 -0.00384316\n",
" -0.00966572 0.956572 -0.00966572 -0.00966572 -0.00966572\n",
" -0.00966572 -0.00966572 0.956572 -0.00966572 -0.00966572\n",
" -0.00966572 -0.00966572 -0.00966572 0.956572 -0.00966572\n",
" -0.00384316 -0.00384316 -0.00384316 -0.00384316 0.950789"
]
},
"execution_count": 168,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂G_∂stocks(stocks,debris,launches,bm)"
]
},
{
"cell_type": "code",
"execution_count": 169,
"id": "d97f9af0-d3a9-4739-b569-fa100a1590f3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1×5 Matrix{Float64}:\n",
" 0.05 0.05 0.05 0.05 0.05"
]
},
"execution_count": 169,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂H_∂launches(stocks,debris,launches,bm)"
]
},
{
"cell_type": "code",
"execution_count": 170,
"id": "1a7c857f-7eee-4273-b954-3a7b84a3b3e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1-element Vector{Float64}:\n",
" 1.174417303261719"
]
},
"execution_count": 170,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂H_∂debris(stocks,debris,launches,bm) "
]
},
{
"cell_type": "code",
"execution_count": 171,
"id": "9ea44146-740c-47c0-9286-846585a4db39",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1×5 Matrix{Float64}:\n",
" 0.360254 0.302231 0.302231 0.302231 0.360254"
]
},
"execution_count": 171,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂H_∂stocks(stocks,debris,launches,bm)\n",
"#columns are derivatives"
]
},
{
"cell_type": "markdown",
"id": "17038bfe-c7c2-484b-9aaf-6b608bbbbfab",
"metadata": {},
"source": [
"## Build optimality conditions"
]
},
{
"cell_type": "markdown",
"id": "ef3c7625-e5ee-4220-8bed-b289d0baea1e",
"metadata": {},
"source": [
"## Build transition conditions\n",
"\n",
"I've built the transition conditions below."
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "c3ef6bcc-37ca-4f6c-b876-8ea2a068cbe1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 2.0 1.0 1.0 1.0 1.0\n",
" 1.0 2.0 1.0 1.0 1.0\n",
" 1.0 1.0 2.0 1.0 1.0\n",
" 1.0 1.0 1.0 2.0 1.0\n",
" 1.0 1.0 1.0 1.0 2.0"
]
},
"execution_count": 172,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"W_∂stocks= ones(5,5) + I\n",
"# columns are stocks\n",
"# rows are derivatives"
]
},
{
"cell_type": "code",
"execution_count": 174,
"id": "5441641b-308b-4fec-a946-2568e79a1203",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×1 Matrix{Float64}:\n",
" 2.0\n",
" 2.0\n",
" 2.0\n",
" 2.0\n",
" 2.0"
]
},
"execution_count": 174,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
" #temporary value partials\n",
"W_∂debris = 2*ones(5,1) #temporary value partials\n",
"#columns are\n",
"#rows are"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "819df0f9-e994-4461-9329-4c4607de8779",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 0.966286 -0.00195257 -0.00195257 -0.00195257 -0.00195257\n",
" -0.00785729 0.972161 -0.00785729 -0.00785729 -0.00785729\n",
" -0.00588119 -0.00588119 0.970199 -0.00588119 -0.00588119\n",
" -0.00195257 -0.00195257 -0.00195257 0.966286 -0.00195257\n",
" -0.00391296 -0.00391296 -0.00391296 -0.00391296 0.96824"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"∂G_∂stocks(stocks,debris,launches,bm)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "5cc19637-9727-479e-aa60-a30c71e9e8b8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×5 Matrix{Float64}:\n",
" 1.91695 1.91695 1.91695 1.91695 1.91695\n",
" 1.88146 1.88146 1.88146 1.88146 1.88146\n",
" 1.89335 1.89335 1.89335 1.89335 1.89335\n",
" 1.91695 1.91695 1.91695 1.91695 1.91695\n",
" 1.90518 1.90518 1.90518 1.90518 1.90518"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = ∂f_∂stocks(stocks,debris,payoff,launches)\n",
"b = ∂G_∂stocks(stocks,debris,launches,bm) * W_∂stocks #This last bit should eventually get replaced with a NN\n",
"#Need to check dimensionality above. Not sure which direction is derivatives and which is functions."
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "65917e4c-bd25-4683-8553-17bc841e594a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5×1 Matrix{Float64}:\n",
" -0.019525714195158184\n",
" -0.07857288258866406\n",
" -0.05881192039840531\n",
" -0.019525714195158184\n",
" -0.039129609402048404"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = (∂G_∂debris(stocks,debris,launches,bm) .* ones(1,5)) * W_∂debris"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "dd5ca5b4-c910-449c-a322-b24b3ef4ecf2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"49.45445380487922"
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss(stocks,debris,payoff,launches)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "8ead6e86-3d03-48ff-bc3e-b3fa9808f981",
"metadata": {},
"outputs": [],
"source": [
"#TODO: create a launch model in flux and see if I can get it to do pullbacks"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "658ba7f0-dfbc-42da-bb20-7f8bccbb257d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.6.2",
"language": "julia",
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,327 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0b5021da-575c-4db3-9e01-dc043a7c64b3",
"metadata": {},
"outputs": [],
"source": [
"using DiffEqFlux,Flux,Zygote, LinearAlgebra"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "32ca6032-9d48-4bb2-b16e-4a66473464cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const N_constellations = 1\n",
"const N_debris = 1\n",
"const N_states= N_constellations + N_debris"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "77213ba3-1645-45b2-903f-b7f2817cbb47",
"metadata": {},
"outputs": [],
"source": [
"#setup physical model\n",
"struct BasicModel\n",
" #rate at which debris hits satellites\n",
" debris_collision_rate\n",
" #rate at which satellites of different constellations collide\n",
" satellite_collision_rates\n",
" #rate at which debris exits orbits\n",
" decay_rate\n",
" #rate at which satellites\n",
" autocatalysis_rate\n",
" #ratio at which a collision between satellites produced debris\n",
" satellite_collision_debris_ratio\n",
" #Ratio at which launches produce debris\n",
" launch_debris_ratio\n",
"end\n",
"\n",
"#Getting loss parameters together.\n",
"loss_param = 2e-3;\n",
"loss_weights = loss_param*(ones(N_constellations,N_constellations) - I);\n",
"\n",
"#orbital decay rate\n",
"decay_param = 0.01;\n",
"\n",
"#debris generation parameters\n",
"autocatalysis_param = 0.001;\n",
"satellite_loss_debris_rate = 5.0;\n",
"launch_debris_rate = 0.05;\n",
"\n",
"#Todo, wrap physical model as a struct with the parameters\n",
"bm = BasicModel(\n",
" loss_param\n",
" ,loss_weights\n",
" ,decay_param\n",
" ,autocatalysis_param\n",
" ,satellite_loss_debris_rate\n",
" ,launch_debris_rate\n",
");\n",
"\n",
"#implement tranistion function\n",
"#percentage survival function\n",
"function survival(stocks,debris,physical_model) \n",
" exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris));\n",
"end\n",
"\n",
"#stock update rules\n",
"function G(stocks,debris,launches, physical_model)\n",
" return diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches\n",
"end;\n",
"\n",
"\n",
"#debris evolution \n",
"function H(stocks,debris,launches,physical_model)\n",
" #get changes in debris from natural dynamics\n",
" natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris \n",
" \n",
" #get changes in debris from satellite loss\n",
" satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks \n",
" \n",
" #get changes in debris from launches\n",
" launch_debris = physical_model.launch_debris_ratio*sum(launches)\n",
" \n",
" #return total debris level\n",
" return natural_debris_dynamics + satellite_loss_debris + launch_debris\n",
"end;\n",
"\n",
"\n",
"#implement reward function\n",
"const payoff = 3*I - 0.02*ones(N_constellations,N_constellations)\n",
"\n",
"#Define the market profit function\n",
"F(stocks,debris,launches) = payoff*stocks + 3.0*launches .+ (debris*-0.2)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "998a1ce8-a6ba-427d-a5d1-fece358146da",
"metadata": {},
"outputs": [],
"source": [
"# Launch function\n",
"launches = Chain(\n",
" Parallel(vcat\n",
" #parallel joins together stocks and debris, along with intermediate interpretation\n",
" ,Chain(Dense(N_constellations, N_states*2,relu)\n",
" ,Dense(N_states*2, N_states*2,relu)\n",
" )\n",
" ,Chain(Dense(N_debris, N_states,relu)\n",
" ,Dense(N_states, N_states,relu)\n",
" )\n",
" #chain gets applied to parallel\n",
" ,Dense(N_states*3,128,relu)\n",
" #,Dense(128,128,relu)\n",
" ,Dense(128,N_constellations,relu)\n",
" )\n",
");\n",
"\n",
"#Value functions\n",
"∂value = Chain(\n",
" Parallel(vcat\n",
" #parallel joins together stocks and debris, along with intermediate interpretation\n",
" ,Chain(Dense(N_constellations, N_states*2,relu)\n",
" ,Dense(N_states*2, N_states*2,relu)\n",
" )\n",
" ,Chain(Dense(N_debris, N_states,relu)\n",
" ,Dense(N_states, N_states,relu)\n",
" )\n",
" #chain gets applied to parallel\n",
" ,Dense(N_states*3,128,relu)\n",
" #,Dense(128,128,relu)\n",
" ,Dense(128,N_states,relu)\n",
" )\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b4409af2-7f41-45bc-b7eb-4bda019e4092",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1-element Vector{Float64}:\n",
" 0.0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Extract parameter sets\n",
"\n",
"#= initialize Algorithm Parameters\n",
"Chose these randomly\n",
"=#\n",
"λʷ = 0.5\n",
"αʷ = 5.0\n",
"λᶿ = 0.5\n",
"αᶿ = 5.0\n",
"αʳ = 10\n",
"\n",
"# initialitze averaging returns\n",
"r = zeros(N_constellations)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0529c209-55c0-49c7-815b-47578b029593",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1-element Vector{Int64}:\n",
" 3"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initial states\n",
"S₀ = rand(1:5,N_constellations)\n",
"D₀ = rand(1:3, N_debris)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ae7d4152-77b0-42ff-92f6-9d5d83d6a39d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"launch_params = Flux.params(launches)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "232c0a44-be74-4431-a86e-dbc71e83c17a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"loss (generic function with 1 method)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function loss(stocks,debris)\n",
" sum(launches((stocks,debris)))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e1e1cdef-b164-43b6-a80e-b8665bdf9b14",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Grads(...)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g = Flux.gradient(() -> loss(S₀,D₀), launch_params)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "35d1763b-4650-4916-957d-fbb436280e1f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"launch_params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eaad2871-54ed-4674-8405-d4ebb950851d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.6.2",
"language": "julia",
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save