diff --git a/julia_code/BellmanResidualMinimization.jl b/julia_code/BellmanResidualMinimization.jl index 571d881..5069b6e 100644 --- a/julia_code/BellmanResidualMinimization.jl +++ b/julia_code/BellmanResidualMinimization.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.17.0 +# v0.17.1 using Markdown using InteractiveUtils diff --git a/julia_code/BellmanResidual_Operators.jl b/julia_code/BellmanResidual_Operators.jl index 81e22a0..d9c6e13 100644 --- a/julia_code/BellmanResidual_Operators.jl +++ b/julia_code/BellmanResidual_Operators.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.17.0 +# v0.17.1 using Markdown using InteractiveUtils @@ -9,7 +9,7 @@ using PlutoUI, Flux,LinearAlgebra # ╔═╡ 66f0e667-d722-4e1e-807b-84a39cbc41b1 md""" -# Bellman Residual Minimization +# Bellman Residual Operators """ @@ -160,33 +160,38 @@ function planner_policy_function_generator(number_params=32) end -# ╔═╡ f2523e2c-2c56-4883-a074-5de7a0aed25b -begin - a = Flux.Chain( - Flux.Parallel( - vcat - ,Dense(1,2) - ,Dense(1,2) - ) - ,Dense(2,1) - ) +# ╔═╡ 6a3b5f7a-a535-450f-8c5f-19bdcc280146 +function operator_policy_function_generator(number_params=32) - c = Flux.Chain( - Flux.Parallel( - vcat - ,Dense(1,2) - ,Dense(1,2) + return Flux.Chain( + Flux.Parallel(vcat + #parallel joins together stocks and debris + ,Flux.Chain( + Flux.Dense(N_constellations, number_params*2,Flux.relu) + #,Flux.Dense(number_params, number_params,Flux.σ) + ) + ,Flux.Chain( + Flux.Dense(N_debris, number_params,Flux.relu) + #,Flux.Dense(number_params, number_params) + ) + ) + #Apply some transformations + ,Flux.Dense(number_params*3,number_params,Flux.σ) + ,Flux.Dense(number_params,1,Flux.relu) ) - ,Dense(2,1) - ) - fancyNN = Flux.Parallel(vcat, - a,c) + end -# ╔═╡ 7075d5fb-8273-498e-87bb-40e084c97601 -fancyNN(([1],[2]),([3],[4])) +# ╔═╡ 9504bf46-e380-4933-8693-03ef3e92a4e4 +ppf = planner_policy_function_generator() + +# ╔═╡ ac4e29e6-474e-489b-adc9-549dfe27a465 +function combine_policies(operators, data) + policies = [co.policy_function(data)[1] for co=operators] + policies +end # ╔═╡ 95bfc9d8-8427-41d6-9f0f-f155296eef91 #not needed yet @@ -218,43 +223,6 @@ begin end -# ╔═╡ 6a3b5f7a-a535-450f-8c5f-19bdcc280146 -function operators_policy_function_generator(number_params=32) - function f() - return Flux.Chain( - Flux.Parallel(vcat - #parallel joins together stocks and debris - ,Flux.Chain( - Flux.Dense(N_constellations, number_params,Flux.relu) - #,Flux.Dense(number_params, number_params,Flux.σ) - ) - ,Flux.Chain( - Flux.Dense(N_debris, number_params,Flux.relu) - #,Flux.Dense(number_params, number_params) - ) - ) - #Apply some transformations - ,Flux.Dense(number_params*3,number_params,Flux.σ) - ,Flux.Dense(number_params,1,Flux.relu) - ) - end - - a = [f() for i=1:N_constellations] - b = [passthrough for i=1:N_constellations] - - return Flux.Chain( - Split(a) - #,Flux.Parallel(vcat, b) - ) - -end - -# ╔═╡ 3d9a2425-d549-48a4-badb-34c0c07aeecc -b = operators_policy_function_generator() - -# ╔═╡ b73396ce-f5ef-46bc-a92c-94a48d9b4551 -Split(() -> 1, () -> 2) - # ╔═╡ bbca5143-f314-40ea-a20e-8a043272e362 md""" # Defining economic parameters and payoff functions @@ -309,6 +277,8 @@ struct ConstellationOperator payoff_fn::Function econ_params::EconomicParameters value::Flux.Chain + policy_function::Flux.Chain + policy_params::Flux.Params end #TODO: create a function that takes this struct and checks backprop @@ -317,6 +287,36 @@ md""" # Loss function specification """ +# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 +#Constellation level loss function +function Ξ( + s::Vector{Float32} + ,d::Vector{Float32} + , physical_model::PhysicalParameters + ,cos::Array{ConstellationOperator} + ,operator_number::Int +) + co = cos[operator_number] + + a = combine_policies(cos,(s,d)) + s′ = G(s,d,a,physical_model) + d′ = H(s,d,a,physical_model) + + bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) + maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) + + return sum([bellman_residuals.^2 maximization_condition]) +end + +# ╔═╡ 9d4a668a-23e3-4f36-86f4-60e242caee3b +begin + s1 = ones(Float32,N_constellations) + d1 = ones(Float32,N_debris) +end + +# ╔═╡ d8deba52-dc0c-470e-81bf-f9d7cc595a41 +ppf((s1,d1)) + # ╔═╡ b433a7ec-8264-48d6-8b95-53d2ec4bad05 md""" # examples of parameter models @@ -349,6 +349,50 @@ begin =# end +# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3 +begin + a = operator_policy_function_generator() + b = operator_policy_function_generator() + c = operator_policy_function_generator() + d = operator_policy_function_generator() + + tops = [ + ConstellationOperator(payoff1,em2_a,value_function_generator(),a,params(a)) + ,ConstellationOperator(payoff1,em2_b,value_function_generator(),b,params(b)) + ,ConstellationOperator(payoff1,em2_c,value_function_generator(),c,params(c)) + ,ConstellationOperator(payoff1,em2_d,value_function_generator(),d,params(d)) + ] + +end + +# ╔═╡ 2f14fb8e-7f71-420f-bd24-f94a4b37b0a8 +a + +# ╔═╡ 9bd80252-bf0f-421b-a747-9b41cbc82edf +a((s1,d1)) + +# ╔═╡ 7630fccb-5169-4d46-96ea-1968baed89a2 +e = combine_policies(tops,([1,2,3,4.0],[2.0]) ) + +# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29 +begin #testing + Ξ(s1,d1,bm,tops,1) +end + +# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58 +#planner's loss function +function planners_loss( + s::Vector{Float32} + ,d::Vector{Float32} +) + l = 0.0 + for (i,co) in enumerate(tops) + l += Ξ(s,d,bm,tops,i) + end + return l +end + + # ╔═╡ dc614254-c211-4552-b985-03020bfc5ab3 em3 = CESParams(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002])) #= @@ -366,24 +410,6 @@ md""" # ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f policy = planner_policy_function_generator(); -# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 -#Constellation level loss function -function Ξ( - s::Vector{Float32} - ,d::Vector{Float32} - , physical_model::PhysicalParameters - ,co::ConstellationOperator -) - a = policy((s,d)) - s′ = G(s,d,a,physical_model) - d′ = H(s,d,a,physical_model) - - bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) - maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) - - return sum([bellman_residuals.^2 maximization_condition]) -end - # ╔═╡ f30904a7-5caa-449a-a5bd-f2aa78777a9a begin #setup the operators @@ -397,30 +423,6 @@ begin @assert length(operators) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized" end -# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29 -begin #testing - s1 = ones(Float32,N_constellations) - d1 = ones(Float32,N_debris) - Ξ(s1,d1,bm,operators[1]) -end - -# ╔═╡ caaabe93-cc09-45c3-9c3f-be4aeb281099 -b((s1,d1)) - -# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58 -#planner's loss function -function planners_loss( - s::Vector{Float32} - ,d::Vector{Float32} -) - l = 0.0 - for co in operators - l += Ξ(s,d,bm,co) - end - return l -end - - # ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166 md""" ## Constructing data and training @@ -459,6 +461,23 @@ for epoch in 1:20 end end +# ╔═╡ e8dbe65e-7df7-4810-8e83-72f0b18d0f1d +# Operators Problem +for epoch in 1:20 + data1 = [(rand(1:500f0, N_constellations),rand(1:500f0, N_debris)) for n=1:N] + + + + #Sweep through training the value functions + for co in tops + Flux.Optimise.train!(planners_loss, params(co.value), data1, ADAM) + Flux.Optimise.train!(co.policy_function, co.policy_params, data1, ADAM) + end +end + +# ╔═╡ d33a0310-07b7-40d1-b3ae-1cbd6977ef6e + + # ╔═╡ 02f3fe78-e7a7-453f-9ddf-acddf08d8676 begin local accum = 0.0 @@ -473,10 +492,10 @@ end begin n=15 - [operators[1].value(data[n]) - ,operators[2].value(data[n]) - ,operators[3].value(data[n]) - ,operators[4].value(data[n])] + [tops[1].value(data[n]) + ,tops[2].value(data[n]) + ,tops[3].value(data[n]) + ,tops[4].value(data[n])] end # ╔═╡ c50b1d39-fe87-441b-935c-c5fe971d09ef @@ -1071,11 +1090,13 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c # ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a # ╠═6a3b5f7a-a535-450f-8c5f-19bdcc280146 -# ╠═3d9a2425-d549-48a4-badb-34c0c07aeecc -# ╠═caaabe93-cc09-45c3-9c3f-be4aeb281099 -# ╠═b73396ce-f5ef-46bc-a92c-94a48d9b4551 -# ╠═f2523e2c-2c56-4883-a074-5de7a0aed25b -# ╠═7075d5fb-8273-498e-87bb-40e084c97601 +# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3 +# ╠═2f14fb8e-7f71-420f-bd24-f94a4b37b0a8 +# ╠═9bd80252-bf0f-421b-a747-9b41cbc82edf +# ╠═9504bf46-e380-4933-8693-03ef3e92a4e4 +# ╠═d8deba52-dc0c-470e-81bf-f9d7cc595a41 +# ╠═ac4e29e6-474e-489b-adc9-549dfe27a465 +# ╠═7630fccb-5169-4d46-96ea-1968baed89a2 # ╠═95bfc9d8-8427-41d6-9f0f-f155296eef91 # ╠═bbca5143-f314-40ea-a20e-8a043272e362 # ╠═340da189-f443-4376-a82d-7699a21ab7a2 @@ -1084,6 +1105,7 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╠═f8d582cb-10cf-4c72-8127-787f662e0567 # ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c # ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 +# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b # ╠═43b99708-0052-4b78-886c-92ac2b532f29 # ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58 # ╠═b433a7ec-8264-48d6-8b95-53d2ec4bad05 @@ -1098,6 +1120,8 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27 # ╠═e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d # ╠═74f5fde3-0593-46fc-a688-f1db7ab28c64 +# ╠═e8dbe65e-7df7-4810-8e83-72f0b18d0f1d +# ╠═d33a0310-07b7-40d1-b3ae-1cbd6977ef6e # ╠═02f3fe78-e7a7-453f-9ddf-acddf08d8676 # ╠═c50b1d39-fe87-441b-935c-c5fe971d09ef # ╠═14e61097-f28f-4029-b6b4-5fb119620fc3