partial refactoring and redesign to use struct of arrays on states

temporaryWork^2
will king 4 years ago
parent 51eff92f7e
commit 3a65710526

@ -1,3 +1,4 @@
{
"python.pythonPath": "/bin/python3"
"python.pythonPath": "/bin/python3",
"editor.detectIndentation": false
}

@ -15,67 +15,90 @@ export operator_policy_function_generator,
cross_linked_planner_policy_function_generator
# Exports from below
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor,
train_planner!, train_operators!
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor ,
train_planner!, train_operators! ,
BasicPhysics, survival_rates_1
# Code
abstract type GeneralizedLoss end
#=
This struct organizes information about a given constellation operator
Used to provide an interable loss function for training
=#
struct OperatorLoss <: GeneralizedLoss
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params
physics::PhysicalModel
end
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
)
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end # struct
function (dc::UniformDataConstructor)(N_constellations,N_debris)
return State(
rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris)
)
end # function
#get actions
a = operator.collected_policies((s,d))
#get updated stocks and debris
state = state_transition(operator.physics,state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
abstract type GeneralizedLoss end
struct OperatorLoss <: GeneralizedLoss
#=
This struct organizes information about a given constellation operator
It is used to provide an iterable loss function for training.
The fields each identify one aspect of the operator's decision problem:
- The economic model describing payoffs and discounting.
- The estimated NN value function held by the operator.
- The estimated NN policy function that describes each of the
- Each operator holds a reference to the parameters they can update.
- The physical world that describes how satellites progress.
There is an overriding function that uses these details to calculate the residual function based on
the bellman residuals and a maximization condition.
There are two versions of this function.
- return an actual loss calculation (MAE) using the owned policy function.
- return the loss calculation using a provided policy function.
=#
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params #but only some of it is available for training
physics::PhysicalModel #It would be nice to move this to somewhere else in the model.
end # struct
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
,policy::Flux.Chain #allow for another policy to be subsituted
)
)
#get actions
a = policy(state)
#get updated stocks and debris
state = stake_transition(state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function
function (operator::OperatorLoss)(
state::State
)
#just use the included policy.
return operator(state,operator.collected_policies)
end # function
function train_operators!(op::OperatorLoss, data::State, opt)
function train_operator!(op::OperatorLoss, data::State, opt)
#Train the policy functions
Flux.train!(op, op.operator_policy_params, data, opt)
#Train the value function
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
end
@ -90,6 +113,7 @@ struct PlannerLoss <: GeneralizedLoss
There is an issue with appropriately training the value functions.
In this case, it is not happening...
=#
#planner discount level (As it may disagree with operators)
β::Float32
operators::Array{GeneralizedLoss}
@ -104,25 +128,25 @@ end
function (planner::PlannerLoss)(
state::State
)
a = planner.policy((s ,d))
#TODO! Rewrite to use a new states setup.
a = planner.policy((s ,d)) #TODO: States
#get updated stocks and debris
state = state_transition(s ,d ,a)
state = state_transition(s ,d ,a) #TODO: States
#calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators])
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States
#issue here with mutating. Maybe generators/list comprehensions?
bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s,d))
bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s,d)) )#TODO: States
maximization_condition = - benefit - planner.β .* planner.value((s,d))
maximization_condition = - benefit - planner.β .* planner.value((s,d)) #TODO: States
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
end # function
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
@ -131,37 +155,20 @@ function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
append!(errors, error(pl, data1) / 200)
end
return errors
end
end # function
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt)
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
data = data_gen()
for op in pl.operators
train_operators!(op,data,opt)
train_operator!(op,data,opt)
end
end
return errors
end
end # function
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end
function (dc::UniformDataConstructor)()
#currently ignores the quantity of data it should construct.
State(
rand(dc.satellites_bottom:dc.satellites_top, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, N_debris)
)
end
end # end module

@ -1,21 +1,29 @@
#==#
struct State
stocks::Array{Float32}
debris::Array{Float32}
end
#=Satellite State=#
abstract type State end
function state_to_tuple(s::State)
return (s.stocks ,s.debris)
struct SingleStates <: State
stocks::Vector{Float32}
debris::Vector{Float32}
end
### Physics
struct MultiStates <: State
stocks::Array{Float32}
debris::Array{Float32}
end
#function state_to_tuple(s::State)
# return (s.stocks ,s.debris)
#end
#=Physical Model
This contains parameters describing the physical model.
=#
abstract type PhysicalModel end
#=Basic implementation of a physical model=#
struct BasicPhysics <: PhysicalModel
#rate at which debris hits satellites
debris_collision_rate::Real
@ -30,15 +38,17 @@ struct BasicPhysics <: PhysicalModel
#Ratio at which launches produce debris
launch_debris_ratio::Real
end
function state_transition(
physics::BasicPhysics
,state::State
,launches::Vector{Float32}
)
,survival_rate::Function
)
#=
Physical Transitions
=#
survival_rate = survival(state,physics)
survival_rates = survival_rate(state,physics)
# Debris transitions
@ -46,7 +56,7 @@ function state_transition(
natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris
# get changes in debris from satellite loss
satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rate)' * state.stocks
satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rates)' * state.stocks
# get changes in debris from launches
launch_debris = physics.launch_debris_ratio * sum(launches)
@ -56,17 +66,16 @@ function state_transition(
# Stocks Transitions
stocks = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches
stocks = (LinearAlgebra.diagm(survival_rates) .- physics.decay_rate)*state.stocks + launches
return State(stocks,debris)
end
function survival(
function survival_rates_1(
#This function describes the rate at which satellites survive each period.
state::State
,physical_model::BasicPhysics
)
)
#TODO! get this to broadcast correctly.
return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris))
end

@ -0,0 +1,62 @@
using Test, Flux, LinearAlgebra
include("../src/Orbits.jl")
using .Orbits
#=
The purpose of this document is to organize tests of the state structs and state transition functions
=#
@testset "States and Physical models testing" verbose=true begin
n_const = 2
n_debr = 3
n_data = 5
#built structs
u = UniformDataConstructor(n_data,0,5,2,3)
s = u(n_const,n_debr)
b = BasicPhysics(
0.05
,0.02*LinearAlgebra.ones(n_const,n_const)
,0.1
,0.002
,0.002
,0.2
)
a2 = ones(n_const,n_data)
#test that dimensions match etc
@test size(b.satellite_collision_rates)[1] == size(s.stocks)[1]
@test size(b.satellite_collision_rates)[2] == size(s.stocks)[1]
@test n_data == size(s.stocks)[2]
@test n_data == size(s.debris)[2]
@test size(s.stocks) == size(a2)
@testset "DataConstructor and states" begin
@test u.N == 5
@test length(s.debris) != 3
@test length(s.stocks) != 2
@test length(s.stocks) == 10
@test length(s.debris) == 15
@test size(s.stocks) == (2,5)
@test size(s.debris) == (3,5)
end
@testset "BasicPhysics" begin
@testset "Survival Functions" verbose = true begin
@test survival_rates_1(s,b) <: AbstractArray
end
@testset "Transitions" begin
@test true
end
end
end #States and physcial models testing
Loading…
Cancel
Save