|
|
|
|
@ -25,9 +25,6 @@ Number of Overall States: $(const N_states = N_constellations + N_debris)
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 90446134-4e45-471c-857d-4e165e51937a
|
|
|
|
|
begin
|
|
|
|
|
md"""
|
|
|
|
|
Parameterization
|
|
|
|
|
"""
|
|
|
|
|
abstract type PhysicalParameters end
|
|
|
|
|
|
|
|
|
|
#setup physical model
|
|
|
|
|
@ -123,17 +120,17 @@ md"""
|
|
|
|
|
function value_function_generator(number_params=10)
|
|
|
|
|
return Flux.Chain(
|
|
|
|
|
Flux.Parallel(vcat
|
|
|
|
|
#parallel joins together stocks and debris
|
|
|
|
|
#parallel joins together stocks and debris, after a little bit of preprocessing
|
|
|
|
|
,Flux.Chain(
|
|
|
|
|
Flux.Dense(N_constellations, N_states*2,Flux.relu)
|
|
|
|
|
#,Flux.Dense(N_states*2, N_states*2,Flux.σ)
|
|
|
|
|
,Flux.Dense(N_states*2, N_states*2,Flux.σ)
|
|
|
|
|
)
|
|
|
|
|
,Flux.Chain(
|
|
|
|
|
Flux.Dense(N_debris, N_states,Flux.relu)
|
|
|
|
|
#,Flux.Dense(N_states, N_states)
|
|
|
|
|
,Flux.Dense(N_states, N_states)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
#Apply some transformations
|
|
|
|
|
#Apply some transformations to the preprocessed data.
|
|
|
|
|
,Flux.Dense(N_states*3,number_params,Flux.σ)
|
|
|
|
|
,Flux.Dense(number_params,1,Flux.σ)
|
|
|
|
|
)
|
|
|
|
|
@ -147,11 +144,11 @@ function policy_function_generator(number_params=10)
|
|
|
|
|
#parallel joins together stocks and debris
|
|
|
|
|
,Flux.Chain(
|
|
|
|
|
Flux.Dense(N_constellations, N_states*2,Flux.relu)
|
|
|
|
|
#,Flux.Dense(N_states*2, N_states*2,Flux.σ)
|
|
|
|
|
,Flux.Dense(N_states*2, N_states*2,Flux.σ)
|
|
|
|
|
)
|
|
|
|
|
,Flux.Chain(
|
|
|
|
|
Flux.Dense(N_debris, N_states,Flux.relu)
|
|
|
|
|
#,Flux.Dense(N_states, N_states)
|
|
|
|
|
,Flux.Dense(N_states, N_states)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
#Apply some transformations
|
|
|
|
|
@ -162,22 +159,32 @@ function policy_function_generator(number_params=10)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 95bfc9d8-8427-41d6-9f0f-f155296eef91
|
|
|
|
|
#not needed
|
|
|
|
|
#not needed yet
|
|
|
|
|
begin
|
|
|
|
|
#=
|
|
|
|
|
Test a return in tuples. Just to see what can happen.
|
|
|
|
|
#= CUSTOM LAYERS
|
|
|
|
|
|
|
|
|
|
=#
|
|
|
|
|
#Custom passthrough layer
|
|
|
|
|
passthrough(x::Array) = x
|
|
|
|
|
Tuple(a::Array,b::Array) = (a,b)
|
|
|
|
|
Flux.Parallel(Tuple,
|
|
|
|
|
passthrough,passthrough
|
|
|
|
|
)(([1,2],[3,4]));
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
|
|
|
|
|
begin
|
|
|
|
|
value = value_function_generator();
|
|
|
|
|
policy = policy_function_generator();
|
|
|
|
|
|
|
|
|
|
# custom split layer
|
|
|
|
|
struct Split{T}
|
|
|
|
|
paths::T
|
|
|
|
|
end
|
|
|
|
|
Split(paths...) = Split(paths)
|
|
|
|
|
Flux.@functor Split
|
|
|
|
|
(m::Split)(x::AbstractArray) = tuple(map(f -> f(x), m.paths))
|
|
|
|
|
|
|
|
|
|
### TESTING ###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#multiple branches
|
|
|
|
|
Flux.Parallel(vcat,
|
|
|
|
|
passthrough, passthrough, passthrough
|
|
|
|
|
)(([1],[2,3],[4]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 206ac4cc-5102-4381-ad8a-777b02dc4d5a
|
|
|
|
|
@ -190,9 +197,6 @@ begin
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
|
|
|
|
|
em = EconModel1(0.95, [1 0 0 ], [5 0 0 ])
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 1cbaa2e5-55e4-46f9-82d0-04b481470094
|
|
|
|
|
function payoff1(
|
|
|
|
|
s::Vector
|
|
|
|
|
@ -203,66 +207,122 @@ function payoff1(
|
|
|
|
|
return em.payoff_array*s - em.policy_costs*a
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ f8d582cb-10cf-4c72-8127-787f662e0567
|
|
|
|
|
#=
|
|
|
|
|
This struct organizes information about a given constellation operator
|
|
|
|
|
=#
|
|
|
|
|
struct ConstellationOperator
|
|
|
|
|
payoff_fn::Function
|
|
|
|
|
econ_params::EconomicParameters
|
|
|
|
|
value::Flux.Chain
|
|
|
|
|
end
|
|
|
|
|
#TODO: create a function that takes this struct and checks backprop
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c
|
|
|
|
|
md"""
|
|
|
|
|
# Loss function specification
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# ╔═╡ b433a7ec-8264-48d6-8b95-53d2ec4bad05
|
|
|
|
|
md"""
|
|
|
|
|
# Testing
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
|
|
|
|
|
policy = policy_function_generator();
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
|
|
|
|
|
#Constellation level loss function
|
|
|
|
|
function Ξ(
|
|
|
|
|
s::Vector
|
|
|
|
|
,d::Vector
|
|
|
|
|
, physical_model::PhysicalParameters
|
|
|
|
|
,econ_params::EconomicParameters
|
|
|
|
|
,payoff_fn::Function
|
|
|
|
|
,co::ConstellationOperator
|
|
|
|
|
)
|
|
|
|
|
a = policy((s,d))
|
|
|
|
|
s′ = G(s,d,a,physical_model)
|
|
|
|
|
d′ = H(s,d,a,physical_model)
|
|
|
|
|
|
|
|
|
|
bellman_residuals = value((s,d)) - payoff_fn(s,d,a,econ_params) - econ_params.β*value((s′,d′))
|
|
|
|
|
maximization_condition = - payoff_fn(s,d,a,econ_params) - econ_params.β*value((s′,d′))
|
|
|
|
|
bellman_residuals = co.value((s,d)) - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′))
|
|
|
|
|
maximization_condition = - co.payoff_fn(s,d,a,co.econ_params) - co.econ_params.β*co.value((s′,d′))
|
|
|
|
|
|
|
|
|
|
return sum([bellman_residuals.^2 maximization_condition])
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854
|
|
|
|
|
# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
|
|
|
|
|
em1 = EconModel1(0.95, [1 0 0 ], [5 0 0 ])
|
|
|
|
|
|
|
|
|
|
# ╔═╡ f30904a7-5caa-449a-a5bd-f2aa78777a9a
|
|
|
|
|
begin
|
|
|
|
|
N=12
|
|
|
|
|
data = [(rand(1:500, N_constellations),rand(1:500, N_debris)) for n=1:N]
|
|
|
|
|
#setup the operators
|
|
|
|
|
operators = [ ConstellationOperator(payoff1,em1,value_function_generator())
|
|
|
|
|
,ConstellationOperator(payoff1,em1,value_function_generator())
|
|
|
|
|
,ConstellationOperator(payoff1,em1,value_function_generator())
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
#check whether or not we've matched the setup correctly.
|
|
|
|
|
@assert length(operators) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized"
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58
|
|
|
|
|
#loss function
|
|
|
|
|
loss(s,d) = Ξ(s,d,bm,em,payoff1)
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 20c777b5-4295-4478-8f53-b18cd409c8ae
|
|
|
|
|
s1 = ones(N_constellations)
|
|
|
|
|
# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29
|
|
|
|
|
begin
|
|
|
|
|
s1 = ones(N_constellations)
|
|
|
|
|
d1 = ones(N_debris)
|
|
|
|
|
Ξ(s1,d1,bm,operators[1])
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 1d65707c-6333-4252-ace6-bad47146ba06
|
|
|
|
|
d1 = ones(N_debris)
|
|
|
|
|
# ╔═╡ dff642d9-ec5a-4fed-a059-6c07760a3a58
|
|
|
|
|
#planner's loss function
|
|
|
|
|
function planners_loss(s,d)
|
|
|
|
|
l = 0.0
|
|
|
|
|
for co in operators
|
|
|
|
|
l += Ξ(s,d,bm,co)
|
|
|
|
|
end
|
|
|
|
|
return l
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 43b99708-0052-4b78-886c-92ac2b532f29
|
|
|
|
|
Ξ(s1,d1,bm,em,payoff1)
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 39433c1a-c3ac-45b0-b1bf-ff2d42ca9cbb
|
|
|
|
|
# ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166
|
|
|
|
|
md"""
|
|
|
|
|
## Constructing data
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854
|
|
|
|
|
begin
|
|
|
|
|
N=20 #increase later
|
|
|
|
|
data = [(rand(1:500, N_constellations),rand(1:500, N_debris)) for n=1:N]
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 6bf8d29a-7990-4e91-86e6-d9894ed3db27
|
|
|
|
|
#optimizer
|
|
|
|
|
ADAM = Flux.Optimise.ADAM()
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 74f5fde3-0593-46fc-a688-f1db7ab28c64
|
|
|
|
|
for epoch in 1:200
|
|
|
|
|
#train the policy funciton
|
|
|
|
|
Flux.Optimise.train!(loss, params(policy), data, ADAM)
|
|
|
|
|
ADAM = Flux.Optimise.ADAM(0.01)
|
|
|
|
|
|
|
|
|
|
#Sweep through the value functions:w
|
|
|
|
|
# ╔═╡ e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d
|
|
|
|
|
begin
|
|
|
|
|
accum1 = 0.0
|
|
|
|
|
for d in data
|
|
|
|
|
accum1 += planners_loss(d...)
|
|
|
|
|
end
|
|
|
|
|
accum1/N
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
#Train the value functions
|
|
|
|
|
Flux.Optimise.train!(loss, params(value), data, ADAM)
|
|
|
|
|
# ╔═╡ 74f5fde3-0593-46fc-a688-f1db7ab28c64
|
|
|
|
|
# Social planners problem
|
|
|
|
|
for epoch in 1:20
|
|
|
|
|
#train the social planner's policy funciton
|
|
|
|
|
Flux.Optimise.train!(planners_loss, params(policy), data, ADAM)
|
|
|
|
|
|
|
|
|
|
#Sweep through training the value functions
|
|
|
|
|
for co in operators
|
|
|
|
|
Flux.Optimise.train!(planners_loss, params(co.value), data, ADAM)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
# ╔═╡ 02f3fe78-e7a7-453f-9ddf-acddf08d8676
|
|
|
|
|
begin
|
|
|
|
|
accum = 0.0
|
|
|
|
|
for d in data
|
|
|
|
|
accum += loss(d...)
|
|
|
|
|
accum += planners_loss(d...)
|
|
|
|
|
end
|
|
|
|
|
accum/N
|
|
|
|
|
end
|
|
|
|
|
@ -838,28 +898,31 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
|
|
|
|
|
|
|
|
|
|
# ╔═╡ Cell order:
|
|
|
|
|
# ╠═0829eb90-1d74-46b6-80ec-482a2b71c6fe
|
|
|
|
|
# ╠═66f0e667-d722-4e1e-807b-84a39cbc41b1
|
|
|
|
|
# ╟─90446134-4e45-471c-857d-4e165e51937a
|
|
|
|
|
# ╟─66f0e667-d722-4e1e-807b-84a39cbc41b1
|
|
|
|
|
# ╠═9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede
|
|
|
|
|
# ╟─90446134-4e45-471c-857d-4e165e51937a
|
|
|
|
|
# ╟─5b45b29e-f0f4-41e9-91e7-d444687feb4e
|
|
|
|
|
# ╟─152f3a3c-a565-41bb-8e59-6ab0d2315ffb
|
|
|
|
|
# ╟─25ac9438-2b1d-4f6b-9ff1-1695e1d52b51
|
|
|
|
|
# ╠═29ff1777-d276-4e8f-8582-4ca191f2e2ff
|
|
|
|
|
# ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c
|
|
|
|
|
# ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a
|
|
|
|
|
# ╟─95bfc9d8-8427-41d6-9f0f-f155296eef91
|
|
|
|
|
# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
|
|
|
|
|
# ╠═95bfc9d8-8427-41d6-9f0f-f155296eef91
|
|
|
|
|
# ╠═206ac4cc-5102-4381-ad8a-777b02dc4d5a
|
|
|
|
|
# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
|
|
|
|
|
# ╠═1cbaa2e5-55e4-46f9-82d0-04b481470094
|
|
|
|
|
# ╠═f8d582cb-10cf-4c72-8127-787f662e0567
|
|
|
|
|
# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c
|
|
|
|
|
# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
|
|
|
|
|
# ╠═a20959be-65e4-4b69-9521-503bc59f0854
|
|
|
|
|
# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58
|
|
|
|
|
# ╠═20c777b5-4295-4478-8f53-b18cd409c8ae
|
|
|
|
|
# ╠═1d65707c-6333-4252-ace6-bad47146ba06
|
|
|
|
|
# ╠═43b99708-0052-4b78-886c-92ac2b532f29
|
|
|
|
|
# ╠═39433c1a-c3ac-45b0-b1bf-ff2d42ca9cbb
|
|
|
|
|
# ╠═dff642d9-ec5a-4fed-a059-6c07760a3a58
|
|
|
|
|
# ╟─b433a7ec-8264-48d6-8b95-53d2ec4bad05
|
|
|
|
|
# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
|
|
|
|
|
# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
|
|
|
|
|
# ╠═f30904a7-5caa-449a-a5bd-f2aa78777a9a
|
|
|
|
|
# ╟─5abebc1a-370c-4f5f-8826-dc0b143d5166
|
|
|
|
|
# ╠═a20959be-65e4-4b69-9521-503bc59f0854
|
|
|
|
|
# ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27
|
|
|
|
|
# ╠═e7ee1a0f-ab9b-439e-a7be-4a6d3b8f160d
|
|
|
|
|
# ╠═74f5fde3-0593-46fc-a688-f1db7ab28c64
|
|
|
|
|
# ╠═02f3fe78-e7a7-453f-9ddf-acddf08d8676
|
|
|
|
|
# ╟─00000000-0000-0000-0000-000000000001
|
|
|
|
|
|