working convergence, with a saved model

temporaryWork^2
will king 4 years ago
parent be47bfc2f1
commit 333bc5000c

@ -7,6 +7,17 @@ using InteractiveUtils
# ╔═╡ 0829eb90-1d74-46b6-80ec-482a2b71c6fe
using PlutoUI, Flux,LinearAlgebra
# ╔═╡ f360390b-1d15-4443-aad6-a5800a6ba776
using BSON: @save,@load
# ╔═╡ 9a1ebef5-ed53-4e0b-89e7-20a20819b032
using Dates,BSON
# ╔═╡ fe446c7e-ed8c-4a0b-8bc9-024806f9f352
md"""
$(Dates.now())
"""
# ╔═╡ 66f0e667-d722-4e1e-807b-84a39cbc41b1
md"""
# Bellman Residual Minimization of Operators
@ -133,9 +144,10 @@ begin
return LinearAlgebra.diagm(survival(stocks,debris,self.physical_model) .- self.physical_model.decay_rate)*stocks + launches
end
end
#TODO: investigate combining the stocks and debris using some sort of state struct <: State type. Maybe combine with the physical model?
# ╔═╡ 6bf01e30-d054-4a98-9e0d-58d7da4e5524
begin
begin # build the transition models
H = BasicDebrisEvolution(bm)
G = BasicStockEvolution(bm)
end;
@ -189,8 +201,53 @@ function planner_policy_function_generator(number_params=32)
end
# ╔═╡ e4e2f50e-05fc-42da-a6d3-d62c4f901d84
function operator_policy_function_generator(number_params=32)
return Flux.Chain(
Flux.Parallel(vcat
#parallel joins together stocks and debris
,Flux.Chain(
Flux.Dense(N_constellations, number_params*2,Flux.relu)
,Flux.Dense(number_params*2, number_params*2,Flux.tanh)
)
,Flux.Chain(
Flux.Dense(N_debris, number_params,Flux.relu)
,Flux.Dense(number_params, number_params)
)
)
#Apply some transformations
,Flux.Dense(number_params*3,number_params,Flux.tanh)
,Flux.Dense(number_params,number_params,Flux.tanh)
,Flux.Dense(number_params,1,x -> Flux.relu.(sinh.(x)))
)
end
# ╔═╡ ae6681e0-c796-4145-a313-75f74b4993ad
#add branch generator
begin #move to module
#= TupleDuplicator
This is used to create a tuple of size n with deepcopies of any object x
=#
struct TupleDuplicator
n::Int
end
(f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...)
#=
This generates a policy function full of branches with the properly scaled sides
=#
struct BranchGenerator
n::UInt8
#limit to 2^8 operators
end
function (b::BranchGenerator)(branch::Flux.Chain,join_fn::Function)
# used to deepcopy the branches and duplicate the inputs in the returned chain
f = TupleDuplicator(b.n)
return Flux.Chain(f,Flux.Parallel(join_fn,f(branch)))
end
end
# ╔═╡ bbca5143-f314-40ea-a20e-8a043272e362
md"""
@ -201,13 +258,14 @@ md"""
abstract type EconomicModel end
# ╔═╡ 206ac4cc-5102-4381-ad8a-777b02dc4d5a
begin #basic linear model
struct EconModel1 <: EconomicModel
begin
#basic linear model
struct LinearModel <: EconomicModel
β::Float32
payoff_array::Array{Float32}
policy_costs::Array{Float32}
end
function (em::EconModel1)(
function (em::LinearModel)(
s::Vector{Float32}
,d::Vector{Float32}
,a::Vector{Float32}
@ -217,20 +275,47 @@ begin #basic linear model
end
# ╔═╡ eebb8706-a431-4fd1-b7a5-40f07a63d5cb
begin #basic CES model
struct CESParams <: EconomicModel
begin
#basic CES model
struct CES <: EconomicModel
β::Float32
r::Float32 #elasticity of subsititution
payoff_array::Array{Float32}
policy_costs::Array{Float32}
debris_costs::Array{Float32}
end
function (em::CESParams)(
function (em::CES)(
s::Vector{Float32}
,d::Vector{Float32}
,a::Vector{Float32}
)
return (em.payoff_array*(s.^em.r) - em.debris_costs*(d.^em.r)).^(1/em.r) - em.policy_costs*a
#issue here with multiplication
r1 = em.payoff_array .* (s.^em.r)
r2 = - em.debris_costs .* (d.^em.r)
r3 = - em.policy_costs .* (a.^em.r)
return (r1 + r2 + r3) .^ (1/em.r)
end
#basic CRRA
struct CRRA <: EconomicModel
β::Float32
σ::Float32 #elasticity of subsititution
payoff_array::Array{Float32}
policy_costs::Array{Float32}
debris_costs::Array{Float32}
end
function (em::CRRA)(
s::Vector{Float32}
,d::Vector{Float32}
,a::Vector{Float32}
)
#issue here
core = (em.payoff_array*s - em.debris_costs*d - em.policy_costs).^(1 - em.σ)
return (core-1)/(1-em.σ)
end
end
@ -241,10 +326,10 @@ md"""
# ╔═╡ 65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
begin
em1_a = EconModel1(0.95, [1.0 0 0 0], [5 0 0 0])
em1_b = EconModel1(0.95, [0 1.0 0 0], [0 5 0 0])
em1_c = EconModel1(0.95, [0 0 1.0 0], [0 0 5 0])
em1_d = EconModel1(0.95, [0 0 0 1.0], [0 0 0 5])
em1_a = LinearModel(0.95, [1.0 0 0 0], [5 0 0 0])
em1_b = LinearModel(0.95, [0 1.0 0 0], [0 5 0 0])
em1_c = LinearModel(0.95, [0 0 1.0 0], [0 0 5 0])
em1_d = LinearModel(0.95, [0 0 0 1.0], [0 0 0 5])
#=
This is the most basic profit model
@ -255,10 +340,10 @@ end
# ╔═╡ 19ccfc3a-6dbb-4c64-bf03-e2e219ef0efe
begin
em2_a = EconModel1(0.95, [1 -0.02 -0.02 0], [5.0 0 0 0])
em2_b = EconModel1(0.95, [-0.02 1 -0.02 0], [0.0 5 0 0])
em2_c = EconModel1(0.95, [0 -0.02 1 -0.02], [0.0 0 5 0])
em2_d = EconModel1(0.95, [0 -0.02 -0.02 1], [0.0 0 0 5])
em2_a = LinearModel(0.95, [1 -0.02 -0.02 0], [5.0 0 0 0])
em2_b = LinearModel(0.95, [-0.02 1 -0.02 0], [0.0 5 0 0])
em2_c = LinearModel(0.95, [0 -0.02 1 -0.02], [0.0 0 5 0])
em2_d = LinearModel(0.95, [0 -0.02 -0.02 1], [0.0 0 0 5])
#=
This is a simple addition to the basic model, where you lose some benefit based
the size of your competitor's satellites.
@ -267,18 +352,32 @@ begin
end
# ╔═╡ dc614254-c211-4552-b985-03020bfc5ab3
em3 = CESParams(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002]))
#=
This is a variation on a CES model.
begin
em3_a = CES(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002]))
em3_b = CES(0.95,0.6,[0 1 0 0], [0 5 0 0], Vector([0.002]))
em3_c = CES(0.95,0.6,[0 0 1 0], [0 0 5 0], Vector([0.002]))
em3_d = CES(0.95,0.6,[0 0 0 1], [0 0 0 5], Vector([0.002]))
#=
This is a variation on a CES model.
The model is CES the relationship between payoffs and debris.
In this particular specification, the only interaction is in debris
=#
The model is CES the relationship between payoffs and debris.
In this particular specification, the only interaction is in debris
=#
end
# ╔═╡ 128ab84e-9682-4c4e-a301-1e27d9c199a4
begin
em4_a = CRRA(0.95,0.6,[1 0 0 0], [5 0 0 0], Vector([0.002]))
em4_b = CRRA(0.95,0.6,[0 1 0 0], [0 5 0 0], Vector([0.002]))
em4_c = CRRA(0.95,0.6,[0 0 1 0], [0 0 5 0], Vector([0.002]))
em4_d = CRRA(0.95,0.6,[0 0 0 1], [0 0 0 5], Vector([0.002]))
end
# ╔═╡ 52eb1934-77b1-4da3-a36a-afb951f76784
#TODO: create a generator for economic models
# ╔═╡ 717a07a4-5254-40aa-8244-1a2010fedec8
# ╔═╡ 4452ab2c-813b-4b7b-8cb3-388ec507a24c
md"""
## Setup Operators Policy Functions
"""
# ╔═╡ 5946daa3-4608-43f3-8933-dd3eb3f4541c
md"""
@ -292,6 +391,15 @@ begin #testing data
end
# ╔═╡ aa079b60-af16-4824-8a1d-b0c68a0cacbc
begin
#test economic models
em1_a(s1,d1,s1)
em2_a(s1,d1,s1)
#em3_a(s1,d1,s1) #ERROR
#em4_a(s1,d1,s1) #ERROR
end
# ╔═╡ 41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
#Constellation operator loss function
function Ξ(
@ -317,7 +425,7 @@ function Ξ(
return sum(bellman_residuals.^2, maximization_condition)
end
# ╔═╡ 89258b2e-cee5-4cbd-a42e-21301b7a2549
# ╔═╡ 382b6dee-5fa7-4ca1-b804-4b50b0fb65a7
begin
#=
This struct organizes information about a given constellation operator
@ -339,28 +447,50 @@ begin
function (operator::ConstellationOperatorLoss)(
s::Vector{Float32}
,d::Vector{Float32}
; #set default policy to the one held by the ConstellationOperator
policy::Flux.Chain = operator.collected_policies
)
#get actions
a = policy((s,d))
a = operator.collected_policies((s,d))
#get updated stocks and debris
s = stocks_transition(s ,d ,a)
d = debris_transition(s ,d ,a)
s = operator.stocks_transition(s ,d ,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
maximization_condition = - operator.econ_model(s,d,a,co.econ_params) - co.econ_params.β*co.value((s,d))
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
return sum(bellman_residuals.^2, maximization_condition)
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
function (operator::ConstellationOperatorLoss)(
s::Vector{Float32}
,d::Vector{Float32}
,policy::Flux.Chain #allow for another policy to be subsituted
)
#get actions
a = policy((s,d))
#get updated stocks and debris
s = operator.stocks_transition(s ,d ,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
end
# ╔═╡ dc855793-d234-4835-8674-bf56aa0ad0db
begin
#do the same thing with the planner's problem
struct PlannerLoss
#=
@ -369,22 +499,40 @@ begin
There is an issue with appropriately training the value functions.
In this case, it is not happening...
=#
β::Float32
operators::Array{ConstellationOperatorLoss}
planner_policy::Flux.Chain
planner_params::Flux.Params
policy::Flux.Chain
policy_params::Flux.Params
value::Flux.Chain
value_params::Flux.Params
debris_transition::BasicDebrisEvolution
stocks_transition::BasicStockEvolution
end
function (pl::PlannerLoss)(
function (planner::PlannerLoss)(
s::Vector{Float32}
,d::Vector{Float32}
)
l = 0.0
for co in pl.operators
l += co(s ,d ,pl.planner_policy)
end
return l
end
a = planner.policy((s ,d))
#get updated stocks and debris
s = planner.stocks_transition(s ,d ,a)
d = planner.debris_transition(s ,d ,a)
#calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators])
#issue here with mutating. Maybe generators/list comprehensions?
#TODO: Training Functions as routines
bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s,d))
maximization_condition = - benefit - planner.β .* planner.value((s,d))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
end
# ╔═╡ cd55e232-493d-4849-8bd7-b0ba85e21bab
@ -392,24 +540,80 @@ md"""
# Set up Operators and Planner
"""
# ╔═╡ 176981f2-b7f5-4a0f-aae0-940e1db778c7
#=
Setup various policy/planning functions
=#
begin
bg4 = BranchGenerator(N_constellations)
#create identically structured policiy functions
operators_policy = bg4(operator_policy_function_generator(),vcat)
planners_policy = bg4(operator_policy_function_generator(),vcat)
#create the planners_value function
planners_value = value_function_generator(64)
end;
# ╔═╡ 08ef7fbf-005b-40b5-acde-f42750c04cd3
begin #setup operators
const tops = [
ConstellationOperator(,em2_a,value_function_generator())
,ConstellationOperator(payoff1,em2_b,value_function_generator())
,ConstellationOperator(payoff1,em2_c,value_function_generator())
,ConstellationOperator(payoff1,em2_d,value_function_generator())
#Create policies
const operator_array = [
#first operator
ConstellationOperatorLoss(
em2_a
,value_function_generator()
,operators_policy #this is held by all operators
,params(operators_policy[2][1]) #first operator gets first branch of params
,H
,G
)
,ConstellationOperatorLoss(
em2_b
,value_function_generator()
,operators_policy #this is held by all operators
,params(operators_policy[2][2]) #first operator gets first branch of params
,H
,G
)
,ConstellationOperatorLoss(
em2_c
,value_function_generator()
,operators_policy #this is held by all operators
,params(operators_policy[2][3]) #first operator gets first branch of params
,H
,G
)
,ConstellationOperatorLoss(
em2_d
,value_function_generator()
,operators_policy #this is held by all operators
,params(operators_policy[2][4]) #first operator gets first branch of params
,H
,G
)
]
#TODO: setup operator losses
#sanity Check time
@assert length(tops) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized"
#sanity check time
@assert length(operator_array) == N_constellations "Mismatch in predetermined number of constellations and the number of operators initialized"
end
# ╔═╡ fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
#build out planner loss
planner_policy = planner_policy_function_generator();
# ╔═╡ 1eae5450-f9ef-49ec-9dcd-48c025562ecf
pl = PlannerLoss(
0.95
,operator_array
,planners_policy
,params(planners_policy)
,planners_value
,params(planners_value)
,H
,G
)
# ╔═╡ f68e2322-937d-44ce-8aa9-757b27b00603
pl(s1,d1)
# ╔═╡ 5abebc1a-370c-4f5f-8826-dc0b143d5166
md"""
@ -418,13 +622,17 @@ md"""
# ╔═╡ a20959be-65e4-4b69-9521-503bc59f0854
begin
struct DataConstructor
struct UniformDataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end
(dc::DataConstructor)(bottom=1f0,top=500f0) = [(rand(bottom:top, N_constellations),rand(bottom:top, N_debris)) for n=1:dc.N]
(dc::UniformDataConstructor)() = [(rand(dc.satellites_bottom:dc.satellites_top, N_constellations),rand(dc.debris_bottom:dc.debris_top, N_debris)) for n=1:dc.N]
#create a data constructor
dc = DataConstructor(200)
dc = UniformDataConstructor(200,0,200,0,4000)
#get a bit of test data
data1 = dc()
@ -432,19 +640,61 @@ end
# ╔═╡ 6bf8d29a-7990-4e91-86e6-d9894ed3db27
#optimizer
opt = Flux.Optimise.ADAM(0.1)
opt = Flux.Optimise.ADAM()
# ╔═╡ e6654ee2-776c-4b26-a5a3-324188062966
pl(s1,d1)[1]
# ╔═╡ 2b115364-a8ce-4eb9-a324-7ac213061b83
begin
N_epoch = 6
data_gen = UniformDataConstructor(20_000,0,200,0,4000)
end
# ╔═╡ 0679c91b-9824-4017-83fd-4001dc8c2d54
function error(pl,data)
total_error = 0.0f0
for d in data
total_error += pl(d...)
end
return total_error
end
# ╔═╡ 30417194-6cd1-4e10-9e25-1fa1c0761b9a
begin
errors = []
for i in 1:N_epoch
data = data_gen()
Flux.train!(pl, pl.policy_params, data, opt)
Flux.train!(pl, pl.value_params, data, opt)
append!(errors,error(pl,data1)/200)
end
end
# ╔═╡ 7f42b201-01ec-4d63-b8de-92e735f28e4a
errors
#first run 7017,6562,5613,5405,5535,4462
# ╔═╡ 65cf41b8-adf4-42b9-82dc-d2fbafccda81
begin
model_name = "$(Dates.now())_planner.bson"
@save model_name pl
end
# ╔═╡ 913a0953-6d5c-4694-9a64-95c54f3398e4
@load model_name testload
# ╔═╡ 00000000-0000-0000-0000-000000000001
PLUTO_PROJECT_TOML_CONTENTS = """
[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PlutoUI = "7f904dfe-b85e-4ff6-b463-dae2292396a8"
[compat]
BSON = "~0.3.4"
Flux = "~0.12.8"
PlutoUI = "~0.7.18"
"""
@ -494,6 +744,11 @@ git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.2.0"
[[BSON]]
git-tree-sha1 = "ebcd6e22d69f21249b7b8668351ebf42d6dc87a1"
uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
version = "0.3.4"
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -1008,6 +1263,9 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
# ╔═╡ Cell order:
# ╠═0829eb90-1d74-46b6-80ec-482a2b71c6fe
# ╠═f360390b-1d15-4443-aad6-a5800a6ba776
# ╠═9a1ebef5-ed53-4e0b-89e7-20a20819b032
# ╠═fe446c7e-ed8c-4a0b-8bc9-024806f9f352
# ╟─66f0e667-d722-4e1e-807b-84a39cbc41b1
# ╟─9fa41b7c-1923-4c1e-bfc6-20ce4a1a2ede
# ╟─0d209540-29ad-4b2f-9e91-fac2bbee47ff
@ -1015,32 +1273,45 @@ uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
# ╠═f8779824-9d59-4458-926f-10beb3cd3866
# ╠═152f3a3c-a565-41bb-8e59-6ab0d2315ffb
# ╠═25ac9438-2b1d-4f6b-9ff1-1695e1d52b51
# ╠═6bf01e30-d054-4a98-9e0d-58d7da4e5524
# ╠═9b1c2f62-ec19-4667-8f93-f8b52750e317
# ╟─9b1c2f62-ec19-4667-8f93-f8b52750e317
# ╟─90446134-4e45-471c-857d-4e165e51937a
# ╠═6bf01e30-d054-4a98-9e0d-58d7da4e5524
# ╟─29ff1777-d276-4e8f-8582-4ca191f2e2ff
# ╟─f7aabe43-9a2c-4fe0-8099-c29cdf66566c
# ╟─d816b252-bdca-44ba-ac5c-cb21163a1e9a
# ╠═f7aabe43-9a2c-4fe0-8099-c29cdf66566c
# ╠═d816b252-bdca-44ba-ac5c-cb21163a1e9a
# ╠═e4e2f50e-05fc-42da-a6d3-d62c4f901d84
# ╠═ae6681e0-c796-4145-a313-75f74b4993ad
# ╠═bbca5143-f314-40ea-a20e-8a043272e362
# ╠═340da189-f443-4376-a82d-7699a21ab7a2
# ╠═206ac4cc-5102-4381-ad8a-777b02dc4d5a
# ╠═eebb8706-a431-4fd1-b7a5-40f07a63d5cb
# ╠═b433a7ec-8264-48d6-8b95-53d2ec4bad05
# ╠═65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
# ╠═19ccfc3a-6dbb-4c64-bf03-e2e219ef0efe
# ╠═dc614254-c211-4552-b985-03020bfc5ab3
# ╠═4452ab2c-813b-4b7b-8cb3-388ec507a24c
# ╟─65e0b1fa-d5e1-4ff6-8736-c9d6b5f40150
# ╟─19ccfc3a-6dbb-4c64-bf03-e2e219ef0efe
# ╟─dc614254-c211-4552-b985-03020bfc5ab3
# ╟─128ab84e-9682-4c4e-a301-1e27d9c199a4
# ╠═aa079b60-af16-4824-8a1d-b0c68a0cacbc
# ╠═52eb1934-77b1-4da3-a36a-afb951f76784
# ╠═717a07a4-5254-40aa-8244-1a2010fedec8
# ╠═5946daa3-4608-43f3-8933-dd3eb3f4541c
# ╠═9d4a668a-23e3-4f36-86f4-60e242caee3b
# ╠═41271ab4-1ec7-431f-9efb-0f7c3da2d8b4
# ╠═89258b2e-cee5-4cbd-a42e-21301b7a2549
# ╠═382b6dee-5fa7-4ca1-b804-4b50b0fb65a7
# ╠═dc855793-d234-4835-8674-bf56aa0ad0db
# ╠═f68e2322-937d-44ce-8aa9-757b27b00603
# ╠═cd55e232-493d-4849-8bd7-b0ba85e21bab
# ╠═176981f2-b7f5-4a0f-aae0-940e1db778c7
# ╠═08ef7fbf-005b-40b5-acde-f42750c04cd3
# ╠═fb6aacff-c42d-4ec1-88cb-5ce1b2e8874f
# ╠═1eae5450-f9ef-49ec-9dcd-48c025562ecf
# ╠═5abebc1a-370c-4f5f-8826-dc0b143d5166
# ╠═a20959be-65e4-4b69-9521-503bc59f0854
# ╠═6bf8d29a-7990-4e91-86e6-d9894ed3db27
# ╠═e6654ee2-776c-4b26-a5a3-324188062966
# ╠═2b115364-a8ce-4eb9-a324-7ac213061b83
# ╠═0679c91b-9824-4017-83fd-4001dc8c2d54
# ╠═30417194-6cd1-4e10-9e25-1fa1c0761b9a
# ╠═7f42b201-01ec-4d63-b8de-92e735f28e4a
# ╠═65cf41b8-adf4-42b9-82dc-d2fbafccda81
# ╠═913a0953-6d5c-4694-9a64-95c54f3398e4
# ╟─00000000-0000-0000-0000-000000000001
# ╟─00000000-0000-0000-0000-000000000002

Loading…
Cancel
Save