From 90478868e15e426237cf5e0a379bd08f9f9ffc3a Mon Sep 17 00:00:00 2001 From: youainti Date: Thu, 11 Nov 2021 16:20:49 -0800 Subject: [PATCH] Working branched policy operator and in the middle of cleaning up how loss function etc work --- julia_code/BellmanResidual_Operators.jl | 374 +++++++++++------------ julia_code/ExperimentWithFluxParallel.jl | 97 +++++- 2 files changed, 264 insertions(+), 207 deletions(-) diff --git a/julia_code/BellmanResidual_Operators.jl b/julia_code/BellmanResidual_Operators.jl index 0b4b067..cb5e8b9 100644 --- a/julia_code/BellmanResidual_Operators.jl +++ b/julia_code/BellmanResidual_Operators.jl @@ -9,61 +9,79 @@ using PlutoUI, Flux,LinearAlgebra # ╔═╡ 66f0e667-d722-4e1e-807b-84a39cbc41b1 md""" -# Bellman Residual Operators +# Bellman Residual Minimization of Operators """ # ╔═╡ 9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede md""" -Number of Constellations: $(const N_constellations = 4) +Define the basic dimensions of the model + - Number of Constellations: $(const N_constellations = 4) + - Number of Debris Trackers: $(const N_debris = 1) + - Number of Overall States: $(const N_states = N_constellations + N_debris) +""" -Number of Debris Trackers: $(const N_debris = 1) +# ╔═╡ 0d209540-29ad-4b2f-9e91-fac2bbee47ff +md""" +## Describe the physics of the world -Number of Overall States: $(const N_states = N_constellations + N_debris) +Three key functions + - Survival function, describing how colisions occur + - Debris Evolution + - Constellation Satellite Stocks Evolution """ +# ╔═╡ 9b1c2f62-ec19-4667-8f93-f8b52750e317 + md""" + ### Parameterize the physical world + """ + # ╔═╡ 90446134-4e45-471c-857d-4e165e51937a begin - abstract type PhysicalParameters end + begin + abstract type PhysicalParameters end + + #setup physical model + struct BasicModel <: PhysicalParameters + #rate at which debris hits satellites + debris_collision_rate::Real + #rate at which satellites of different constellations collide + satellite_collision_rates::Matrix{Float64} + #rate at which debris exits orbits + decay_rate::Real + #rate at which satellites + autocatalysis_rate::Real + #ratio at which a collision between satellites produced debris + satellite_collision_debris_ratio::Real + #Ratio at which launches produce debris + launch_debris_ratio::Real + end - #setup physical model - struct BasicModel <: PhysicalParameters - #rate at which debris hits satellites - debris_collision_rate::Real - #rate at which satellites of different constellations collide - satellite_collision_rates::Matrix{Float64} - #rate at which debris exits orbits - decay_rate::Real - #rate at which satellites - autocatalysis_rate::Real - #ratio at which a collision between satellites produced debris - satellite_collision_debris_ratio::Real - #Ratio at which launches produce debris - launch_debris_ratio::Real + #Getting loss parameters together. + loss_param = 2e-3; + loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I); + + #orbital decay rate + decay_param = 0.01; + + #debris generation parameters + autocatalysis_param = 0.001; + satellite_loss_debris_rate = 5.0; + launch_debris_rate = 0.05; + + #Todo, wrap physical model as a struct with the parameters + bm = BasicModel( + loss_param + ,loss_weights + ,decay_param + ,autocatalysis_param + ,satellite_loss_debris_rate + ,launch_debris_rate + ) end - - #Getting loss parameters together. - loss_param = 2e-3; - loss_weights = loss_param*(ones(N_constellations,N_constellations) - LinearAlgebra.I); - #orbital decay rate - decay_param = 0.01; - - #debris generation parameters - autocatalysis_param = 0.001; - satellite_loss_debris_rate = 5.0; - launch_debris_rate = 0.05; - - #Todo, wrap physical model as a struct with the parameters - bm = BasicModel( - loss_param - ,loss_weights - ,decay_param - ,autocatalysis_param - ,satellite_loss_debris_rate - ,launch_debris_rate - ) + end # ╔═╡ 5b45b29e-f0f4-41e9-91e7-d444687feb4e @@ -160,41 +178,8 @@ function planner_policy_function_generator(number_params=32) end -# ╔═╡ 6a3b5f7a-a535-450f-8c5f-19bdcc280146 -function operator_policy_function_generator(number_params=32) - - return Flux.Chain( - Flux.Parallel(vcat - #parallel joins together stocks and debris - ,Flux.Chain( - Flux.Dense(N_constellations, number_params*2,Flux.relu) - #,Flux.Dense(number_params, number_params,Flux.σ) - ) - ,Flux.Chain( - Flux.Dense(N_debris, number_params,Flux.relu) - #,Flux.Dense(number_params, number_params) - ) - ) - #Apply some transformations - ,Flux.Dense(number_params*3,number_params,Flux.σ) - ,Flux.Dense(number_params,1,Flux.relu) - ) - - - -end - -# ╔═╡ 9504bf46-e380-4933-8693-03ef3e92a4e4 -ppf = planner_policy_function_generator() - -# ╔═╡ ac4e29e6-474e-489b-adc9-549dfe27a465 -function combine_policies(operators, data) - policies = [co.policy_function(data)[1] for co=operators] - policies -end - # ╔═╡ 95bfc9d8-8427-41d6-9f0f-f155296eef91 -#not needed yet +#I don't think that these are going to be needed begin #= CUSTOM LAYERS @@ -277,55 +262,13 @@ struct ConstellationOperator payoff_fn::Function econ_params::EconomicParameters value::Flux.Chain - loss_fn::Function - policy_params::Flux.Params + #policy_params::Flux.Params #cutting this for now end #TODO: create a function that takes this struct and checks backprop -# ╔═╡ 9a97a41f-475b-458e-9f7c-aabed95c6f54 - - -# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c -md""" -# Loss function specification -""" - -# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 -#Constellation level loss function -function Ξ( - s::Vector{Float32} - ,d::Vector{Float32} - , physical_model::PhysicalParameters - ,cos::Array{ConstellationOperator} - ,operator_number::Int -) - co = cos[operator_number] - - a = combine_policies(cos,(s,d)) - s′ = G(s,d,a,physical_model) - d′ = H(s,d,a,physical_model) - - bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) - maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) - - return sum([bellman_residuals.^2 maximization_condition]) -end - -# ╔═╡ 9d4a668a-23e3-4f36-86f4-60e242caee3b -begin - s1 = ones(Float32,N_constellations) - d1 = ones(Float32,N_debris) -end - -# ╔═╡ d8deba52-dc0c-470e-81bf-f9d7cc595a41 -ppf((s1,d1)) - -# ╔═╡ 560ff22a-44b6-4c5c-a0fe-9bc95e97ef06 - - # ╔═╡ b433a7ec-8264-48d6-8b95-53d2ec4bad05 md""" -# examples of parameter models +### examples of parameter models """ # ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 @@ -355,50 +298,6 @@ begin =# end -# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3 -begin - a = operator_policy_function_generator() - b = operator_policy_function_generator() - c = operator_policy_function_generator() - d = operator_policy_function_generator() - - tops = [ - ConstellationOperator(payoff1,em2_a,value_function_generator(),a,params(a)) - ,ConstellationOperator(payoff1,em2_b,value_function_generator(),b,params(b)) - ,ConstellationOperator(payoff1,em2_c,value_function_generator(),c,params(c)) - ,ConstellationOperator(payoff1,em2_d,value_function_generator(),d,params(d)) - ] - -end - -# ╔═╡ 2f14fb8e-7f71-420f-bd24-f94a4b37b0a8 -a - -# ╔═╡ 9bd80252-bf0f-421b-a747-9b41cbc82edf -a((s1,d1)) - -# ╔═╡ 7630fccb-5169-4d46-96ea-1968baed89a2 -e = combine_policies(tops,([1,2,3,4.0],[2.0]) ) - -# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29 -begin #testing - Ξ(s1,d1,bm,tops,1) -end - -# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58 -#planner's loss function -function planners_loss( - s::Vector{Float32} - ,d::Vector{Float32} -) - l = 0.0 - for (i,co) in enumerate(tops) - l += Ξ(s,d,bm,tops,i) - end - return l -end - - # ╔═╡ dc614254-c211-4552-b985-03020bfc5ab3 em3 = CESParams(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002])) #= @@ -408,27 +307,105 @@ The model is CES the relationship between payoffs and debris. In this particular specification, the only interaction is in debris =# +# ╔═╡ 4452ab2c-813b-4b7b-8cb3-388ec507a24c +md""" +## Setup Operators Policy Functions +""" + +# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c +md""" +# Loss function specification +""" + +# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 +#Constellation operator loss function +function Ξ( + s::Vector{Float32} + ,d::Vector{Float32} + ,physical_model::PhysicalParameters + ,co::ConstellationOperator + ,policy::Flux.Chain +) + #get actions + a = policy((s,d)) + + #get updated stocks and debris + s′ = G(s,d,a,physical_model) + d′ = H(s,d,a,physical_model) + + + bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) + + maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′)) + + return sum(bellman_residuals.^2, maximization_condition) +end + +# ╔═╡ 9d4a668a-23e3-4f36-86f4-60e242caee3b +begin #testing data + s1 = ones(Float32,N_constellations) + d1 = ones(Float32,N_debris) +end + # ╔═╡ cd55e232-493d-4849-8bd7-b0ba85e21bab md""" # Start setting things up """ -# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f -policy = planner_policy_function_generator(); - -# ╔═╡ f30904a7-5caa-449a-a5bd-f2aa78777a9a +# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3 begin - #setup the operators - operators = [ ConstellationOperator(payoff1,em2_a,value_function_generator()) - ,ConstellationOperator(payoff1,em2_b,value_function_generator()) - ,ConstellationOperator(payoff1,em2_c,value_function_generator()) - ,ConstellationOperator(payoff1,em2_d,value_function_generator()) + const tops = [ + ConstellationOperator(payoff1,em2_a,value_function_generator()) + ,ConstellationOperator(payoff1,em2_b,value_function_generator()) + ,ConstellationOperator(payoff1,em2_c,value_function_generator()) + ,ConstellationOperator(payoff1,em2_d,value_function_generator()) ] - #check whether or not we've matched the setup correctly. - @assert length(operators) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized" + #sanity Check time + @assert length(tops) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized" end +# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29 +begin #testing + Ξ(s1,d1,bm,tops[1],collected_policies) +end + +# ╔═╡ 89258b2e-cee5-4cbd-a42e-21301b7a2549 +begin + #= + Used to provide an interable loss function for training + =# + struct OperatorLoss + physical_model::PhysicalParameters + co::ConstellationOperator + collected_policies::Flux.Chain + policy_params::Flux.Params + end + function (self::OperatorLoss)( + s::Vector{Float32} + ,d::Vector{Float32} + ) + Ξ(s ,d ,self.physical_model ,self.co ,self.collected_policies ) + end + + #do the same thing with the planner's problem + struct PlannerLoss + end + function planners_loss( + s::Vector{Float32} + ,d::Vector{Float32} + ) + l = 0.0 + for (i,co) in enumerate(tops) + l += Ξ(s,d,bm,tops,i) + end + return l + end +end + +# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f +policy = planner_policy_function_generator(); + # ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166 md""" ## Constructing data and training @@ -436,13 +413,22 @@ md""" # ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854 begin - N=200 #increase later - data = [(rand(1:500f0, N_constellations),rand(1:500f0, N_debris)) for n=1:N] + struct DataConstructor + N::UInt64 + end + (dc::DataConstructor)(bottom=1f0,top=500f0) = [(rand(bottom:top, N_constellations),rand(bottom:top, N_debris)) for n=1:dc.N] + + data = DataConstructor(200) + + data1 = data() end # ╔═╡ 6bf8d29a-7990-4e91-86e6-d9894ed3db27 #optimizer -ADAM = Flux.Optimise.ADAM(0.1) +opt = Flux.Optimise.ADAM(0.1) + +# ╔═╡ 30417194-6cd1-4e10-9e25-1fa1c0761b9a + # ╔═╡ e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d begin @@ -1056,45 +1042,39 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╔═╡ Cell order: # ╠═0829eb90-1d74-46b6-80ec-482a2b71c6fe # ╟─66f0e667-d722-4e1e-807b-84a39cbc41b1 -# ╠═9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede +# ╟─9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede +# ╟─0d209540-29ad-4b2f-9e91-fac2bbee47ff +# ╟─5b45b29e-f0f4-41e9-91e7-d444687feb4e +# ╟─152f3a3c-a565-41bb-8e59-6ab0d2315ffb +# ╟─25ac9438-2b1d-4f6b-9ff1-1695e1d52b51 +# ╠═9b1c2f62-ec19-4667-8f93-f8b52750e317 # ╟─90446134-4e45-471c-857d-4e165e51937a -# ╠═5b45b29e-f0f4-41e9-91e7-d444687feb4e -# ╠═152f3a3c-a565-41bb-8e59-6ab0d2315ffb -# ╠═25ac9438-2b1d-4f6b-9ff1-1695e1d52b51 -# ╠═29ff1777-d276-4e8f-8582-4ca191f2e2ff +# ╟─29ff1777-d276-4e8f-8582-4ca191f2e2ff # ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c # ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a -# ╠═6a3b5f7a-a535-450f-8c5f-19bdcc280146 -# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3 -# ╠═2f14fb8e-7f71-420f-bd24-f94a4b37b0a8 -# ╠═9bd80252-bf0f-421b-a747-9b41cbc82edf -# ╠═9504bf46-e380-4933-8693-03ef3e92a4e4 -# ╠═d8deba52-dc0c-470e-81bf-f9d7cc595a41 -# ╠═ac4e29e6-474e-489b-adc9-549dfe27a465 -# ╠═7630fccb-5169-4d46-96ea-1968baed89a2 # ╠═95bfc9d8-8427-41d6-9f0f-f155296eef91 # ╠═bbca5143-f314-40ea-a20e-8a043272e362 # ╠═340da189-f443-4376-a82d-7699a21ab7a2 # ╠═206ac4cc-5102-4381-ad8a-777b02dc4d5a # ╠═eebb8706-a431-4fd1-b7a5-40f07a63d5cb # ╠═f8d582cb-10cf-4c72-8127-787f662e0567 -# ╠═9a97a41f-475b-458e-9f7c-aabed95c6f54 -# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c -# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 -# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b -# ╠═43b99708-0052-4b78-886c-92ac2b532f29 -# ╠═560ff22a-44b6-4c5c-a0fe-9bc95e97ef06 -# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58 # ╠═b433a7ec-8264-48d6-8b95-53d2ec4bad05 # ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150 # ╠═19ccfc3a-6dbb-4c64-bf03-e2e219ef0efe # ╠═dc614254-c211-4552-b985-03020bfc5ab3 +# ╠═4452ab2c-813b-4b7b-8cb3-388ec507a24c +# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c +# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4 +# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b +# ╠═43b99708-0052-4b78-886c-92ac2b532f29 +# ╠═89258b2e-cee5-4cbd-a42e-21301b7a2549 # ╠═cd55e232-493d-4849-8bd7-b0ba85e21bab +# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3 # ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f -# ╠═f30904a7-5caa-449a-a5bd-f2aa78777a9a # ╠═5abebc1a-370c-4f5f-8826-dc0b143d5166 # ╠═a20959be-65e4-4b69-9521-503bc59f0854 # ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27 +# ╠═30417194-6cd1-4e10-9e25-1fa1c0761b9a # ╠═e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d # ╠═02f3fe78-e7a7-453f-9ddf-acddf08d8676 # ╠═c50b1d39-fe87-441b-935c-c5fe971d09ef diff --git a/julia_code/ExperimentWithFluxParallel.jl b/julia_code/ExperimentWithFluxParallel.jl index fcd0e2e..56e8190 100644 --- a/julia_code/ExperimentWithFluxParallel.jl +++ b/julia_code/ExperimentWithFluxParallel.jl @@ -60,10 +60,13 @@ b(([1],[1])) b([1],[2]) # ╔═╡ c1420fc5-de42-4b32-9e13-e859e111996b - +md""" +Again, very similar to previous work, but without the vcat +""" # ╔═╡ 46c9f81c-3a8c-463a-952e-32287d11f1db -c = Flux.Parallel(vcat +c = + Flux.Parallel(vcat #add chain here ,Flux.Chain( Flux.Parallel(vcat @@ -92,10 +95,13 @@ c = Flux.Parallel(vcat c(([1],[2]),([3],[4]),([1],[2])) # ╔═╡ 10938bcf-dbc5-4109-b8a2-ee21c482f610 -c(([1],[2],[3],[4],[1],[2])) +c([1],[2],[3],[4],[1],[2]) #ignores the last 3 # ╔═╡ e318c3f0-3828-41fc-9fd0-de2ae3d19e2f -c([1],[2],[3],[4]) +c([1],[2],[3]) + +# ╔═╡ 0da65b18-e511-4203-a5af-cdb0096acbbf +c(([1],[2])) #don't run the 3rd branch # ╔═╡ ec1bd1c4-fc35-482b-b186-7f95cb332463 md""" @@ -105,11 +111,73 @@ In the second case, each sub-branch gets its own details. The big test will be whether training on each parallel branch will affect anything other than that entry """ -# ╔═╡ 7efd82a6-6968-4b44-82f8-a955edf77532 -f(x) = (x,x) +# ╔═╡ 0d85a68f-3a5d-4aab-ab0a-b9830e50c383 +begin #move to module + #= TupleDuplicator + This is used to create a tuple of size n with deepcopies of any object x + =# + struct TupleDuplicator + n::Int + end + (f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...) + + #= + This generates a policy function full of branches with the properly scaled sides + =# + struct BranchGenerator + n::UInt8 #limit to 2^8 operators + end + function (b::BranchGenerator)(branch::Flux.Chain,join_fn::Function) + # used to deepcopy the branches and duplicate the inputs in the returned chain + f = TupleDuplicator(b.n) + + return Flux.Chain( + f + ,Flux.Parallel(join_fn + ,f(branch) + ) + ) + end +end + +# ╔═╡ 2f6b7042-5cf5-4312-a617-dbeb08e05175 +bg3 = BranchGenerator(3) + +# ╔═╡ b9fe5e74-a524-4c74-8c88-b26204ffa57b +begin + #Setup branch to duplicate + d = Flux.Chain( + Flux.Parallel(vcat + ,Dense(1,2,Flux.σ) + ,Dense(1,2) + ) + ,Dense(4,1) + ) + + #build branch + e = bg3(d, vcat) +end + +# ╔═╡ 5a03a348-93bf-4298-b483-57f1256a09fb +e(([2],[1])) + +# ╔═╡ 28ff091c-fa33-485a-b425-2197a7915419 +loss(x) = sum(abs2,e(x)) #force to zero + +# ╔═╡ 3f0c5427-6a45-44cb-a83e-8f3829c5f3cf +loss(([2],[1])) + +# ╔═╡ 03e2712c-0edb-4fd0-ac1f-cd34be4bcc03 +opt = Flux.Optimise.ADAGrad() + +# ╔═╡ 64644280-577a-4d31-9863-6c9914bba94c +params = Flux.params(e[2][1]) + +# ╔═╡ 19a3a52f-3ecf-46ff-ac4d-5d44ef3a7821 +Flux.train!(loss, params, (([2],[1])), opt) -# ╔═╡ 92dc1ae3-1286-4c51-91fb-c75abd897674 -f((2,3)) +# ╔═╡ ad690dd7-6fad-4194-9cf2-ab33b1b23a11 +e(([2],[1])) # ╔═╡ 00000000-0000-0000-0000-000000000001 PLUTO_PROJECT_TOML_CONTENTS = """ @@ -704,8 +772,17 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" # ╠═8c651a48-a919-4e13-95ed-e0bd929e8b64 # ╠═10938bcf-dbc5-4109-b8a2-ee21c482f610 # ╠═e318c3f0-3828-41fc-9fd0-de2ae3d19e2f +# ╠═0da65b18-e511-4203-a5af-cdb0096acbbf # ╠═ec1bd1c4-fc35-482b-b186-7f95cb332463 -# ╠═7efd82a6-6968-4b44-82f8-a955edf77532 -# ╠═92dc1ae3-1286-4c51-91fb-c75abd897674 +# ╠═0d85a68f-3a5d-4aab-ab0a-b9830e50c383 +# ╠═2f6b7042-5cf5-4312-a617-dbeb08e05175 +# ╠═b9fe5e74-a524-4c74-8c88-b26204ffa57b +# ╠═5a03a348-93bf-4298-b483-57f1256a09fb +# ╠═28ff091c-fa33-485a-b425-2197a7915419 +# ╠═3f0c5427-6a45-44cb-a83e-8f3829c5f3cf +# ╠═03e2712c-0edb-4fd0-ac1f-cd34be4bcc03 +# ╠═64644280-577a-4d31-9863-6c9914bba94c +# ╠═19a3a52f-3ecf-46ff-ac4d-5d44ef3a7821 +# ╠═ad690dd7-6fad-4194-9cf2-ab33b1b23a11 # ╟─00000000-0000-0000-0000-000000000001 # ╟─00000000-0000-0000-0000-000000000002