From 66a237cd192f849432b780153b9a04d563ac5df1 Mon Sep 17 00:00:00 2001 From: youainti Date: Tue, 2 Nov 2021 16:24:57 -0700 Subject: [PATCH] current work, including a whole lot of julia stuff (Pluto.jl FTW) --- .gitignore | 1 + Code/Untitled.ipynb | 219 ++-- julia_code/ActorCritic.jl | 945 +++++++++++++++++ julia_code/DerivativesOfProfit_enzyme.ipynb | 262 +++++ julia_code/MWE_MutabilityErrors.jl | 1002 +++++++++++++++++++ julia_code/Maliar_attempt.jl | 925 +++++++++++++++++ julia_code/Untitled1.ipynb | 617 ++++++++++++ julia_code/UsingFlux.ipynb | 327 ++++++ 8 files changed, 4167 insertions(+), 131 deletions(-) create mode 100644 julia_code/ActorCritic.jl create mode 100644 julia_code/DerivativesOfProfit_enzyme.ipynb create mode 100644 julia_code/MWE_MutabilityErrors.jl create mode 100644 julia_code/Maliar_attempt.jl create mode 100644 julia_code/Untitled1.ipynb create mode 100644 julia_code/UsingFlux.ipynb diff --git a/.gitignore b/.gitignore index 85fa8fc..997c58c 100644 --- a/.gitignore +++ b/.gitignore @@ -300,4 +300,5 @@ TSWLatexianTemp* *.pdf #Don't track python/jupyterlab stuff */.ipynb_checkpoints/* +.ipynb_checkpoints/* */__pycache__/* diff --git a/Code/Untitled.ipynb b/Code/Untitled.ipynb index 3c3b01d..19d51bb 100644 --- a/Code/Untitled.ipynb +++ b/Code/Untitled.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "royal-trace", + "id": "operating-illinois", "metadata": {}, "outputs": [], "source": [ @@ -14,26 +14,10 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "atlantic-finish", + "execution_count": 25, + "id": "white-lottery", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[8., 5., 3.]],\n", - "\n", - " [[3., 6., 6.]],\n", - "\n", - " [[3., 7., 2.]],\n", - "\n", - " [[4., 8., 2.]],\n", - "\n", - " [[0., 6., 8.]]], grad_fn=) torch.Size([5, 1, 3])\n" - ] - } - ], + "outputs": [], "source": [ "BATCH_SIZE = 5\n", "STATES = 3\n", @@ -42,25 +26,14 @@ "FEATURES = 1\n", "\n", "stocks = torch.randint(MAX,(BATCH_SIZE,1,CONSTELLATIONS), dtype=torch.float32, requires_grad=True)\n", - "debris = torch.randint(MAX,(BATCH_SIZE,1,1), dtype=torch.float32, requires_grad=True)\n", - "\n", - "s = c.States(stocks, debris)\n", - "\n", - "print(s.values,s.values.shape)" + "debris = torch.randint(MAX,(BATCH_SIZE,1), dtype=torch.float32, requires_grad=True)\n", + "\n" ] }, { "cell_type": "code", - "execution_count": null, - "id": "prostate-liverpool", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "simplified-permission", + "execution_count": 91, + "id": "quick-extraction", "metadata": {}, "outputs": [ { @@ -68,53 +41,100 @@ "output_type": "stream", "text": [ "tensor([[[[1.],\n", - " [0.]],\n", - "\n", - " [[0.],\n", - " [1.]]]]) torch.Size([1, 2, 2, 1])\n", - "tensor([[ 1.0000, 0.0000],\n", - " [ 0.0000, 1.0000],\n", - " [-0.2000, -0.2000]]) torch.Size([3, 2])\n" + " [0.]]]], requires_grad=True) torch.Size([1, 1, 2, 1])\n", + "tensor([[ 1.0000, -0.1000]], requires_grad=True) torch.Size([1, 2])\n" ] } ], "source": [ "#launch_costs = torch.randint(3,(1,CONSTELLATIONS,CONSTELLATIONS,FEATURES), dtype=torch.float32)\n", - "launch_costs = torch.tensor([[[[1.0],[0]],[[0.0],[1]]]])\n", + "launch_costs = torch.tensor([[[[1.0],[0.0]]]], requires_grad=True)\n", "print(launch_costs, launch_costs.shape)\n", "#payoff = torch.randint(5,(STATES,CONSTELLATIONS), dtype=torch.float32)\n", - "payoff = torch.tensor([[1.0, 0],[0,1.0],[-0.2,-0.2]])\n", - "print(payoff, payoff.shape)" + "payoff = torch.tensor([[1.0, -0.1]], requires_grad=True)\n", + "print(payoff, payoff.shape)\n", + "\n", + "debris_cost = -0.2" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "level-angle", + "execution_count": 92, + "id": "textile-cleanup", "metadata": {}, "outputs": [], "source": [ - "def linear_profit(states, choices):\n", + "def linear_profit(stocks, debris, choices,constellation_number):\n", " #Pay particular attention to the dimensions\n", " #note that there is an extra dimension in there just ot match that of the profit vector we'll be giving out.\n", " \n", " #calculate launch expenses\n", " \n", - " launch_expense = torch.tensordot(choices,launch_costs, [[-2,-1],[-2,-1]])\n", + " launch_expense = (-5 * output)[:,constellation_number,:]\n", "\n", " #calculate revenue\n", "\n", - " revenue = torch.tensordot(s.values, payoff, [[-1],[0]])\n", + " revenue = (payoff * stocks).sum(dim=2)\n", + " \n", + " debris_costs = debris * debris_cost \n", "\n", "\n", - " profit = revenue - launch_expense\n", + " profit = (revenue + debris_costs + launch_expense).sum(dim=1)\n", " return profit" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "copyrighted-acting", + "execution_count": 100, + "id": "single-wheat", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([ 3.2451, 4.3734, 6.5474, -0.2722, -2.8843], grad_fn=),\n", + " torch.Size([5]))" + ] + }, + "execution_count": 100, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "profit = linear_profit(stocks, debris, output,0)\n", + "profit, profit.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "handy-perry", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[-0.2000],\n", + " [-0.0000],\n", + " [-0.0000],\n", + " [-0.0000],\n", + " [-0.0000]]),)" + ] + }, + "execution_count": 123, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.autograd.grad(profit[0], (debris), create_graph=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "purple-superior", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +148,7 @@ }, { "cell_type": "markdown", - "id": "casual-career", + "id": "auburn-leonard", "metadata": {}, "source": [ "example to get profit = 1\n", @@ -159,30 +179,30 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "straight-negative", + "execution_count": 96, + "id": "herbal-manual", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[0.0000],\n", - " [0.0000]],\n", + "tensor([[[0.2910],\n", + " [0.4003]],\n", "\n", - " [[0.0000],\n", - " [0.0000]],\n", + " [[0.1053],\n", + " [0.2446]],\n", "\n", - " [[0.0000],\n", - " [0.0000]],\n", + " [[0.1705],\n", + " [0.2758]],\n", "\n", - " [[0.0000],\n", - " [0.0000]],\n", + " [[0.1944],\n", + " [0.3421]],\n", "\n", - " [[0.3742],\n", - " [0.0000]]], grad_fn=)" + " [[0.5369],\n", + " [0.6181]]], grad_fn=)" ] }, - "execution_count": 6, + "execution_count": 96, "metadata": {}, "output_type": "execute_result" } @@ -192,73 +212,10 @@ "output" ] }, - { - "cell_type": "code", - "execution_count": 7, - "id": "independent-deficit", - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.ones_like(output, requires_grad=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "romance-force", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "element 0 of tensors does not require grad and does not have a grad_fn", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m#this is where I lose the gradient. This is where I need a gradient so that I can call .backward below\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mtest_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n", - "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn" - ] - } - ], - "source": [ - "def test_loss(options):\n", - " return torch.autograd.functional.jacobian(linear_profit, (s.values, options))[0].sum()\n", - " #something is off here ^\n", - " #this is where I lose the gradient. This is where I need a gradient so that I can call .backward below\n", - "\n", - "test_loss(output).backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "id": "asian-death", - "metadata": {}, - "outputs": [ - { - "ename": "RuntimeError", - "evalue": "element 0 of tensors does not require grad and does not have a grad_fn", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/miniconda3/envs/pytorch-CPU/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 145\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n", - "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn" - ] - } - ], - "source": [] - }, { "cell_type": "code", "execution_count": null, - "id": "prospective-nelson", + "id": "another-timing", "metadata": {}, "outputs": [], "source": [] diff --git a/julia_code/ActorCritic.jl b/julia_code/ActorCritic.jl new file mode 100644 index 0000000..c98abce --- /dev/null +++ b/julia_code/ActorCritic.jl @@ -0,0 +1,945 @@ +### A Pluto.jl notebook ### +# v0.17.0 + +using Markdown +using InteractiveUtils + +# ╔═╡ cc16838c-3b25-11ec-2489-11c3d35f26f4 +using Flux,LinearAlgebra,Zygote ,PlutoUI + +# ╔═╡ 458bb826-4eaf-42ca-b889-4d1c50a2ecae +using Flux.Optimise: AbstractOptimiser + +# ╔═╡ 7d6afccb-c45e-4371-827a-43d588d4945f +md""" +# Actor Critic Model +This is an implementation of an optimizer for an actor critic model. + +""" + +# ╔═╡ a810a371-d2e4-4d4f-9cf6-1b98582c4b0f +md""" +## Physical Model +""" + +# ╔═╡ 2e59f530-aeb0-4ce1-9312-e81a66520181 +abstract type AbstractPhysicalModel end + +# ╔═╡ c57e9967-75e8-47d2-836e-f28159dc6592 +#setup physical model +struct BasicModel <: AbstractPhysicalModel + #rate at which debris hits satellites + debris_collision_rate::Real + #rate at which satellites of different constellations collide + satellite_collision_rates::Matrix{Float64} + #rate at which debris exits orbits + decay_rate::Real + #rate at which satellites + autocatalysis_rate::Real + #ratio at which a collision between satellites produced debris + satellite_collision_debris_ratio::Real + #Ratio at which launches produce debris + launch_debris_ratio::Real +end + +# ╔═╡ 4153bfd6-2ec6-4025-93ba-8e6e061809e8 +md""" +## Setup NeuralNets +""" + +# ╔═╡ bbc7e4b0-5edb-4e5d-9512-1aaeae1bb4ec +md""" +Custom function to zero out parameter traces +""" + +# ╔═╡ 533250c6-fd76-47ef-a561-d567da04ae61 + + +# ╔═╡ a305c33f-3388-4120-b036-3093c3ed6aa3 +md""" +## setup Actor Critic struct and loop +""" + +# ╔═╡ 0d375c34-3226-40b6-adaf-2f531c00d3ac +function zero(a::Flux.Params) + Flux.Params([Base.zero(x) for x=a]) +end + +# ╔═╡ e79e4ecf-caec-4c1a-9626-0c1a8b5006ce +mutable struct ActorCritic + #parameters + λʷ::Real + λᶿ::Real + αʷ::Real + αᶿ::Real + αᴿ::Real + + #keep track of eligibility traces + zᶿ::Flux.Params + zʷ::Flux.Params + + #keep track of update rate + R̄::AbstractFloat + + #Inside generator + ActorCritic( + λʷ::Real + ,λᶿ::Real + ,αʷ::Real + ,αᶿ::Real + ,αᴿ::Real + ,θ::Flux.Params + ,w::Flux.Params + ,R::Real + ) = + begin + zθ = zero(θ) #custom zero handles params + zw = zero(w) #custom zero handles params + new(λʷ,λᶿ,αʷ,αᶿ,αᴿ,zθ,zw,R) + end +end + +# ╔═╡ 0ac4ac61-0ca2-4ed6-8855-85c9d1939afa +md""" +# Example use: Model setup +""" + +# ╔═╡ 7102c12c-581b-4204-9986-741889cb03f6 +#implement tranistion function +#percentage survival function +function survival( + stocks + ,debris + ,physical_model::AbstractPhysicalModel + ) + return exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris)) +end + +# ╔═╡ 283bdd91-d742-4f4b-a19b-2756ac01ce2c +#stock update rules +function G( + stocks::Vector + ,debris::Vector + ,launches::Vector + , physical_model::AbstractPhysicalModel +) + return LinearAlgebra.diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches +end + +# ╔═╡ 8a3f9518-ae0a-46f3-b37d-1017ce99c70d +#debris evolution +function H(stocks,debris,launches,physical_model) + #get changes in debris from natural dynamics + natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris + + #get changes in debris from satellite loss + satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks + + #get changes in debris from launches + launch_debris = physical_model.launch_debris_ratio*sum(launches) + + #return total debris level + return natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris +end + + +# ╔═╡ e4c8abfa-e96c-44b7-9858-af46adc1164a +#implement reward function +begin + const payoff = 3*LinearAlgebra.I #- 0.02*ones(N_constellations,N_constellations) + #Define the market profit function + F(stocks,debris,launches) = (stocks - 3.0*launches .+ (debris*-0.2))[1] +end + +# ╔═╡ 42be8352-7ecf-4da4-bff0-61e328144ed1 +#test +F(4,2,2) + +# ╔═╡ e2135298-e280-40ac-a549-6a15f13b84ee +md""" +## Example use: Single actor. + +### Model and Optimizer Parameterization +""" + +# ╔═╡ 5a2dab62-985c-4689-ab48-da04981d15a0 +#Model shape +begin + const N_constellations = 1; + const N_debris = 1; + const N_states = N_constellations + N_debris; +end + +# ╔═╡ 3f14b399-cae8-433a-8729-fc91c9b0bee9 +# Launch function +launch_policy = Flux.Chain( + Flux.Parallel(vcat + #parallel joins together stocks and debris, along with intermediate interpretation + ,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu) + ,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ) + ,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu) + ,Flux.Dense(N_states, N_states,Flux.σ) + ) + ) + #Apply some transformations + ,Flux.Dense(N_states*3,128,Flux.σ) + ,Flux.Dense(128,128,Flux.σ) + ,Flux.Dense(128,N_constellations,Flux.relu) + +) + +# ╔═╡ 89dabcc8-7f4e-4df4-8445-eae9f50ebfea +#inspect parameters +Flux.params(launch_policy) + +# ╔═╡ 15d5ce62-135d-493e-abad-3ec1f99bfd3b +#Test the function above +zero(Flux.params(launch_policy)) + +# ╔═╡ 2c9346eb-048d-4667-b22b-583b814ae1a3 +# Launch function +value = Flux.Chain( + Flux.Parallel(vcat + #parallel joins together stocks and debris, along with intermediate interpretation + ,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu) + ,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ) + ,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu) + ,Flux.Dense(N_states, N_states,Flux.σ) + ) + ) + #Apply some transformations + ,Flux.Dense(N_states*3,128,Flux.σ) + ,Flux.Dense(128,128,Flux.σ) + ,Flux.Dense(128,1) + +); + +# ╔═╡ 95b9943b-8c13-4717-8239-2bae9f93ee63 +#create the current actor critic optimizer +optim = ActorCritic( + 0.5 + ,0.5 + ,2.0 + ,2.0 + ,2.0 + ,Flux.params(launch_policy) + ,Flux.params(value) + ,0.5 +) + +# ╔═╡ 43704f08-6f1a-47bb-87e3-483a529b1654 +#Starting States +begin + stock_state = ones(N_constellations); + debris_state = ones(N_debris); +end; + +# ╔═╡ 6f3649c1-2ade-437d-81cb-d9f46306d1e0 +#check launch policy +launch_policy((stock_state,debris_state)) + +# ╔═╡ 7dfc61a9-569a-4e97-a292-4c2b15a78e3a +value((stock_state,debris_state)) + +# ╔═╡ 2f0ff236-0dff-4cd8-bbb4-f1e81ecef308 +#= +Setup the physical model + +These values are just guesstimates +=# + +begin +#Getting loss parameters together. +loss_param = 2e-3; +loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I); + +#orbital decay rate +decay_param = 0.01; + +#debris generation parameters +autocatalysis_param = 0.001; +satellite_loss_debris_rate = 5.0; +launch_debris_rate = 0.05; + +#Todo, wrap physical model as a struct with the parameters +bm = BasicModel( + loss_param + ,loss_weights + ,decay_param + ,autocatalysis_param + ,satellite_loss_debris_rate + ,launch_debris_rate +); +end + + +# ╔═╡ df771ced-e541-4bb2-b6a3-bb562bbaa89c +md""" +### Evaluation loop +""" + +# ╔═╡ 6d9a71ea-b9d1-4efd-aa32-9053b96a66e1 +#=Actor Critic Loop +This iterates on the actor critic loop, +=# +with_terminal() do +for iit in 1:12 + w = Flux.params(value) + θ = Flux.params(launch_policy) + + #Get learning parameters + action = launch_policy((stock_state,debris_state)) + new_stock_state = G(stock_state,debris_state,action,bm) + new_debris_state = H(stock_state,debris_state,action,bm) + println(new_debris_state) + + #get current value + current_value = value((stock_state,debris_state)) + new_value = value((new_stock_state,new_debris_state)) + + #need to define + R = F(stock_state,debris_state,action) + + #FIX: R is a vector, so are the values. This needs fixed. + δ = ((R .- optim.R̄) .+ (new_value .- current_value))[1] #fix + + #store values + optim.R̄ = optim.R̄ + δ*optim.αᴿ + + #check for exit conditions. + #probably use grad=0 + continue #issue in calculating gradients + #update the learning traces + optim.zʷ = optim.λʷ .* optim.zʷ .+ Flux.gradient(value,w) + optim.zᶿ = optim.λᶿ .* optim.zᶿ .+ Flux.gradient(policy,θ) + + continue + #update the policies + w = w .+ δ* optim.αʷ .* optim.zʷ; + θ = θ .+ δ* optim.αᶿ .* optim.zᶿ; + + + println("its working, $iit") +end +end + +# ╔═╡ 9941e984-e51b-4e64-8d16-13c8071363b0 +#TODO:issue is here, taking gradient. I get to figure this out next. +Zygote.gradient(() -> sum(value(stock_state,debris_state),w)) + +# ╔═╡ dd9f5f04-31e4-441e-9f5f-3019a30692d6 +with_terminal() do + x = rand(Float32, 10) + m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax) + l(x) = Flux.Losses.crossentropy(m(x), [0.5, 0.5]) + grads = gradient(params(m)) do + l(x) + end + for p in params(m) + println(grads[p]) + end +end + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Flux = "~0.12.8" +PlutoUI = "~0.7.17" +Zygote = "~0.6.29" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.0.1" + +[[AbstractTrees]] +git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.4" + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.1" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[ArrayInterface]] +deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "d9352737cef8525944bf9ef34392d756321cbd54" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.38" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[CEnum]] +git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.1" + +[[CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "2c8329f16addffd09e6ca84c556e2185a4933c64" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.5.0" + +[[ChainRules]] +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] +git-tree-sha1 = "035ef8a5382a614b2d8e3091b6fdbb1c2b050e11" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.12.1" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.11.0" + +[[CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.0" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.8" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.40.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[DataAPI]] +git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.9.0" + +[[DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.10" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.3.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[ExprTools]] +git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.6" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.12.7" + +[[FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "ArrayInterface", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NNlibCUDA", "Pkg", "Printf", "Random", "Reexport", "SHA", "SparseArrays", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] +git-tree-sha1 = "e8b37bb43c01eed0418821d1f9d20eca5ba6ab21" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.12.8" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "63777916efbcb0ab6173d09a658fb7f2783de485" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.21" + +[[Functors]] +git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.2.7" + +[[GPUArrays]] +deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "7772508f17f1d482fe0df72cabc5b55bec06bbe0" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.1.2" + +[[GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "77d915a0af27d474f0aaf12fcd46c400a552e84c" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.13.7" + +[[Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "8d511d5b81240fc8e6802386302675bdf47737b9" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.4" + +[[HypertextLiteral]] +git-tree-sha1 = "5efcf53d798efede8fee5b2c8b09284be359bf24" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.2" + +[[IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.2" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.3" + +[[IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.1" + +[[IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.2" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile"] +git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.8.4" + +[[LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.6.0" + +[[LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.11+0" + +[[LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.4" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.7.29" + +[[NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.1.9" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[Parsers]] +deps = ["Dates"] +git-tree-sha1 = "d911b6a12ba974dabe2291c6d450094a7226b372" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.1.1" + +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PlutoUI]] +deps = ["Base64", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"] +git-tree-sha1 = "615f3a1eff94add4bca9476ded096de60b46443b" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.17" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Random123]] +deps = ["Libdl", "Random", "RandomNumbers"] +git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.4.2" + +[[RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "f0bccf98e16759818ffc5d97ac3ebf87eb950150" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.8.1" + +[[Static]] +deps = ["IfElse"] +git-tree-sha1 = "e7bc80dc93f50857a5d1e3c8121495852f407e6a" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.4.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.13" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.12" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.13" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.6" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.9.4" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "0fc9959bcabc4668c403810b4e851f6b8962eac9" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.29" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +""" + +# ╔═╡ Cell order: +# ╟─7d6afccb-c45e-4371-827a-43d588d4945f +# ╠═cc16838c-3b25-11ec-2489-11c3d35f26f4 +# ╟─a810a371-d2e4-4d4f-9cf6-1b98582c4b0f +# ╠═2e59f530-aeb0-4ce1-9312-e81a66520181 +# ╠═c57e9967-75e8-47d2-836e-f28159dc6592 +# ╟─4153bfd6-2ec6-4025-93ba-8e6e061809e8 +# ╟─bbc7e4b0-5edb-4e5d-9512-1aaeae1bb4ec +# ╠═533250c6-fd76-47ef-a561-d567da04ae61 +# ╠═3f14b399-cae8-433a-8729-fc91c9b0bee9 +# ╠═2c9346eb-048d-4667-b22b-583b814ae1a3 +# ╠═89dabcc8-7f4e-4df4-8445-eae9f50ebfea +# ╠═6f3649c1-2ade-437d-81cb-d9f46306d1e0 +# ╠═7dfc61a9-569a-4e97-a292-4c2b15a78e3a +# ╟─a305c33f-3388-4120-b036-3093c3ed6aa3 +# ╠═458bb826-4eaf-42ca-b889-4d1c50a2ecae +# ╠═0d375c34-3226-40b6-adaf-2f531c00d3ac +# ╠═15d5ce62-135d-493e-abad-3ec1f99bfd3b +# ╠═e79e4ecf-caec-4c1a-9626-0c1a8b5006ce +# ╟─0ac4ac61-0ca2-4ed6-8855-85c9d1939afa +# ╠═7102c12c-581b-4204-9986-741889cb03f6 +# ╠═283bdd91-d742-4f4b-a19b-2756ac01ce2c +# ╠═8a3f9518-ae0a-46f3-b37d-1017ce99c70d +# ╠═e4c8abfa-e96c-44b7-9858-af46adc1164a +# ╠═42be8352-7ecf-4da4-bff0-61e328144ed1 +# ╟─e2135298-e280-40ac-a549-6a15f13b84ee +# ╠═95b9943b-8c13-4717-8239-2bae9f93ee63 +# ╠═5a2dab62-985c-4689-ab48-da04981d15a0 +# ╠═43704f08-6f1a-47bb-87e3-483a529b1654 +# ╠═2f0ff236-0dff-4cd8-bbb4-f1e81ecef308 +# ╟─df771ced-e541-4bb2-b6a3-bb562bbaa89c +# ╠═6d9a71ea-b9d1-4efd-aa32-9053b96a66e1 +# ╠═9941e984-e51b-4e64-8d16-13c8071363b0 +# ╠═dd9f5f04-31e4-441e-9f5f-3019a30692d6 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 diff --git a/julia_code/DerivativesOfProfit_enzyme.ipynb b/julia_code/DerivativesOfProfit_enzyme.ipynb new file mode 100644 index 0000000..61d207c --- /dev/null +++ b/julia_code/DerivativesOfProfit_enzyme.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "41dcca64-963f-488e-b92e-f1dc5109359a", + "metadata": {}, + "outputs": [], + "source": [ + "using Enzyme" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ca966ab8-469e-4f8c-af54-579c55f54bd4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "()" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "function mymul!(R, A, B)\n", + " @assert axes(A,2) == axes(B,1)\n", + " @inbounds @simd for i in eachindex(R)\n", + " R[i] = 0\n", + " end\n", + " @inbounds for j in axes(B, 2), i in axes(A, 1)\n", + " @inbounds @simd for k in axes(A,2)\n", + " R[i,j] += A[i,k] * B[k,j]\n", + " end\n", + " end\n", + " nothing\n", + "end\n", + "\n", + "\n", + "A = rand(5, 3)\n", + "B = rand(3, 7)\n", + "\n", + "R = zeros(size(A,1), size(B,2))\n", + "∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n", + "\n", + "∂z_∂A = zero(A)\n", + "∂z_∂B = zero(B)\n", + "\n", + "Enzyme.autodiff(mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7442bfb5-3146-493d-9abc-9afeb56c0471", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "true" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "R ≈ A * B &&\n", + "∂z_∂A ≈ ∂z_∂R * B' && # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]\n", + "∂z_∂B ≈ A' * ∂z_∂R # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "36a4cd0f-c5e2-4a6f-b434-2d347686a08b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3×7 Matrix{Float64}:\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#reset\n", + "R = zeros(size(A,1), size(B,2))\n", + "∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n", + "\n", + "∂z_∂A = zero(A)\n", + "∂z_∂B = zero(B)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "49cbf4e1-1ef0-4428-90af-02ac2b46c2a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "revenue! (generic function with 1 method)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function revenue!(R, A, B)\n", + " @assert axes(A,2) == axes(B,1)\n", + " @inbounds @simd for i in eachindex(R)\n", + " R[i] = 0\n", + " end\n", + " @inbounds for j in axes(B, 2), i in axes(A, 1)\n", + " @inbounds @simd for k in axes(A,2)\n", + " R[i,j] += A[i,k] * B[k,j]\n", + " end\n", + " end\n", + " nothing\n", + "end\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5fea97a5-39ac-4ade-9cb4-4aba66a80825", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 5;\n", + "constellations = 2;\n", + "payoff_mat = zeros(batch_size,1);\n", + "\n", + "stocks = rand(batch_size, constellations);\n", + "\n", + "payoffs = rand(constellations,1);" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d06ec650-621c-4a97-8bbd-0bfcfd4ab8d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×1 Matrix{Float64}:\n", + " 0.05943992677268309\n", + " 0.16746133343364858\n", + " 0.22311130107900645\n", + " 0.1326381498910713\n", + " 0.23997313509634804" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "revenue!(payoff_mat, stocks,payoffs)\n", + "\n", + "payoff_mat" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "86425c6b-baa3-494d-8c2c-35eaf08cefaf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×1 Matrix{Float64}:\n", + " 1.0\n", + " 1.0\n", + " 1.0\n", + " 1.0\n", + " 1.0" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "payoff_mat = zero(payoff_mat)\n", + "∂payoff_mat = ones(size(payoff_mat)...)\n", + "\n", + "∂stocks = zero(stocks)\n", + "∂payoff_mat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbb5fa3f-b3c6-48bf-9df2-657d5d7aae18", + "metadata": {}, + "outputs": [], + "source": [ + "autodiff(revenue!\n", + " ,Duplicated(payoff_mat, ∂payoff_mat)\n", + " ,Const(payoffs)\n", + " ,Duplicated(stocks,∂stocks)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0872e372-f3d1-4fca-b89f-f134a3dc563d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "984b1bd7-e53a-41c3-bafd-f56aacdae4b7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.6.2", + "language": "julia", + "name": "julia-1.6" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/julia_code/MWE_MutabilityErrors.jl b/julia_code/MWE_MutabilityErrors.jl new file mode 100644 index 0000000..d18e254 --- /dev/null +++ b/julia_code/MWE_MutabilityErrors.jl @@ -0,0 +1,1002 @@ +### A Pluto.jl notebook ### +# v0.17.0 + +using Markdown +using InteractiveUtils + +# ╔═╡ 20b324a4-3b6c-11ec-3673-d98ec8af9009 +import Zygote, LinearAlgebra,Flux,BenchmarkTools,PlutoUI + +# ╔═╡ 28c0d31f-e90a-415c-b35e-312bdf771ddf +md""" +# Background + +I am trying to use the approach of +[Maliar et al](https://notes.quantecon.org/submission/5ddb3c926bad3800109084bf) +in solving general equilibrium models (DSGE models as they are typically called). +The goal is to approximate a policy function describing optimal behavior based on the current states. + + +The basic approach is as follows: + 1. Define a bellman equation with constraints. + 1. Use calculus and the envelope theorem to transform it into a set of +equations (euler equations and KKT conditions) that implicitly defines the policy function. + 1. Use the residuals from the euler equations to create a loss function. + 1. Use the loss function to calibrate (train) a functional approximation (NN is the cannonical example) of the policy function. + +In the original paper, they completed steps 1 through 3 algebraically, and hand coded step 3. +TensorFlow was used to complete step 4. + +I have completed steps 1 and 2 mathematically and found that the +euler equations can be written as functions of derivatives of the constraints. +This gives a very simple way to generate the Euler Equations. +I am currently working on implementing step 2 computationally. +In order to generalize the work of Maliar et al, I was planning on using +differentiable programming to construct jacobians of the constraints. + +### The mathematical model +The Bellman Equation with constraints is: + +$V(\theta) = \max_X F(\theta,X(\theta)) + \beta V(\theta^\prime)$ + +where + +$\theta^\prime = T(\theta,X(\theta))$ + +The two parts of the residuals can be written as + + 1. $\vec{0} = \frac{\partial F}{\partial \theta}(\theta,X(\theta)) + \beta \frac{\partial T}{\partial \theta}(\theta, X(\theta)) \cdot \frac{\partial V}{\partial \theta}(\theta^\prime) - \frac{\partial V}{\partial \theta}(\theta))$ + 1. $\vec{0} = \frac{\partial F}{\partial X}(\theta,X(\theta)) + \beta \frac{\partial T}{\partial X}(\theta, X(\theta)) \cdot \frac{\partial V}{\partial \theta}(\theta^\prime)$ + +I am planning on training two NN. + - $X: \theta \rightarrow \text{actions}$ + - $\partial V: \theta \rightarrow R^{|\theta|}$ + +### Where I am running into problems +I have constructed one of the residual functions that I need, and wrapped it into +a loss function. +I am now testing whether or not I can train $\partial V$ (given a fixed policies) +on the first of the conditions specified above. + +When I try to calculate the gradients, I get a bunch of nothings. + +# Code description +## Functions + +The physical model defines two major relationships + - How stocks transition between time periods + - How debris transistions between time periods. +These are function `G()` and `H()` respectively. +Because of the differences in how the dynamics work, I wrote these as two different functions. + +Theres is a `survival()` function that is just part of refactoring the code. + +There is also a function `F()` that determine the returns to a given action +in a single step. + +## Neural Networks +Because of the way the constraints are formed, I am approximating two functions. +One is the policy function, and the other is the derivatives of the value function. + +## Stucts +The only custom struct I am using so far is the `BasicModel`, which parameterizes +the physical model + +A similar economic model struct will be used in the future. + +""" + +# ╔═╡ 11a53b0c-465b-4f16-88d2-e0163e471fd6 +#setup physical model +struct BasicModel + #rate at which debris hits satellites + debris_collision_rate::Real + #rate at which satellites of different constellations collide + satellite_collision_rates::Matrix{Float64} + #rate at which debris exits orbits + decay_rate::Real + #rate at which satellites + autocatalysis_rate::Real + #ratio at which a collision between satellites produced debris + satellite_collision_debris_ratio::Real + #Ratio at which launches produce debris + launch_debris_ratio::Real +end + +# ╔═╡ 1d4955f9-1d1c-4572-bc7b-9f002ea54042 +#= +This section just sets information related to model shape +=# +begin + const N_constellations = 3; + const N_debris = 1; + const N_states = N_constellations + N_debris; +end + +# ╔═╡ 6175d40b-0fa6-4cb4-8ef3-9765953ce97e +#= +Setup the physical model, including parameter values. + +These values are just guesstimates +=# + +begin +#Getting loss parameters together. +loss_param = 2e-3; +loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I); + +#orbital decay rate +decay_param = 0.01; + +#debris generation parameters +autocatalysis_param = 0.001; +satellite_loss_debris_rate = 5.0; +launch_debris_rate = 0.05; + +#Todo, wrap physical model as a struct with the parameters +bm = BasicModel( + loss_param + ,loss_weights + ,decay_param + ,autocatalysis_param + ,satellite_loss_debris_rate + ,launch_debris_rate +); +end + +# ╔═╡ 7462740e-23da-4151-81b5-cc2e6cbcf2c8 +#implement survival function +function survival( + stocks::Array + ,debris::Array + ,physical_model::BasicModel + ) + return exp.( + -physical_model.satellite_collision_rates * stocks + .- (physical_model.debris_collision_rate*debris) + ) +end + +# ╔═╡ 1eee99c9-5fef-498d-8a57-9c18e2f1cf49 +#stock update rules +function G( + stocks::Array + ,debris::Array + ,launches::Array + , physical_model::BasicModel +) + return LinearAlgebra.diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches +end + +# ╔═╡ bc3492c6-5c3b-4b08-86d4-dd4026e25655 +#debris evolution +function H( + stocks::Array + ,debris::Array + ,launches::Array + , physical_model::BasicModel + ) + #get changes in debris from natural dynamics + natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris + + #get changes in debris from satellite loss + satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks + + #get changes in debris from launches + launch_debris = physical_model.launch_debris_ratio*sum(launches) + + #return total debris level + return natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris +end + +# ╔═╡ 56314025-fb0a-4e98-8628-091bda52708c +begin + number_params=10 +∂value = Flux.Chain( + Flux.Parallel(vcat + #parallel joins together stocks and debris + ,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu) + #,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ) + ,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu) + #,Flux.Dense(N_states, N_states) + ) + ) + #Apply some transformations + ,Flux.Dense(N_states*3,number_params,Flux.σ) + #,Flux.Dense(number_params,number_params,Flux.σ) + + #Split out into partials related to stocks and debris + ,Flux.Parallel(vcat + ,Flux.Chain( + #Flux.Dense(number_params, number_params,Flux.relu) + #, + Flux.Dense(number_params, N_constellations,Flux.σ) + ) + ,Flux.Chain( + #Flux.Dense(number_params, number_params,Flux.relu) + #, + Flux.Dense(number_params, N_debris) + ) + ) +) +end + +# ╔═╡ f61682fe-da30-459f-b807-4fc3e8f36f32 + + +# ╔═╡ 31237f62-428a-4df3-9caf-1f5c5ade6ade +md""" +### Payout function +""" + +# ╔═╡ 55fd682d-2808-4786-ad52-647ac6892860 +begin + pay_matrix = zeros(N_constellations,N_constellations) + LinearAlgebra.I +end + +# ╔═╡ bf869218-034e-42f9-96c5-4267a8c8fcfb +function F1( + s #stock levels + ,d #debris levels + ,a #actions + #,econ::Int #TODO: a struct with various bits of info + #,taxes::Int #TODO: a struct with info on taxes + #,debris_interactions::Matrix #Ignoring for now + ,n::Matrix #constellation of interest + ) + return n*(pay_matrix*s) - n*5.0*a + #-taxes and econ should be structs with parameters later +end + +# ╔═╡ 4817d4fe-b86c-43e2-94d2-fa6c424fa001 +begin + n1 = [1.0 2 3] + d1 = [-0.002 0] +end + +# ╔═╡ 932b7a31-d5f4-44a3-a6c3-f608003a0a6f +md""" +## Building a function for the transition residuals (transversality conditions) +""" + +# ╔═╡ 29a8f5c6-d5b8-4ec6-95a4-89025245787d +loss(x) = sum(x.^2) + +# ╔═╡ ba7ca719-9687-4f1a-9a36-89b795a1bc13 +md""" +# Getting the gradients of a NN +""" + +# ╔═╡ ee99625e-68bb-4fc8-8664-fe5df6b52755 +md""" +Note how if I calculate the loss, I get no gradients. +""" + +# ╔═╡ 2df69994-fc19-478e-8ae9-20c3d542e2ff +md""" +If I try with the function, I get +""" + +# ╔═╡ c032df17-d414-4222-8634-d27344cbdd15 +md""" +So, what is going on here? + - With the first example + - Why do I get empty gradients/jacobians? + - why doesn't it throw the same error as the transitional residuals case? + - With the second example (transition residuals function) + - Why do I get this mutating error? + - I have not been able to identify why this occurs. +""" + +# ╔═╡ 287ff2f9-3469-4e29-91a3-91ed54109f5a +md""" +## Parameters for testing +""" + +# ╔═╡ d173ec63-b7c3-4042-862c-2626039d6d94 +β = 0.95 + +# ╔═╡ 8eca54ad-f95b-450c-8567-480f0ab7ea19 +function transition_residuals(G,H,F + ,s,s′ #stocks t and t+1 + ,d,d′ #debris '' + ,p #policy + ,n #the constellation we are dealing with + ,physics #Parameters of the physical model of the world + ;econ=0.0,taxes=0.0 #economic parameters (to be incorporated later) and tax policy + ) + + #calculate partials of transition functions + ∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,physics),s)[1] + ∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,physics),d)[1] + ∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,physics),s)[1] + ∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,physics),d)[1] + #concatenate to create iterated vector + ∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s′,d′)) + + #get partials of benefit function + ∂F_∂d = Zygote.jacobian(debris->F(s,debris,p,n1),d)[1] + ∂F_∂s = Zygote.jacobian(stocks->F(stocks,d,p,n1),s)[1] #this probaby should not be transposed + + #concatenate + ∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )' + + #calculate the optimality residuals + ∂F_∂θ + β*∂Vₜ₊₁_∂θ - ∂value((s,d)) +end + +# ╔═╡ 274e3a3f-fe23-4d6f-893c-622db0413af0 +begin + s = [1.0,2,3] + d = [1.0] + p = [1.0,2,3] +end + +# ╔═╡ 5e754d71-f4c3-476c-9656-61f07687a534 +G(s,d,[1,1,3],bm) + +# ╔═╡ 749e6b65-fd8e-45ed-9a2b-f97274549933 +survival(s,d,bm) + +# ╔═╡ d38be057-952a-450e-ba23-bd2f40e79de3 +F1(s,d,p,n1) + +# ╔═╡ affe6dd2-05e2-4852-9ef0-9605bdb5e34b +Zygote.jacobian(stocks -> F1(stocks,d,p,n1),s) + +# ╔═╡ 1c8f2c25-b8da-4e31-adac-ad575203f9cf +begin + #calculate partials of transition functions + ∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,bm),s)[1] + ∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,bm),d)[1] + ∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,bm),s)[1] + ∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,bm),d)[1] + #concatenate to create iterated vector + ∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s,d)) + + #get partials of benefit function + ∂F_∂d = Zygote.jacobian(debris->F1(s,debris,p,n1),d)[1] + ∂F_∂s = Zygote.jacobian(stocks->F1(stocks,d,p,n1),s)[1] #this probaby should not be transposed + + #concatenate + ∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )' + + #calculate the optimality residuals + a = ∂F_∂θ + β * ∂Vₜ₊₁_∂θ - ∂value((s,d)) +end + +# ╔═╡ 48b09bfb-87fe-4974-a4f3-7a6ece521da6 +sum(a.^2) + +# ╔═╡ e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad +loss(a) + +# ╔═╡ 1e0cf592-fd96-4a01-84d1-078301ac7f99 +x = Flux.gradient(() -> loss(a),Flux.params(∂value)) + +# ╔═╡ 8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b +PlutoUI.with_terminal() do + println("beginning") + for pr in Flux.params(∂value) + println(x[pr]) + end + println("beginning") + for pr in x + println(pr) + end +end + +# ╔═╡ 9b59255c-77b4-4df6-a70b-e1fdaab7619b +grads_value1 = Flux.gradient(() -> loss(transition_residuals(G,H,F1,s,s,d,d,p,1,bm)),Flux.params(∂value)) + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +BenchmarkTools = "~1.2.0" +Flux = "~0.12.8" +PlutoUI = "~0.7.17" +Zygote = "~0.6.29" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.0.1" + +[[AbstractTrees]] +git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.4" + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.1" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[ArrayInterface]] +deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "d9352737cef8525944bf9ef34392d756321cbd54" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.38" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "61adeb0823084487000600ef8b1c00cc2474cd47" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.2.0" + +[[CEnum]] +git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.1" + +[[CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "2c8329f16addffd09e6ca84c556e2185a4933c64" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.5.0" + +[[ChainRules]] +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] +git-tree-sha1 = "035ef8a5382a614b2d8e3091b6fdbb1c2b050e11" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.12.1" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.11.0" + +[[CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.0" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.8" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.40.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[DataAPI]] +git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.9.0" + +[[DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.10" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.3.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[ExprTools]] +git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.6" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.12.7" + +[[FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "ArrayInterface", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NNlibCUDA", "Pkg", "Printf", "Random", "Reexport", "SHA", "SparseArrays", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] +git-tree-sha1 = "e8b37bb43c01eed0418821d1f9d20eca5ba6ab21" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.12.8" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "63777916efbcb0ab6173d09a658fb7f2783de485" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.21" + +[[Functors]] +git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.2.7" + +[[GPUArrays]] +deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "7772508f17f1d482fe0df72cabc5b55bec06bbe0" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.1.2" + +[[GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "77d915a0af27d474f0aaf12fcd46c400a552e84c" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.13.7" + +[[Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "8d511d5b81240fc8e6802386302675bdf47737b9" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.4" + +[[HypertextLiteral]] +git-tree-sha1 = "5efcf53d798efede8fee5b2c8b09284be359bf24" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.2" + +[[IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.2" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.3" + +[[IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.1" + +[[IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.2" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile"] +git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.8.4" + +[[LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.6.0" + +[[LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.11+0" + +[[LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.4" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.7.29" + +[[NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.1.9" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[Parsers]] +deps = ["Dates"] +git-tree-sha1 = "d911b6a12ba974dabe2291c6d450094a7226b372" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.1.1" + +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PlutoUI]] +deps = ["Base64", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"] +git-tree-sha1 = "615f3a1eff94add4bca9476ded096de60b46443b" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.17" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Random123]] +deps = ["Libdl", "Random", "RandomNumbers"] +git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.4.2" + +[[RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "f0bccf98e16759818ffc5d97ac3ebf87eb950150" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.8.1" + +[[Static]] +deps = ["IfElse"] +git-tree-sha1 = "e7bc80dc93f50857a5d1e3c8121495852f407e6a" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.4.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.13" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.12" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.13" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.6" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.9.4" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "0fc9959bcabc4668c403810b4e851f6b8962eac9" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.29" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +""" + +# ╔═╡ Cell order: +# ╠═20b324a4-3b6c-11ec-3673-d98ec8af9009 +# ╟─28c0d31f-e90a-415c-b35e-312bdf771ddf +# ╠═11a53b0c-465b-4f16-88d2-e0163e471fd6 +# ╠═1d4955f9-1d1c-4572-bc7b-9f002ea54042 +# ╠═6175d40b-0fa6-4cb4-8ef3-9765953ce97e +# ╠═7462740e-23da-4151-81b5-cc2e6cbcf2c8 +# ╠═1eee99c9-5fef-498d-8a57-9c18e2f1cf49 +# ╠═bc3492c6-5c3b-4b08-86d4-dd4026e25655 +# ╠═5e754d71-f4c3-476c-9656-61f07687a534 +# ╠═749e6b65-fd8e-45ed-9a2b-f97274549933 +# ╠═56314025-fb0a-4e98-8628-091bda52708c +# ╠═f61682fe-da30-459f-b807-4fc3e8f36f32 +# ╟─31237f62-428a-4df3-9caf-1f5c5ade6ade +# ╠═55fd682d-2808-4786-ad52-647ac6892860 +# ╠═bf869218-034e-42f9-96c5-4267a8c8fcfb +# ╠═4817d4fe-b86c-43e2-94d2-fa6c424fa001 +# ╠═d38be057-952a-450e-ba23-bd2f40e79de3 +# ╠═affe6dd2-05e2-4852-9ef0-9605bdb5e34b +# ╠═932b7a31-d5f4-44a3-a6c3-f608003a0a6f +# ╠═8eca54ad-f95b-450c-8567-480f0ab7ea19 +# ╠═1c8f2c25-b8da-4e31-adac-ad575203f9cf +# ╠═48b09bfb-87fe-4974-a4f3-7a6ece521da6 +# ╠═29a8f5c6-d5b8-4ec6-95a4-89025245787d +# ╠═ba7ca719-9687-4f1a-9a36-89b795a1bc13 +# ╠═ee99625e-68bb-4fc8-8664-fe5df6b52755 +# ╠═e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad +# ╠═1e0cf592-fd96-4a01-84d1-078301ac7f99 +# ╠═8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b +# ╠═2df69994-fc19-478e-8ae9-20c3d542e2ff +# ╠═9b59255c-77b4-4df6-a70b-e1fdaab7619b +# ╠═c032df17-d414-4222-8634-d27344cbdd15 +# ╟─287ff2f9-3469-4e29-91a3-91ed54109f5a +# ╠═d173ec63-b7c3-4042-862c-2626039d6d94 +# ╠═274e3a3f-fe23-4d6f-893c-622db0413af0 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 diff --git a/julia_code/Maliar_attempt.jl b/julia_code/Maliar_attempt.jl new file mode 100644 index 0000000..2177092 --- /dev/null +++ b/julia_code/Maliar_attempt.jl @@ -0,0 +1,925 @@ +### A Pluto.jl notebook ### +# v0.17.0 + +using Markdown +using InteractiveUtils + +# ╔═╡ 20b324a4-3b6c-11ec-3673-d98ec8af9009 +import Zygote, LinearAlgebra,Flux,BenchmarkTools,PlutoUI + +# ╔═╡ 64f31ef7-e0ba-4353-8e51-e8356a894656 +abstract type AbstractPhysicalModel end + +# ╔═╡ 11a53b0c-465b-4f16-88d2-e0163e471fd6 +#setup physical model +struct BasicModel <: AbstractPhysicalModel + #rate at which debris hits satellites + debris_collision_rate::Real + #rate at which satellites of different constellations collide + satellite_collision_rates::Matrix{Float64} + #rate at which debris exits orbits + decay_rate::Real + #rate at which satellites + autocatalysis_rate::Real + #ratio at which a collision between satellites produced debris + satellite_collision_debris_ratio::Real + #Ratio at which launches produce debris + launch_debris_ratio::Real +end + +# ╔═╡ 1d4955f9-1d1c-4572-bc7b-9f002ea54042 +begin + const N_constellations = 3; + const N_debris = 1; + const N_states = N_constellations + N_debris; +end + +# ╔═╡ 6175d40b-0fa6-4cb4-8ef3-9765953ce97e +#= +Setup the physical model + +These values are just guesstimates +=# + +begin +#Getting loss parameters together. +loss_param = 2e-3; +loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I); + +#orbital decay rate +decay_param = 0.01; + +#debris generation parameters +autocatalysis_param = 0.001; +satellite_loss_debris_rate = 5.0; +launch_debris_rate = 0.05; + +#Todo, wrap physical model as a struct with the parameters +bm = BasicModel( + loss_param + ,loss_weights + ,decay_param + ,autocatalysis_param + ,satellite_loss_debris_rate + ,launch_debris_rate +); +end + +# ╔═╡ 28c0d31f-e90a-415c-b35e-312bdf771ddf +md""" +# Setup Model Functions + + - Debris Transitions + - Derivatives of debris transitions +""" + +# ╔═╡ 7462740e-23da-4151-81b5-cc2e6cbcf2c8 +#implement tranistion function +#percentage survival function +function survival( + stocks + ,debris + ,physical_model::AbstractPhysicalModel + ) + return exp.( + -physical_model.satellite_collision_rates * stocks + .- (physical_model.debris_collision_rate*debris) + ) +end + +# ╔═╡ 1eee99c9-5fef-498d-8a57-9c18e2f1cf49 +#= Stock levels evolution + +=# +function G( + stocks::Vector + ,debris::Vector + ,launches::Vector + , physical_model::AbstractPhysicalModel +) + return LinearAlgebra.diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches +end + +# ╔═╡ bc3492c6-5c3b-4b08-86d4-dd4026e25655 +#= Debris evolution +The model is + +=# +function H(stocks,debris,launches,physical_model) + #get changes in debris from natural dynamics + natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris + + #get changes in debris from satellite loss + satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks + #Possible Issue: Broadcasts? ^^^ + + #get changes in debris from launches + launch_debris = physical_model.launch_debris_ratio*sum(launches) + #Possible Issue: Sum? ^^^ + + #return total debris level + return natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris +end + +# ╔═╡ 56314025-fb0a-4e98-8628-091bda52708c +begin + number_params=10 +∂value = Flux.Chain( + Flux.Parallel(vcat + #parallel joins together stocks and debris + ,Flux.Chain(Flux.Dense(N_constellations, N_states*2,Flux.relu) + #,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ) + ,Flux.Chain(Flux.Dense(N_debris, N_states,Flux.relu) + #,Flux.Dense(N_states, N_states) + ) + ) + #Apply some transformations + ,Flux.Dense(N_states*3,number_params,Flux.σ) + #,Flux.Dense(number_params,number_params,Flux.σ) + + #Split out into partials related to stocks and debris + ,Flux.Parallel(vcat + ,Flux.Chain( + #Flux.Dense(number_params, number_params,Flux.relu) + #, + Flux.Dense(number_params, N_constellations,Flux.σ) + ) + ,Flux.Chain( + #Flux.Dense(number_params, number_params,Flux.relu) + #, + Flux.Dense(number_params, N_debris) + ) + ) +) +end + +# ╔═╡ f61682fe-da30-459f-b807-4fc3e8f36f32 + + +# ╔═╡ 31237f62-428a-4df3-9caf-1f5c5ade6ade +md""" +### Payout function +""" + +# ╔═╡ 55fd682d-2808-4786-ad52-647ac6892860 +begin + pay_matrix = zeros(N_constellations,N_constellations) + LinearAlgebra.I +end + +# ╔═╡ bf869218-034e-42f9-96c5-4267a8c8fcfb +function F1( + s #stock levels + ,d #debris levels + ,a #actions + #,econ::Int #TODO: a struct with various bits of info + #,taxes::Int #TODO: a struct with info on taxes + #,debris_interactions::Matrix #Ignoring for now + ,n::Matrix #constellation of interest + ) + return n*(pay_matrix*s) - n*5.0*a + #-taxes and econ should be structs with parameters later +end + +# ╔═╡ 4817d4fe-b86c-43e2-94d2-fa6c424fa001 +begin + n1 = [1.0 2 3] + d1 = [-0.002 0] +end + +# ╔═╡ 48a0d42f-bd00-4298-8b28-56acf2dbc8a7 + + +# ╔═╡ 5d85cdb6-a627-4dd5-b76c-79ed2bc019c6 + + +# ╔═╡ 5a5ca97f-28a9-4196-a1a6-40f3ae4bbb9c + + +# ╔═╡ 932b7a31-d5f4-44a3-a6c3-f608003a0a6f +md""" +## Building a function for the transition residuals (transversality conditions) +""" + +# ╔═╡ 29a8f5c6-d5b8-4ec6-95a4-89025245787d +loss(x) = sum(x.^2) + +# ╔═╡ ba7ca719-9687-4f1a-9a36-89b795a1bc13 +md""" +### Testing training the value partials +""" + +# ╔═╡ fda49870-c355-4b46-8d04-47268cb0372d +md""" +## Building a function for Optimality Residuals (Optimality Condition) +""" + +# ╔═╡ 287ff2f9-3469-4e29-91a3-91ed54109f5a +md""" +## Parameters for testing +""" + +# ╔═╡ d173ec63-b7c3-4042-862c-2626039d6d94 +β = 0.95 + +# ╔═╡ 8eca54ad-f95b-450c-8567-480f0ab7ea19 +function transition_residuals(G,H,F + ,s,s′ #stocks t and t+1 + ,d,d′ #debris '' + ,p #policy + ,n #the constellation we are dealing with + ,physics #Parameters of the physical model of the world + ;econ=0.0,taxes=0.0 #economic parameters (to be incorporated later) and tax policy + ) + + #calculate partials of transition functions + ∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,physics),s)[1] + ∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,physics),d)[1] + ∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,physics),s)[1] + ∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,physics),d)[1] + #concatenate to create iterated vector + ∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s′,d′)) + + #get partials of benefit function + ∂F_∂d = Zygote.jacobian(debris->F(s,debris,p,n1),d)[1] + ∂F_∂s = Zygote.jacobian(stocks->F(stocks,d,p,n1),s)[1] #this probaby should not be transposed + + #concatenate + ∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )' + + #calculate the optimality residuals + ∂F_∂θ + β*∂Vₜ₊₁_∂θ - ∂value((s,d)) +end + +# ╔═╡ 274e3a3f-fe23-4d6f-893c-622db0413af0 +begin + s = [1.0,2,3] + d = [1.0] + p = [1.0,2,3] +end + +# ╔═╡ 5e754d71-f4c3-476c-9656-61f07687a534 +G(s,d,[1,1,3],bm) + +# ╔═╡ 749e6b65-fd8e-45ed-9a2b-f97274549933 +survival(s,d,bm) + +# ╔═╡ d38be057-952a-450e-ba23-bd2f40e79de3 +F1(s,d,p,n1) + +# ╔═╡ affe6dd2-05e2-4852-9ef0-9605bdb5e34b +Zygote.jacobian(stocks -> F1(stocks,d,p,n1),s) + +# ╔═╡ 1c8f2c25-b8da-4e31-adac-ad575203f9cf +begin + #calculate partials of transition functions + ∂G_∂s = Zygote.jacobian(stocks->G(stocks,d,p,bm),s)[1] + ∂G_∂d = Zygote.jacobian(debris->G(s,debris,p,bm),d)[1] + ∂H_∂s = Zygote.jacobian(stocks->H(stocks,d,p,bm),s)[1] + ∂H_∂d = Zygote.jacobian(debris->H(s,debris,p,bm),d)[1] + #concatenate to create iterated vector + ∂Vₜ₊₁_∂θ = vcat(hcat(∂G_∂s ,∂G_∂d),hcat(∂H_∂s ,∂H_∂d)) * ∂value((s,d)) + + #get partials of benefit function + ∂F_∂d = Zygote.jacobian(debris->F1(s,debris,p,n1),d)[1] + ∂F_∂s = Zygote.jacobian(stocks->F1(stocks,d,p,n1),s)[1] #this probaby should not be transposed + + #concatenate + ∂F_∂θ = hcat(∂F_∂s , ∂F_∂d )' + + #calculate the optimality residuals + a = ∂F_∂θ + β * ∂Vₜ₊₁_∂θ - ∂value((s,d)) +end + +# ╔═╡ 48b09bfb-87fe-4974-a4f3-7a6ece521da6 +sum(a.^2) + +# ╔═╡ e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad +loss(a) + +# ╔═╡ 1e0cf592-fd96-4a01-84d1-078301ac7f99 +x = Flux.gradient(() -> loss(a), Flux.params(∂value)) + +# ╔═╡ 8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b +PlutoUI.with_terminal() do + println("beginning") +for pr in Flux.params(∂value) + println(x[pr]) +end +end + +# ╔═╡ 9b59255c-77b4-4df6-a70b-e1fdaab7619b +grads_value1 = Flux.gradient(() -> loss(transition_residuals(G,H,F1,s,s,d,d,p,1,bm)),Flux.params(∂value)) + +# ╔═╡ 00000000-0000-0000-0000-000000000001 +PLUTO_PROJECT_TOML_CONTENTS = """ +[deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +BenchmarkTools = "~1.2.0" +Flux = "~0.12.8" +PlutoUI = "~0.7.17" +Zygote = "~0.6.29" +""" + +# ╔═╡ 00000000-0000-0000-0000-000000000002 +PLUTO_MANIFEST_TOML_CONTENTS = """ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.0.1" + +[[AbstractTrees]] +git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.3.4" + +[[Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "84918055d15b3114ede17ac6a7182f68870c16f7" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.1" + +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[ArrayInterface]] +deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "d9352737cef8525944bf9ef34392d756321cbd54" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "3.1.38" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "61adeb0823084487000600ef8b1c00cc2474cd47" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.2.0" + +[[CEnum]] +git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.1" + +[[CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "2c8329f16addffd09e6ca84c556e2185a4933c64" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.5.0" + +[[ChainRules]] +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] +git-tree-sha1 = "035ef8a5382a614b2d8e3091b6fdbb1c2b050e11" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.12.1" + +[[ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "3533f5a691e60601fe60c90d8bc47a27aa2907ec" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.11.0" + +[[CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.0" + +[[Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.8" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "dce3e3fea680869eaa0b774b2e8343e9ff442313" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.40.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[DataAPI]] +git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.9.0" + +[[DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "7d9d316f04214f7efdbb6398d545446e246eff02" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.10" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "7220bc21c33e990c14f4a9a319b1d242ebc5b269" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.3.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[ExprTools]] +git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.6" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "8756f9935b7ccc9064c6eef0bff0ad643df733a3" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.12.7" + +[[FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[Flux]] +deps = ["AbstractTrees", "Adapt", "ArrayInterface", "CUDA", "CodecZlib", "Colors", "DelimitedFiles", "Functors", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "NNlibCUDA", "Pkg", "Printf", "Random", "Reexport", "SHA", "SparseArrays", "Statistics", "StatsBase", "Test", "ZipFile", "Zygote"] +git-tree-sha1 = "e8b37bb43c01eed0418821d1f9d20eca5ba6ab21" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.12.8" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "63777916efbcb0ab6173d09a658fb7f2783de485" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.21" + +[[Functors]] +git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.2.7" + +[[GPUArrays]] +deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "7772508f17f1d482fe0df72cabc5b55bec06bbe0" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.1.2" + +[[GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "77d915a0af27d474f0aaf12fcd46c400a552e84c" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.13.7" + +[[Hyperscript]] +deps = ["Test"] +git-tree-sha1 = "8d511d5b81240fc8e6802386302675bdf47737b9" +uuid = "47d2ed2b-36de-50cf-bf87-49c2cf4b8b91" +version = "0.0.4" + +[[HypertextLiteral]] +git-tree-sha1 = "5efcf53d798efede8fee5b2c8b09284be359bf24" +uuid = "ac1192a8-f4b3-4bfe-ba22-af5b92cd3ab2" +version = "0.9.2" + +[[IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.2" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.3" + +[[IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.1" + +[[IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.3.0" + +[[JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "8076680b162ada2a031f707ac7b4953e30667a37" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.2" + +[[Juno]] +deps = ["Base64", "Logging", "Media", "Profile"] +git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" +uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" +version = "0.8.4" + +[[LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "46092047ca4edc10720ecab437c42283cd7c44f3" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.6.0" + +[[LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6a2af408fe809c4f1a54d2b3f188fdd3698549d6" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.11+0" + +[[LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.4" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[Media]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" +uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" +version = "0.5.0" + +[[Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "5203a4532ad28c44f82c76634ad621d7c90abcbd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.7.29" + +[[NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "04490d5e7570c038b1cb0f5c3627597181cc15a9" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.1.9" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[Parsers]] +deps = ["Dates"] +git-tree-sha1 = "d911b6a12ba974dabe2291c6d450094a7226b372" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.1.1" + +[[Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[PlutoUI]] +deps = ["Base64", "Dates", "Hyperscript", "HypertextLiteral", "IOCapture", "InteractiveUtils", "JSON", "Logging", "Markdown", "Random", "Reexport", "UUIDs"] +git-tree-sha1 = "615f3a1eff94add4bca9476ded096de60b46443b" +uuid = "7f904dfe-b85e-4ff6-b463-dae2292396a8" +version = "0.7.17" + +[[Preferences]] +deps = ["TOML"] +git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.2" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Random123]] +deps = ["Libdl", "Random", "RandomNumbers"] +git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.4.2" + +[[RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.3" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "f0bccf98e16759818ffc5d97ac3ebf87eb950150" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.8.1" + +[[Static]] +deps = ["IfElse"] +git-tree-sha1 = "e7bc80dc93f50857a5d1e3c8121495852f407e6a" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.4.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "3c76dde64d03699e074ac02eb2e8ba8254d428da" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.2.13" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[StatsAPI]] +git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.0.0" + +[[StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "eb35dcc66558b2dda84079b9a1be17557d32091a" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.12" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "7cb456f358e8f9d102a8b25e8dfedf58fa5689bc" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.13" + +[[TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.6" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.9.4" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "0fc9959bcabc4668c403810b4e851f6b8962eac9" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.29" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +""" + +# ╔═╡ Cell order: +# ╠═20b324a4-3b6c-11ec-3673-d98ec8af9009 +# ╠═64f31ef7-e0ba-4353-8e51-e8356a894656 +# ╠═11a53b0c-465b-4f16-88d2-e0163e471fd6 +# ╠═1d4955f9-1d1c-4572-bc7b-9f002ea54042 +# ╠═6175d40b-0fa6-4cb4-8ef3-9765953ce97e +# ╠═28c0d31f-e90a-415c-b35e-312bdf771ddf +# ╠═7462740e-23da-4151-81b5-cc2e6cbcf2c8 +# ╠═1eee99c9-5fef-498d-8a57-9c18e2f1cf49 +# ╠═bc3492c6-5c3b-4b08-86d4-dd4026e25655 +# ╠═5e754d71-f4c3-476c-9656-61f07687a534 +# ╠═749e6b65-fd8e-45ed-9a2b-f97274549933 +# ╠═56314025-fb0a-4e98-8628-091bda52708c +# ╠═f61682fe-da30-459f-b807-4fc3e8f36f32 +# ╟─31237f62-428a-4df3-9caf-1f5c5ade6ade +# ╠═55fd682d-2808-4786-ad52-647ac6892860 +# ╠═bf869218-034e-42f9-96c5-4267a8c8fcfb +# ╠═4817d4fe-b86c-43e2-94d2-fa6c424fa001 +# ╠═d38be057-952a-450e-ba23-bd2f40e79de3 +# ╠═affe6dd2-05e2-4852-9ef0-9605bdb5e34b +# ╠═48a0d42f-bd00-4298-8b28-56acf2dbc8a7 +# ╠═5d85cdb6-a627-4dd5-b76c-79ed2bc019c6 +# ╠═5a5ca97f-28a9-4196-a1a6-40f3ae4bbb9c +# ╠═932b7a31-d5f4-44a3-a6c3-f608003a0a6f +# ╠═8eca54ad-f95b-450c-8567-480f0ab7ea19 +# ╠═1c8f2c25-b8da-4e31-adac-ad575203f9cf +# ╠═48b09bfb-87fe-4974-a4f3-7a6ece521da6 +# ╠═29a8f5c6-d5b8-4ec6-95a4-89025245787d +# ╠═e1d0aeb8-8fb8-48a8-9842-7888d9fae1ad +# ╠═1e0cf592-fd96-4a01-84d1-078301ac7f99 +# ╠═8dad1cc0-4cd1-45c5-9b92-8fff8d06fd9b +# ╟─ba7ca719-9687-4f1a-9a36-89b795a1bc13 +# ╠═9b59255c-77b4-4df6-a70b-e1fdaab7619b +# ╠═fda49870-c355-4b46-8d04-47268cb0372d +# ╠═287ff2f9-3469-4e29-91a3-91ed54109f5a +# ╠═d173ec63-b7c3-4042-862c-2626039d6d94 +# ╠═274e3a3f-fe23-4d6f-893c-622db0413af0 +# ╟─00000000-0000-0000-0000-000000000001 +# ╟─00000000-0000-0000-0000-000000000002 diff --git a/julia_code/Untitled1.ipynb b/julia_code/Untitled1.ipynb new file mode 100644 index 0000000..6869a43 --- /dev/null +++ b/julia_code/Untitled1.ipynb @@ -0,0 +1,617 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 124, + "id": "3d049ef3-2630-4064-a2bd-ee28c7a89e9e", + "metadata": {}, + "outputs": [], + "source": [ + "using LinearAlgebra, Zygote, Tullio" + ] + }, + { + "cell_type": "markdown", + "id": "cfa5d466-5354-4ef4-a4ab-2bf2ed69626e", + "metadata": {}, + "source": [ + "### notes\n", + "\n", + "Notice how this approach estimates policy and the value function in a single go.\n", + "While you could eliminate the value function partials, it would be much more complex.\n", + "Note that in RL, both must be iterated on." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2bc6f9bd-8308-40a0-bee4-acbb1be4672d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#Model dimensions\n", + "N_constellations = 5\n", + "N_debris = 1\n", + "N_states = N_constellations + N_debris" + ] + }, + { + "cell_type": "markdown", + "id": "15670d3b-9651-4fc5-b98a-29ba60b4b38e", + "metadata": {}, + "source": [ + "## Built Payoff functions" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "id": "1e99275a-9ab5-4c9e-921b-9587145b1fb5", + "metadata": {}, + "outputs": [], + "source": [ + "stocks = rand(1:N_constellations,N_constellations);\n", + "debris = rand(1:3);\n", + "payoff = 3*I + ones(5,5) .+ [1,0,0,0,0]; #TODO: move this into a struct\n", + "β=0.95\n", + "launches = ones(N_constellations);" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "id": "a4e4bbaf-61de-4860-adc2-1e91f0626ead", + "metadata": {}, + "outputs": [], + "source": [ + "#Define the market profit function\n", + "F(stocks,debris,payoff,launches) = payoff*stocks + 3.0*launches .+ (debris*-0.2)\n", + "\n", + "#create derivative functions\n", + "∂f_∂stocks(st,debris,payoff,launches) = Zygote.jacobian(s -> F(s,debris,payoff,launches),st)[1];\n", + "∂f_∂debris(stocks,d,payoff,launches) = Zygote.jacobian(s -> F(stocks,s,payoff,launches),d)[1];\n", + "∂f_∂launches(stocks,debris,payoff,launches) = Zygote.jacobian(l -> F(stocks,debris,payoff,l),launches)[1];\n" + ] + }, + { + "cell_type": "code", + "execution_count": 158, + "id": "ad73e64a-2fa5-4bdf-a4e4-22357b53002c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 5.0 2.0 2.0 2.0 2.0\n", + " 1.0 4.0 1.0 1.0 1.0\n", + " 1.0 1.0 4.0 1.0 1.0\n", + " 1.0 1.0 1.0 4.0 1.0\n", + " 1.0 1.0 1.0 1.0 4.0" + ] + }, + "execution_count": 158, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂f_∂stocks(stocks,debris,payoff,launches) \n", + "#Rows are constellations, columns are derivatives indexed by constellations" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "id": "a6159740-8ef2-40e0-b10f-c451cf78418d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5-element Vector{Float64}:\n", + " -0.2\n", + " -0.2\n", + " -0.2\n", + " -0.2\n", + " -0.2" + ] + }, + "execution_count": 159, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂f_∂debris(stocks,debris,payoff,launches)" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "id": "b46da3d9-1765-4427-b3de-bf0a9fdd6576", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 3.0 0.0 0.0 0.0 0.0\n", + " 0.0 3.0 0.0 0.0 0.0\n", + " 0.0 0.0 3.0 0.0 0.0\n", + " 0.0 0.0 0.0 3.0 0.0\n", + " 0.0 0.0 0.0 0.0 3.0" + ] + }, + "execution_count": 160, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂f_∂launches(stocks,debris,payoff,launches)" + ] + }, + { + "cell_type": "markdown", + "id": "7a3a09cc-e48d-4c27-9e4f-02e45a5fc5c1", + "metadata": {}, + "source": [ + "## Building Physical Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "638ed4ac-36b0-4efb-bcb4-13d345470dd7", + "metadata": {}, + "outputs": [], + "source": [ + "struct BasicModel\n", + " #rate at which debris hits satellites\n", + " debris_collision_rate\n", + " #rate at which satellites of different constellations collide\n", + " satellite_collision_rates\n", + " #rate at which debris exits orbits\n", + " decay_rate\n", + " #rate at which satellites\n", + " autocatalysis_rate\n", + " #ratio at which a collision between satellites produced debris\n", + " satellite_collision_debris_ratio\n", + " #Ratio at which launches produce debris\n", + " launch_debris_ratio\n", + "end\n", + "\n", + "#Getting loss parameters together.\n", + "loss_param = 2e-3;\n", + "loss_weights = loss_param*(ones(N_constellations,N_constellations) - I);\n", + "\n", + "#orbital decay rate\n", + "decay_param = 0.01;\n", + "\n", + "#debris generation parameters\n", + "autocatalysis_param = 0.001;\n", + "satellite_loss_debris_rate = 5.0;\n", + "launch_debris_rate = 0.05;\n", + "\n", + "#Todo, wrap physical model as a struct with the parameters\n", + "bm = BasicModel(\n", + " loss_param\n", + " ,loss_weights\n", + " ,decay_param\n", + " ,autocatalysis_param\n", + " ,satellite_loss_debris_rate\n", + " ,launch_debris_rate\n", + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "id": "ae04a4f2-7401-450e-ba98-bfec375b7646", + "metadata": {}, + "outputs": [], + "source": [ + "#percentage survival function\n", + "function survival(stocks,debris,physical_model) \n", + " exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris));\n", + "end\n", + "\n", + "#stock update rules\n", + "function G(stocks,debris,launches, physical_model)\n", + " return diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches\n", + "end;\n", + "\n", + "#stock evolution wrt various things\n", + "∂G_∂launches(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(stocks,debris,x,bm),launches)[1];\n", + "∂G_∂debris(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(stocks,x,launches,bm),debris)[1];\n", + "∂G_∂stocks(stocks,debris,launches,bm) = Zygote.jacobian(x -> G(x,debris,launches,bm),stocks)[1];\n", + "\n", + "#debris evolution \n", + "function H(stocks,debris,launches,physical_model)\n", + " #get changes in debris from natural dynamics\n", + " natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris \n", + " \n", + " #get changes in debris from satellite loss\n", + " satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks \n", + " \n", + " #get changes in debris from launches\n", + " launch_debris = physical_model.launch_debris_ratio*sum(launches)\n", + " \n", + " #return total debris level\n", + " return natural_debris_dynamics + satellite_loss_debris + launch_debris\n", + "end;\n", + "\n", + "#get jacobians of debris dynamics\n", + "∂H_∂launches(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(stocks,debris,x,physical_model),launches)[1];\n", + "∂H_∂debris(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(stocks,x,launches,physical_model),debris)[1];\n", + "∂H_∂stocks(stocks,debris,launches,physical_model) = Zygote.jacobian(x -> H(x,debris,launches,physical_model),stocks)[1];" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "c6519b6e-f9ef-466e-a9ac-746b8f36b9e7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 1.0 0.0 0.0 0.0 0.0\n", + " 0.0 1.0 0.0 0.0 0.0\n", + " 0.0 0.0 1.0 0.0 0.0\n", + " 0.0 0.0 0.0 1.0 0.0\n", + " 0.0 0.0 0.0 0.0 1.0" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂G_∂launches(stocks,debris,launches,bm)" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "id": "1d46c09d-2c91-4302-8e8d-1c0abeb010aa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5-element Vector{Float64}:\n", + " -0.003843157756609293\n", + " -0.009665715046375067\n", + " -0.009665715046375067\n", + " -0.009665715046375067\n", + " -0.003843157756609293" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂G_∂debris(stocks,debris,launches,bm)" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "1f791c8a-5457-4ec2-bf0b-68eb6bdd578f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 0.950789 -0.00384316 -0.00384316 -0.00384316 -0.00384316\n", + " -0.00966572 0.956572 -0.00966572 -0.00966572 -0.00966572\n", + " -0.00966572 -0.00966572 0.956572 -0.00966572 -0.00966572\n", + " -0.00966572 -0.00966572 -0.00966572 0.956572 -0.00966572\n", + " -0.00384316 -0.00384316 -0.00384316 -0.00384316 0.950789" + ] + }, + "execution_count": 168, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂G_∂stocks(stocks,debris,launches,bm)" + ] + }, + { + "cell_type": "code", + "execution_count": 169, + "id": "d97f9af0-d3a9-4739-b569-fa100a1590f3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1×5 Matrix{Float64}:\n", + " 0.05 0.05 0.05 0.05 0.05" + ] + }, + "execution_count": 169, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂H_∂launches(stocks,debris,launches,bm)" + ] + }, + { + "cell_type": "code", + "execution_count": 170, + "id": "1a7c857f-7eee-4273-b954-3a7b84a3b3e0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1-element Vector{Float64}:\n", + " 1.174417303261719" + ] + }, + "execution_count": 170, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂H_∂debris(stocks,debris,launches,bm) " + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "9ea44146-740c-47c0-9286-846585a4db39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1×5 Matrix{Float64}:\n", + " 0.360254 0.302231 0.302231 0.302231 0.360254" + ] + }, + "execution_count": 171, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂H_∂stocks(stocks,debris,launches,bm)\n", + "#columns are derivatives" + ] + }, + { + "cell_type": "markdown", + "id": "17038bfe-c7c2-484b-9aaf-6b608bbbbfab", + "metadata": {}, + "source": [ + "## Build optimality conditions" + ] + }, + { + "cell_type": "markdown", + "id": "ef3c7625-e5ee-4220-8bed-b289d0baea1e", + "metadata": {}, + "source": [ + "## Build transition conditions\n", + "\n", + "I've built the transition conditions below." + ] + }, + { + "cell_type": "code", + "execution_count": 172, + "id": "c3ef6bcc-37ca-4f6c-b876-8ea2a068cbe1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 2.0 1.0 1.0 1.0 1.0\n", + " 1.0 2.0 1.0 1.0 1.0\n", + " 1.0 1.0 2.0 1.0 1.0\n", + " 1.0 1.0 1.0 2.0 1.0\n", + " 1.0 1.0 1.0 1.0 2.0" + ] + }, + "execution_count": 172, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "W_∂stocks= ones(5,5) + I\n", + "# columns are stocks\n", + "# rows are derivatives" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "id": "5441641b-308b-4fec-a946-2568e79a1203", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×1 Matrix{Float64}:\n", + " 2.0\n", + " 2.0\n", + " 2.0\n", + " 2.0\n", + " 2.0" + ] + }, + "execution_count": 174, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + " #temporary value partials\n", + "W_∂debris = 2*ones(5,1) #temporary value partials\n", + "#columns are\n", + "#rows are" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "819df0f9-e994-4461-9329-4c4607de8779", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 0.966286 -0.00195257 -0.00195257 -0.00195257 -0.00195257\n", + " -0.00785729 0.972161 -0.00785729 -0.00785729 -0.00785729\n", + " -0.00588119 -0.00588119 0.970199 -0.00588119 -0.00588119\n", + " -0.00195257 -0.00195257 -0.00195257 0.966286 -0.00195257\n", + " -0.00391296 -0.00391296 -0.00391296 -0.00391296 0.96824" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "∂G_∂stocks(stocks,debris,launches,bm)" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "id": "5cc19637-9727-479e-aa60-a30c71e9e8b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×5 Matrix{Float64}:\n", + " 1.91695 1.91695 1.91695 1.91695 1.91695\n", + " 1.88146 1.88146 1.88146 1.88146 1.88146\n", + " 1.89335 1.89335 1.89335 1.89335 1.89335\n", + " 1.91695 1.91695 1.91695 1.91695 1.91695\n", + " 1.90518 1.90518 1.90518 1.90518 1.90518" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = ∂f_∂stocks(stocks,debris,payoff,launches)\n", + "b = ∂G_∂stocks(stocks,debris,launches,bm) * W_∂stocks #This last bit should eventually get replaced with a NN\n", + "#Need to check dimensionality above. Not sure which direction is derivatives and which is functions." + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "65917e4c-bd25-4683-8553-17bc841e594a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5×1 Matrix{Float64}:\n", + " -0.019525714195158184\n", + " -0.07857288258866406\n", + " -0.05881192039840531\n", + " -0.019525714195158184\n", + " -0.039129609402048404" + ] + }, + "execution_count": 119, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "c = (∂G_∂debris(stocks,debris,launches,bm) .* ones(1,5)) * W_∂debris" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "dd5ca5b4-c910-449c-a322-b24b3ef4ecf2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "49.45445380487922" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss(stocks,debris,payoff,launches)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "8ead6e86-3d03-48ff-bc3e-b3fa9808f981", + "metadata": {}, + "outputs": [], + "source": [ + "#TODO: create a launch model in flux and see if I can get it to do pullbacks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "658ba7f0-dfbc-42da-bb20-7f8bccbb257d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.6.2", + "language": "julia", + "name": "julia-1.6" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/julia_code/UsingFlux.ipynb b/julia_code/UsingFlux.ipynb new file mode 100644 index 0000000..31111e6 --- /dev/null +++ b/julia_code/UsingFlux.ipynb @@ -0,0 +1,327 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "0b5021da-575c-4db3-9e01-dc043a7c64b3", + "metadata": {}, + "outputs": [], + "source": [ + "using DiffEqFlux,Flux,Zygote, LinearAlgebra" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "32ca6032-9d48-4bb2-b16e-4a66473464cd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "const N_constellations = 1\n", + "const N_debris = 1\n", + "const N_states= N_constellations + N_debris" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "77213ba3-1645-45b2-903f-b7f2817cbb47", + "metadata": {}, + "outputs": [], + "source": [ + "#setup physical model\n", + "struct BasicModel\n", + " #rate at which debris hits satellites\n", + " debris_collision_rate\n", + " #rate at which satellites of different constellations collide\n", + " satellite_collision_rates\n", + " #rate at which debris exits orbits\n", + " decay_rate\n", + " #rate at which satellites\n", + " autocatalysis_rate\n", + " #ratio at which a collision between satellites produced debris\n", + " satellite_collision_debris_ratio\n", + " #Ratio at which launches produce debris\n", + " launch_debris_ratio\n", + "end\n", + "\n", + "#Getting loss parameters together.\n", + "loss_param = 2e-3;\n", + "loss_weights = loss_param*(ones(N_constellations,N_constellations) - I);\n", + "\n", + "#orbital decay rate\n", + "decay_param = 0.01;\n", + "\n", + "#debris generation parameters\n", + "autocatalysis_param = 0.001;\n", + "satellite_loss_debris_rate = 5.0;\n", + "launch_debris_rate = 0.05;\n", + "\n", + "#Todo, wrap physical model as a struct with the parameters\n", + "bm = BasicModel(\n", + " loss_param\n", + " ,loss_weights\n", + " ,decay_param\n", + " ,autocatalysis_param\n", + " ,satellite_loss_debris_rate\n", + " ,launch_debris_rate\n", + ");\n", + "\n", + "#implement tranistion function\n", + "#percentage survival function\n", + "function survival(stocks,debris,physical_model) \n", + " exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris));\n", + "end\n", + "\n", + "#stock update rules\n", + "function G(stocks,debris,launches, physical_model)\n", + " return diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches\n", + "end;\n", + "\n", + "\n", + "#debris evolution \n", + "function H(stocks,debris,launches,physical_model)\n", + " #get changes in debris from natural dynamics\n", + " natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris \n", + " \n", + " #get changes in debris from satellite loss\n", + " satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks \n", + " \n", + " #get changes in debris from launches\n", + " launch_debris = physical_model.launch_debris_ratio*sum(launches)\n", + " \n", + " #return total debris level\n", + " return natural_debris_dynamics + satellite_loss_debris + launch_debris\n", + "end;\n", + "\n", + "\n", + "#implement reward function\n", + "const payoff = 3*I - 0.02*ones(N_constellations,N_constellations)\n", + "\n", + "#Define the market profit function\n", + "F(stocks,debris,launches) = payoff*stocks + 3.0*launches .+ (debris*-0.2)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "998a1ce8-a6ba-427d-a5d1-fece358146da", + "metadata": {}, + "outputs": [], + "source": [ + "# Launch function\n", + "launches = Chain(\n", + " Parallel(vcat\n", + " #parallel joins together stocks and debris, along with intermediate interpretation\n", + " ,Chain(Dense(N_constellations, N_states*2,relu)\n", + " ,Dense(N_states*2, N_states*2,relu)\n", + " )\n", + " ,Chain(Dense(N_debris, N_states,relu)\n", + " ,Dense(N_states, N_states,relu)\n", + " )\n", + " #chain gets applied to parallel\n", + " ,Dense(N_states*3,128,relu)\n", + " #,Dense(128,128,relu)\n", + " ,Dense(128,N_constellations,relu)\n", + " )\n", + ");\n", + "\n", + "#Value functions\n", + "∂value = Chain(\n", + " Parallel(vcat\n", + " #parallel joins together stocks and debris, along with intermediate interpretation\n", + " ,Chain(Dense(N_constellations, N_states*2,relu)\n", + " ,Dense(N_states*2, N_states*2,relu)\n", + " )\n", + " ,Chain(Dense(N_debris, N_states,relu)\n", + " ,Dense(N_states, N_states,relu)\n", + " )\n", + " #chain gets applied to parallel\n", + " ,Dense(N_states*3,128,relu)\n", + " #,Dense(128,128,relu)\n", + " ,Dense(128,N_states,relu)\n", + " )\n", + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b4409af2-7f41-45bc-b7eb-4bda019e4092", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1-element Vector{Float64}:\n", + " 0.0" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#Extract parameter sets\n", + "\n", + "#= initialize Algorithm Parameters\n", + "Chose these randomly\n", + "=#\n", + "λʷ = 0.5\n", + "αʷ = 5.0\n", + "λᶿ = 0.5\n", + "αᶿ = 5.0\n", + "αʳ = 10\n", + "\n", + "# initialitze averaging returns\n", + "r = zeros(N_constellations)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0529c209-55c0-49c7-815b-47578b029593", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1-element Vector{Int64}:\n", + " 3" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# initial states\n", + "S₀ = rand(1:5,N_constellations)\n", + "D₀ = rand(1:3, N_debris)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ae7d4152-77b0-42ff-92f6-9d5d83d6a39d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "launch_params = Flux.params(launches)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "232c0a44-be74-4431-a86e-dbc71e83c17a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "loss (generic function with 1 method)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "function loss(stocks,debris)\n", + " sum(launches((stocks,debris)))\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e1e1cdef-b164-43b6-a80e-b8665bdf9b14", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Grads(...)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g = Flux.gradient(() -> loss(S₀,D₀), launch_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "35d1763b-4650-4916-957d-fbb436280e1f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "launch_params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaad2871-54ed-4674-8405-d4ebb950851d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.6.2", + "language": "julia", + "name": "julia-1.6" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}