You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Orbits/julia_code/UsingFlux.ipynb

328 lines
9.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0b5021da-575c-4db3-9e01-dc043a7c64b3",
"metadata": {},
"outputs": [],
"source": [
"using DiffEqFlux,Flux,Zygote, LinearAlgebra"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "32ca6032-9d48-4bb2-b16e-4a66473464cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"const N_constellations = 1\n",
"const N_debris = 1\n",
"const N_states= N_constellations + N_debris"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "77213ba3-1645-45b2-903f-b7f2817cbb47",
"metadata": {},
"outputs": [],
"source": [
"#setup physical model\n",
"struct BasicModel\n",
" #rate at which debris hits satellites\n",
" debris_collision_rate\n",
" #rate at which satellites of different constellations collide\n",
" satellite_collision_rates\n",
" #rate at which debris exits orbits\n",
" decay_rate\n",
" #rate at which satellites\n",
" autocatalysis_rate\n",
" #ratio at which a collision between satellites produced debris\n",
" satellite_collision_debris_ratio\n",
" #Ratio at which launches produce debris\n",
" launch_debris_ratio\n",
"end\n",
"\n",
"#Getting loss parameters together.\n",
"loss_param = 2e-3;\n",
"loss_weights = loss_param*(ones(N_constellations,N_constellations) - I);\n",
"\n",
"#orbital decay rate\n",
"decay_param = 0.01;\n",
"\n",
"#debris generation parameters\n",
"autocatalysis_param = 0.001;\n",
"satellite_loss_debris_rate = 5.0;\n",
"launch_debris_rate = 0.05;\n",
"\n",
"#Todo, wrap physical model as a struct with the parameters\n",
"bm = BasicModel(\n",
" loss_param\n",
" ,loss_weights\n",
" ,decay_param\n",
" ,autocatalysis_param\n",
" ,satellite_loss_debris_rate\n",
" ,launch_debris_rate\n",
");\n",
"\n",
"#implement tranistion function\n",
"#percentage survival function\n",
"function survival(stocks,debris,physical_model) \n",
" exp.(-physical_model.satellite_collision_rates*stocks .- (physical_model.debris_collision_rate*debris));\n",
"end\n",
"\n",
"#stock update rules\n",
"function G(stocks,debris,launches, physical_model)\n",
" return diagm(survival(stocks,debris,physical_model) .- physical_model.decay_rate)*stocks + launches\n",
"end;\n",
"\n",
"\n",
"#debris evolution \n",
"function H(stocks,debris,launches,physical_model)\n",
" #get changes in debris from natural dynamics\n",
" natural_debris_dynamics = (1-physical_model.decay_rate+physical_model.autocatalysis_rate) * debris \n",
" \n",
" #get changes in debris from satellite loss\n",
" satellite_loss_debris = physical_model.satellite_collision_debris_ratio * (1 .- survival(stocks,debris,physical_model))'*stocks \n",
" \n",
" #get changes in debris from launches\n",
" launch_debris = physical_model.launch_debris_ratio*sum(launches)\n",
" \n",
" #return total debris level\n",
" return natural_debris_dynamics + satellite_loss_debris + launch_debris\n",
"end;\n",
"\n",
"\n",
"#implement reward function\n",
"const payoff = 3*I - 0.02*ones(N_constellations,N_constellations)\n",
"\n",
"#Define the market profit function\n",
"F(stocks,debris,launches) = payoff*stocks + 3.0*launches .+ (debris*-0.2)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "998a1ce8-a6ba-427d-a5d1-fece358146da",
"metadata": {},
"outputs": [],
"source": [
"# Launch function\n",
"launches = Chain(\n",
" Parallel(vcat\n",
" #parallel joins together stocks and debris, along with intermediate interpretation\n",
" ,Chain(Dense(N_constellations, N_states*2,relu)\n",
" ,Dense(N_states*2, N_states*2,relu)\n",
" )\n",
" ,Chain(Dense(N_debris, N_states,relu)\n",
" ,Dense(N_states, N_states,relu)\n",
" )\n",
" #chain gets applied to parallel\n",
" ,Dense(N_states*3,128,relu)\n",
" #,Dense(128,128,relu)\n",
" ,Dense(128,N_constellations,relu)\n",
" )\n",
");\n",
"\n",
"#Value functions\n",
"∂value = Chain(\n",
" Parallel(vcat\n",
" #parallel joins together stocks and debris, along with intermediate interpretation\n",
" ,Chain(Dense(N_constellations, N_states*2,relu)\n",
" ,Dense(N_states*2, N_states*2,relu)\n",
" )\n",
" ,Chain(Dense(N_debris, N_states,relu)\n",
" ,Dense(N_states, N_states,relu)\n",
" )\n",
" #chain gets applied to parallel\n",
" ,Dense(N_states*3,128,relu)\n",
" #,Dense(128,128,relu)\n",
" ,Dense(128,N_states,relu)\n",
" )\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b4409af2-7f41-45bc-b7eb-4bda019e4092",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1-element Vector{Float64}:\n",
" 0.0"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Extract parameter sets\n",
"\n",
"#= initialize Algorithm Parameters\n",
"Chose these randomly\n",
"=#\n",
"λʷ = 0.5\n",
"αʷ = 5.0\n",
"λᶿ = 0.5\n",
"αᶿ = 5.0\n",
"αʳ = 10\n",
"\n",
"# initialitze averaging returns\n",
"r = zeros(N_constellations)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0529c209-55c0-49c7-815b-47578b029593",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1-element Vector{Int64}:\n",
" 3"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initial states\n",
"S₀ = rand(1:5,N_constellations)\n",
"D₀ = rand(1:3, N_debris)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ae7d4152-77b0-42ff-92f6-9d5d83d6a39d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"launch_params = Flux.params(launches)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "232c0a44-be74-4431-a86e-dbc71e83c17a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"loss (generic function with 1 method)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function loss(stocks,debris)\n",
" sum(launches((stocks,debris)))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e1e1cdef-b164-43b6-a80e-b8665bdf9b14",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Grads(...)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"g = Flux.gradient(() -> loss(S₀,D₀), launch_params)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "35d1763b-4650-4916-957d-fbb436280e1f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Params([Float32[-0.110574484; 1.0583764; 0.039519094; 0.2908444], Float32[0.0, 0.0, 0.0, 0.0], Float32[0.4765299 0.5994208 -0.43710196 0.2269359; 0.5550531 0.5423604 -0.796175 0.76214457; 0.59269524 0.7436546 0.02525105 0.85908467; 0.3774994 -0.111040816 0.84196734 -0.18133782], Float32[0.0, 0.0, 0.0, 0.0], Float32[-1.2003294; -1.24031], Float32[0.0, 0.0], Float32[-0.004074011 -0.84631246; -0.5459394 1.1513239], Float32[0.0, 0.0], Float32[-0.19545768 -0.20670874 … 0.06923863 -0.09825141; 0.097166725 0.06564395 … -0.1928437 0.19962357; … ; 0.025075339 -0.06016964 … 0.0838129 -0.11523932; 0.20085223 0.16679004 … 0.016495213 -0.1548977], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.21479565 0.0090183215 … -0.2022802 -0.19925424], Float32[0.0]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"launch_params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eaad2871-54ed-4674-8405-d4ebb950851d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.6.2",
"language": "julia",
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}