Working branched policy operator and in the middle of cleaning up how loss function etc work

temporaryWork^2
youainti 5 years ago
parent 76b8cfdc8f
commit 90478868e1

@ -9,22 +9,37 @@ using PlutoUI, Flux,LinearAlgebra
# ╔═╡ 66f0e667-d722-4e1e-807b-84a39cbc41b1
md"""
# Bellman Residual Operators
# Bellman Residual Minimization of Operators
"""
# ╔═╡ 9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede
md"""
Number of Constellations: $(const N_constellations = 4)
Define the basic dimensions of the model
- Number of Constellations: $(const N_constellations = 4)
- Number of Debris Trackers: $(const N_debris = 1)
- Number of Overall States: $(const N_states = N_constellations + N_debris)
"""
Number of Debris Trackers: $(const N_debris = 1)
# ╔═╡ 0d209540-29ad-4b2f-9e91-fac2bbee47ff
md"""
## Describe the physics of the world
Number of Overall States: $(const N_states = N_constellations + N_debris)
Three key functions
- Survival function, describing how colisions occur
- Debris Evolution
- Constellation Satellite Stocks Evolution
"""
# ╔═╡ 9b1c2f62-ec19-4667-8f93-f8b52750e317
md"""
### Parameterize the physical world
"""
# ╔═╡ 90446134-4e45-471c-857d-4e165e51937a
begin
begin
abstract type PhysicalParameters end
#setup physical model
@ -64,6 +79,9 @@ begin
,satellite_loss_debris_rate
,launch_debris_rate
)
end
end
# ╔═╡ 5b45b29e-f0f4-41e9-91e7-d444687feb4e
@ -160,41 +178,8 @@ function planner_policy_function_generator(number_params=32)
end
# ╔═╡ 6a3b5f7a-a535-450f-8c5f-19bdcc280146
function operator_policy_function_generator(number_params=32)
return Flux.Chain(
Flux.Parallel(vcat
#parallel joins together stocks and debris
,Flux.Chain(
Flux.Dense(N_constellations, number_params*2,Flux.relu)
#,Flux.Dense(number_params, number_params,Flux.σ)
)
,Flux.Chain(
Flux.Dense(N_debris, number_params,Flux.relu)
#,Flux.Dense(number_params, number_params)
)
)
#Apply some transformations
,Flux.Dense(number_params*3,number_params,Flux.σ)
,Flux.Dense(number_params,1,Flux.relu)
)
end
# ╔═╡ 9504bf46-e380-4933-8693-03ef3e92a4e4
ppf = planner_policy_function_generator()
# ╔═╡ ac4e29e6-474e-489b-adc9-549dfe27a465
function combine_policies(operators, data)
policies = [co.policy_function(data)[1] for co=operators]
policies
end
# ╔═╡ 95bfc9d8-8427-41d6-9f0f-f155296eef91
#not needed yet
#I don't think that these are going to be needed
begin
#= CUSTOM LAYERS
@ -277,55 +262,13 @@ struct ConstellationOperator
payoff_fn::Function
econ_params::EconomicParameters
value::Flux.Chain
loss_fn::Function
policy_params::Flux.Params
#policy_params::Flux.Params #cutting this for now
end
#TODO: create a function that takes this struct and checks backprop
# ╔═╡ 9a97a41f-475b-458e-9f7c-aabed95c6f54
# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c
md"""
# Loss function specification
"""
# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
#Constellation level loss function
function Ξ(
s::Vector{Float32}
,d::Vector{Float32}
, physical_model::PhysicalParameters
,cos::Array{ConstellationOperator}
,operator_number::Int
)
co = cos[operator_number]
a = combine_policies(cos,(s,d))
s = G(s,d,a,physical_model)
d = H(s,d,a,physical_model)
bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s,d))
maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s,d))
return sum([bellman_residuals.^2 maximization_condition])
end
# ╔═╡ 9d4a668a-23e3-4f36-86f4-60e242caee3b
begin
s1 = ones(Float32,N_constellations)
d1 = ones(Float32,N_debris)
end
# ╔═╡ d8deba52-dc0c-470e-81bf-f9d7cc595a41
ppf((s1,d1))
# ╔═╡ 560ff22a-44b6-4c5c-a0fe-9bc95e97ef06
# ╔═╡ b433a7ec-8264-48d6-8b95-53d2ec4bad05
md"""
# examples of parameter models
### examples of parameter models
"""
# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
@ -355,80 +298,114 @@ begin
=#
end
# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3
begin
a = operator_policy_function_generator()
b = operator_policy_function_generator()
c = operator_policy_function_generator()
d = operator_policy_function_generator()
tops = [
ConstellationOperator(payoff1,em2_a,value_function_generator(),a,params(a))
,ConstellationOperator(payoff1,em2_b,value_function_generator(),b,params(b))
,ConstellationOperator(payoff1,em2_c,value_function_generator(),c,params(c))
,ConstellationOperator(payoff1,em2_d,value_function_generator(),d,params(d))
]
end
# ╔═╡ 2f14fb8e-7f71-420f-bd24-f94a4b37b0a8
a
# ╔═╡ dc614254-c211-4552-b985-03020bfc5ab3
em3 = CESParams(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002]))
#=
This is a variation on a CES model.
# ╔═╡ 9bd80252-bf0f-421b-a747-9b41cbc82edf
a((s1,d1))
The model is CES the relationship between payoffs and debris.
In this particular specification, the only interaction is in debris
=#
# ╔═╡ 7630fccb-5169-4d46-96ea-1968baed89a2
e = combine_policies(tops,([1,2,3,4.0],[2.0]) )
# ╔═╡ 4452ab2c-813b-4b7b-8cb3-388ec507a24c
md"""
## Setup Operators Policy Functions
"""
# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29
begin #testing
Ξ(s1,d1,bm,tops,1)
end
# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c
md"""
# Loss function specification
"""
# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58
#planner's loss function
function planners_loss(
# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
#Constellation operator loss function
function Ξ(
s::Vector{Float32}
,d::Vector{Float32}
,physical_model::PhysicalParameters
,co::ConstellationOperator
,policy::Flux.Chain
)
l = 0.0
for (i,co) in enumerate(tops)
l += Ξ(s,d,bm,tops,i)
end
return l
end
#get actions
a = policy((s,d))
#get updated stocks and debris
s = G(s,d,a,physical_model)
d = H(s,d,a,physical_model)
# ╔═╡ dc614254-c211-4552-b985-03020bfc5ab3
em3 = CESParams(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002]))
#=
This is a variation on a CES model.
The model is CES the relationship between payoffs and debris.
In this particular specification, the only interaction is in debris
=#
bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s,d))
maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s,d))
return sum(bellman_residuals.^2, maximization_condition)
end
# ╔═╡ 9d4a668a-23e3-4f36-86f4-60e242caee3b
begin #testing data
s1 = ones(Float32,N_constellations)
d1 = ones(Float32,N_debris)
end
# ╔═╡ cd55e232-493d-4849-8bd7-b0ba85e21bab
md"""
# Start setting things up
"""
# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
policy = planner_policy_function_generator();
# ╔═╡ f30904a7-5caa-449a-a5bd-f2aa78777a9a
# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3
begin
#setup the operators
operators = [ ConstellationOperator(payoff1,em2_a,value_function_generator())
const tops = [
ConstellationOperator(payoff1,em2_a,value_function_generator())
,ConstellationOperator(payoff1,em2_b,value_function_generator())
,ConstellationOperator(payoff1,em2_c,value_function_generator())
,ConstellationOperator(payoff1,em2_d,value_function_generator())
]
#check whether or not we've matched the setup correctly.
@assert length(operators) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized"
#sanity Check time
@assert length(tops) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized"
end
# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29
begin #testing
Ξ(s1,d1,bm,tops[1],collected_policies)
end
# ╔═╡ 89258b2e-cee5-4cbd-a42e-21301b7a2549
begin
#=
Used to provide an interable loss function for training
=#
struct OperatorLoss
physical_model::PhysicalParameters
co::ConstellationOperator
collected_policies::Flux.Chain
policy_params::Flux.Params
end
function (self::OperatorLoss)(
s::Vector{Float32}
,d::Vector{Float32}
)
Ξ(s ,d ,self.physical_model ,self.co ,self.collected_policies )
end
#do the same thing with the planner's problem
struct PlannerLoss
end
function planners_loss(
s::Vector{Float32}
,d::Vector{Float32}
)
l = 0.0
for (i,co) in enumerate(tops)
l += Ξ(s,d,bm,tops,i)
end
return l
end
end
# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
policy = planner_policy_function_generator();
# ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166
md"""
## Constructing data and training
@ -436,13 +413,22 @@ md"""
# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854
begin
N=200 #increase later
data = [(rand(1:500f0, N_constellations),rand(1:500f0, N_debris)) for n=1:N]
struct DataConstructor
N::UInt64
end
(dc::DataConstructor)(bottom=1f0,top=500f0) = [(rand(bottom:top, N_constellations),rand(bottom:top, N_debris)) for n=1:dc.N]
data = DataConstructor(200)
data1 = data()
end
# ╔═╡ 6bf8d29a-7990-4e91-86e6-d9894ed3db27
#optimizer
ADAM = Flux.Optimise.ADAM(0.1)
opt = Flux.Optimise.ADAM(0.1)
# ╔═╡ 30417194-6cd1-4e10-9e25-1fa1c0761b9a
# ╔═╡ e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d
begin
@ -1056,45 +1042,39 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
# ╔═╡ Cell order:
# ╠═0829eb90-1d74-46b6-80ec-482a2b71c6fe
# ╟─66f0e667-d722-4e1e-807b-84a39cbc41b1
# ╠═9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede
# ╟─9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede
# ╟─0d209540-29ad-4b2f-9e91-fac2bbee47ff
# ╟─5b45b29e-f0f4-41e9-91e7-d444687feb4e
# ╟─152f3a3c-a565-41bb-8e59-6ab0d2315ffb
# ╟─25ac9438-2b1d-4f6b-9ff1-1695e1d52b51
# ╠═9b1c2f62-ec19-4667-8f93-f8b52750e317
# ╟─90446134-4e45-471c-857d-4e165e51937a
# ╠═5b45b29e-f0f4-41e9-91e7-d444687feb4e
# ╠═152f3a3c-a565-41bb-8e59-6ab0d2315ffb
# ╠═25ac9438-2b1d-4f6b-9ff1-1695e1d52b51
# ╠═29ff1777-d276-4e8f-8582-4ca191f2e2ff
# ╟─29ff1777-d276-4e8f-8582-4ca191f2e2ff
# ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c
# ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a
# ╠═6a3b5f7a-a535-450f-8c5f-19bdcc280146
# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3
# ╠═2f14fb8e-7f71-420f-bd24-f94a4b37b0a8
# ╠═9bd80252-bf0f-421b-a747-9b41cbc82edf
# ╠═9504bf46-e380-4933-8693-03ef3e92a4e4
# ╠═d8deba52-dc0c-470e-81bf-f9d7cc595a41
# ╠═ac4e29e6-474e-489b-adc9-549dfe27a465
# ╠═7630fccb-5169-4d46-96ea-1968baed89a2
# ╠═95bfc9d8-8427-41d6-9f0f-f155296eef91
# ╠═bbca5143-f314-40ea-a20e-8a043272e362
# ╠═340da189-f443-4376-a82d-7699a21ab7a2
# ╠═206ac4cc-5102-4381-ad8a-777b02dc4d5a
# ╠═eebb8706-a431-4fd1-b7a5-40f07a63d5cb
# ╠═f8d582cb-10cf-4c72-8127-787f662e0567
# ╠═9a97a41f-475b-458e-9f7c-aabed95c6f54
# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c
# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b
# ╠═43b99708-0052-4b78-886c-92ac2b532f29
# ╠═560ff22a-44b6-4c5c-a0fe-9bc95e97ef06
# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58
# ╠═b433a7ec-8264-48d6-8b95-53d2ec4bad05
# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
# ╠═19ccfc3a-6dbb-4c64-bf03-e2e219ef0efe
# ╠═dc614254-c211-4552-b985-03020bfc5ab3
# ╠═4452ab2c-813b-4b7b-8cb3-388ec507a24c
# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c
# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b
# ╠═43b99708-0052-4b78-886c-92ac2b532f29
# ╠═89258b2e-cee5-4cbd-a42e-21301b7a2549
# ╠═cd55e232-493d-4849-8bd7-b0ba85e21bab
# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3
# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
# ╠═f30904a7-5caa-449a-a5bd-f2aa78777a9a
# ╠═5abebc1a-370c-4f5f-8826-dc0b143d5166
# ╠═a20959be-65e4-4b69-9521-503bc59f0854
# ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27
# ╠═30417194-6cd1-4e10-9e25-1fa1c0761b9a
# ╠═e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d
# ╠═02f3fe78-e7a7-453f-9ddf-acddf08d8676
# ╠═c50b1d39-fe87-441b-935c-c5fe971d09ef

@ -60,10 +60,13 @@ b(([1],[1]))
b([1],[2])
# ╔═╡ c1420fc5-de42-4b32-9e13-e859e111996b
md"""
Again, very similar to previous work, but without the vcat
"""
# ╔═╡ 46c9f81c-3a8c-463a-952e-32287d11f1db
c = Flux.Parallel(vcat
c =
Flux.Parallel(vcat
#add chain here
,Flux.Chain(
Flux.Parallel(vcat
@ -92,10 +95,13 @@ c = Flux.Parallel(vcat
c(([1],[2]),([3],[4]),([1],[2]))
# ╔═╡ 10938bcf-dbc5-4109-b8a2-ee21c482f610
c(([1],[2],[3],[4],[1],[2]))
c([1],[2],[3],[4],[1],[2]) #ignores the last 3
# ╔═╡ e318c3f0-3828-41fc-9fd0-de2ae3d19e2f
c([1],[2],[3],[4])
c([1],[2],[3])
# ╔═╡ 0da65b18-e511-4203-a5af-cdb0096acbbf
c(([1],[2])) #don't run the 3rd branch
# ╔═╡ ec1bd1c4-fc35-482b-b186-7f95cb332463
md"""
@ -105,11 +111,73 @@ In the second case, each sub-branch gets its own details.
The big test will be whether training on each parallel branch will affect anything other than that entry
"""
# ╔═╡ 7efd82a6-6968-4b44-82f8-a955edf77532
f(x) = (x,x)
# ╔═╡ 0d85a68f-3a5d-4aab-ab0a-b9830e50c383
begin #move to module
#= TupleDuplicator
This is used to create a tuple of size n with deepcopies of any object x
=#
struct TupleDuplicator
n::Int
end
(f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...)
#=
This generates a policy function full of branches with the properly scaled sides
=#
struct BranchGenerator
n::UInt8 #limit to 2^8 operators
end
function (b::BranchGenerator)(branch::Flux.Chain,join_fn::Function)
# used to deepcopy the branches and duplicate the inputs in the returned chain
f = TupleDuplicator(b.n)
return Flux.Chain(
f
,Flux.Parallel(join_fn
,f(branch)
)
)
end
end
# ╔═╡ 2f6b7042-5cf5-4312-a617-dbeb08e05175
bg3 = BranchGenerator(3)
# ╔═╡ b9fe5e74-a524-4c74-8c88-b26204ffa57b
begin
#Setup branch to duplicate
d = Flux.Chain(
Flux.Parallel(vcat
,Dense(1,2,Flux.σ)
,Dense(1,2)
)
,Dense(4,1)
)
#build branch
e = bg3(d, vcat)
end
# ╔═╡ 5a03a348-93bf-4298-b483-57f1256a09fb
e(([2],[1]))
# ╔═╡ 28ff091c-fa33-485a-b425-2197a7915419
loss(x) = sum(abs2,e(x)) #force to zero
# ╔═╡ 3f0c5427-6a45-44cb-a83e-8f3829c5f3cf
loss(([2],[1]))
# ╔═╡ 03e2712c-0edb-4fd0-ac1f-cd34be4bcc03
opt = Flux.Optimise.ADAGrad()
# ╔═╡ 64644280-577a-4d31-9863-6c9914bba94c
params = Flux.params(e[2][1])
# ╔═╡ 19a3a52f-3ecf-46ff-ac4d-5d44ef3a7821
Flux.train!(loss, params, (([2],[1])), opt)
# ╔═╡ 92dc1ae3-1286-4c51-91fb-c75abd897674
f((2,3))
# ╔═╡ ad690dd7-6fad-4194-9cf2-ab33b1b23a11
e(([2],[1]))
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
@ -704,8 +772,17 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
# ╠═8c651a48-a919-4e13-95ed-e0bd929e8b64
# ╠═10938bcf-dbc5-4109-b8a2-ee21c482f610
# ╠═e318c3f0-3828-41fc-9fd0-de2ae3d19e2f
# ╠═0da65b18-e511-4203-a5af-cdb0096acbbf
# ╠═ec1bd1c4-fc35-482b-b186-7f95cb332463
# ╠═7efd82a6-6968-4b44-82f8-a955edf77532
# ╠═92dc1ae3-1286-4c51-91fb-c75abd897674
# ╠═0d85a68f-3a5d-4aab-ab0a-b9830e50c383
# ╠═2f6b7042-5cf5-4312-a617-dbeb08e05175
# ╠═b9fe5e74-a524-4c74-8c88-b26204ffa57b
# ╠═5a03a348-93bf-4298-b483-57f1256a09fb
# ╠═28ff091c-fa33-485a-b425-2197a7915419
# ╠═3f0c5427-6a45-44cb-a83e-8f3829c5f3cf
# ╠═03e2712c-0edb-4fd0-ac1f-cd34be4bcc03
# ╠═64644280-577a-4d31-9863-6c9914bba94c
# ╠═19a3a52f-3ecf-46ff-ac4d-5d44ef3a7821
# ╠═ad690dd7-6fad-4194-9cf2-ab33b1b23a11
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002

Loading…
Cancel
Save