diff --git a/julia_code/BellmanResidualMinimization.jl b/julia_code/BellmanResidualMinimization.jl index 2ee4111..34147aa 100644 --- a/julia_code/BellmanResidualMinimization.jl +++ b/julia_code/BellmanResidualMinimization.jl @@ -25,9 +25,6 @@ Number of Overall States: $(const N_states = N_constellations + N_debris) # ╔═╡ 90446134-4e45-471c-857d-4e165e51937a begin - md""" - Parameterization - """ abstract type PhysicalParameters end #setup physical model @@ -123,17 +120,17 @@ md""" function value_function_generator(number_params=10) return Flux.Chain( Flux.Parallel(vcat - #parallel joins together stocks and debris + #parallel joins together stocks and debris, after a little bit of preprocessing ,Flux.Chain( Flux.Dense(N_constellations, N_states*2,Flux.relu) - #,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ,Flux.Dense(N_states*2, N_states*2,Flux.σ) ) ,Flux.Chain( Flux.Dense(N_debris, N_states,Flux.relu) - #,Flux.Dense(N_states, N_states) + ,Flux.Dense(N_states, N_states) ) ) - #Apply some transformations + #Apply some transformations to the preprocessed data. ,Flux.Dense(N_states*3,number_params,Flux.σ) ,Flux.Dense(number_params,1,Flux.σ) ) @@ -147,11 +144,11 @@ function policy_function_generator(number_params=10) #parallel joins together stocks and debris ,Flux.Chain( Flux.Dense(N_constellations, N_states*2,Flux.relu) - #,Flux.Dense(N_states*2, N_states*2,Flux.σ) + ,Flux.Dense(N_states*2, N_states*2,Flux.σ) ) ,Flux.Chain( Flux.Dense(N_debris, N_states,Flux.relu) - #,Flux.Dense(N_states, N_states) + ,Flux.Dense(N_states, N_states) ) ) #Apply some transformations @@ -162,22 +159,32 @@ function policy_function_generator(number_params=10) end # ╔═╡ 95bfc9d8-8427-41d6-9f0f-f155296eef91 -#not needed +#not needed yet begin - #= - Test a return in tuples. Just to see what can happen. + #= CUSTOM LAYERS + =# + #Custom passthrough layer passthrough(x::Array) = x - Tuple(a::Array,b::Array) = (a,b) - Flux.Parallel(Tuple, - passthrough,passthrough - )(([1,2],[3,4])); -end -# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f -begin - value = value_function_generator(); - policy = policy_function_generator(); + + # custom split layer + struct Split{T} + paths::T + end + Split(paths...) = Split(paths) + Flux.@functor Split + (m::Split)(x::AbstractArray) = tuple(map(f -> f(x), m.paths)) + + ### TESTING ### + + + #multiple branches + Flux.Parallel(vcat, + passthrough, passthrough, passthrough + )(([1],[2,3],[4])) + + end # ╔═╡ 206ac4cc-5102-4381-ad8a-777b02dc4d5a @@ -190,9 +197,6 @@ begin end end -# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 -em = EconModel1(0.95, [1 0 0 ], [5 0 0 ]) - # ╔═╡ 1cbaa2e5-55e4-46f9-82d0-04b481470094 function payoff1( s::Vector @@ -203,66 +207,122 @@ function payoff1( return em.payoff_array*s - em.policy_costs*a end +# ╔═╡ f8d582cb-10cf-4c72-8127-787f662e0567 +#= +This struct organizes information about a given constellation operator +=# +struct ConstellationOperator + payoff_fn::Function + econ_params::EconomicParameters + value::Flux.Chain +end +#TODO: create a function that takes this struct and checks backprop + +# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c +md""" +# Loss function specification +""" + +# ╔═╡ b433a7ec-8264-48d6-8b95-53d2ec4bad05 +md""" +# Testing +""" + +# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f +policy = policy_function_generator(); + # ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 +#Constellation level loss function function Ξ( s::Vector ,d::Vector , physical_model::PhysicalParameters - ,econ_params::EconomicParameters - ,payoff_fn::Function + ,co::ConstellationOperator ) a = policy((s,d)) s′ = G(s,d,a,physical_model) d′ = H(s,d,a,physical_model) - bellman_residuals = value((s,d)) - payoff_fn(s,d,a,econ_params) - econ_params.β*value((s′,d′)) - maximization_condition = - payoff_fn(s,d,a,econ_params) - econ_params.β*value((s′,d′)) + bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) + maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) return sum([bellman_residuals.^2 maximization_condition]) end -# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854 +# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 +em1 = EconModel1(0.95, [1 0 0 ], [5 0 0 ]) + +# ╔═╡ f30904a7-5caa-449a-a5bd-f2aa78777a9a begin - N=12 - data = [(rand(1:500, N_constellations),rand(1:500, N_debris)) for n=1:N] + #setup the operators + operators = [ ConstellationOperator(payoff1,em1,value_function_generator()) + ,ConstellationOperator(payoff1,em1,value_function_generator()) + ,ConstellationOperator(payoff1,em1,value_function_generator()) + ] + + #check whether or not we've matched the setup correctly. + @assert length(operators) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized" end -# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58 -#loss function -loss(s,d) = Ξ(s,d,bm,em,payoff1) - -# ╔═╡ 20c777b5-4295-4478-8f53-b18cd409c8ae -s1 = ones(N_constellations) - -# ╔═╡ 1d65707c-6333-4252-ace6-bad47146ba06 -d1 = ones(N_debris) - # ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29 -Ξ(s1,d1,bm,em,payoff1) +begin + s1 = ones(N_constellations) + d1 = ones(N_debris) + Ξ(s1,d1,bm,operators[1]) +end + +# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58 +#planner's loss function +function planners_loss(s,d) + l = 0.0 + for co in operators + l += Ξ(s,d,bm,co) + end + return l +end + -# ╔═╡ 39433c1a-c3ac-45b0-b1bf-ff2d42ca9cbb +# ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166 +md""" +## Constructing data +""" +# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854 +begin + N=20 #increase later + data = [(rand(1:500, N_constellations),rand(1:500, N_debris)) for n=1:N] +end # ╔═╡ 6bf8d29a-7990-4e91-86e6-d9894ed3db27 #optimizer -ADAM = Flux.Optimise.ADAM() +ADAM = Flux.Optimise.ADAM(0.01) -# ╔═╡ 74f5fde3-0593-46fc-a688-f1db7ab28c64 -for epoch in 1:200 - #train the policy funciton - Flux.Optimise.train!(loss, params(policy), data, ADAM) +# ╔═╡ e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d +begin + accum1 = 0.0 + for d in data + accum1 += planners_loss(d...) + end + accum1/N +end - #Sweep through the value functions:w +# ╔═╡ 74f5fde3-0593-46fc-a688-f1db7ab28c64 +# Social planners problem +for epoch in 1:20 + #train the social planner's policy funciton + Flux.Optimise.train!(planners_loss, params(policy), data, ADAM) - #Train the value functions - Flux.Optimise.train!(loss, params(value), data, ADAM) + #Sweep through training the value functions + for co in operators + Flux.Optimise.train!(planners_loss, params(co.value), data, ADAM) + end end # ╔═╡ 02f3fe78-e7a7-453f-9ddf-acddf08d8676 begin accum = 0.0 for d in data - accum += loss(d...) + accum += planners_loss(d...) end accum/N end @@ -838,28 +898,31 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╔═╡ Cell order: # ╠═0829eb90-1d74-46b6-80ec-482a2b71c6fe -# ╠═66f0e667-d722-4e1e-807b-84a39cbc41b1 -# ╟─90446134-4e45-471c-857d-4e165e51937a +# ╟─66f0e667-d722-4e1e-807b-84a39cbc41b1 # ╠═9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede +# ╟─90446134-4e45-471c-857d-4e165e51937a # ╟─5b45b29e-f0f4-41e9-91e7-d444687feb4e # ╟─152f3a3c-a565-41bb-8e59-6ab0d2315ffb # ╟─25ac9438-2b1d-4f6b-9ff1-1695e1d52b51 # ╠═29ff1777-d276-4e8f-8582-4ca191f2e2ff # ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c # ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a -# ╟─95bfc9d8-8427-41d6-9f0f-f155296eef91 -# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f +# ╠═95bfc9d8-8427-41d6-9f0f-f155296eef91 # ╠═206ac4cc-5102-4381-ad8a-777b02dc4d5a -# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 # ╠═1cbaa2e5-55e4-46f9-82d0-04b481470094 +# ╠═f8d582cb-10cf-4c72-8127-787f662e0567 +# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c # ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 -# ╠═a20959be-65e4-4b69-9521-503bc59f0854 -# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58 -# ╠═20c777b5-4295-4478-8f53-b18cd409c8ae -# ╠═1d65707c-6333-4252-ace6-bad47146ba06 # ╠═43b99708-0052-4b78-886c-92ac2b532f29 -# ╠═39433c1a-c3ac-45b0-b1bf-ff2d42ca9cbb +# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58 +# ╟─b433a7ec-8264-48d6-8b95-53d2ec4bad05 +# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f +# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 +# ╠═f30904a7-5caa-449a-a5bd-f2aa78777a9a +# ╟─5abebc1a-370c-4f5f-8826-dc0b143d5166 +# ╠═a20959be-65e4-4b69-9521-503bc59f0854 # ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27 +# ╠═e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d # ╠═74f5fde3-0593-46fc-a688-f1db7ab28c64 # ╠═02f3fe78-e7a7-453f-9ddf-acddf08d8676 # ╟─00000000-0000-0000-0000-000000000001