partial refactoring and redesign to use struct of arrays on states

temporaryWork^2
will king 4 years ago
parent 51eff92f7e
commit 3a65710526

@ -1,3 +1,4 @@
{ {
"python.pythonPath": "/bin/python3" "python.pythonPath": "/bin/python3",
"editor.detectIndentation": false
} }

@ -15,67 +15,90 @@ export operator_policy_function_generator,
cross_linked_planner_policy_function_generator cross_linked_planner_policy_function_generator
# Exports from below # Exports from below
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor, export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor ,
train_planner!, train_operators! train_planner!, train_operators! ,
BasicPhysics, survival_rates_1
# Code # Code
abstract type GeneralizedLoss end
#= #Construct and manage data
This struct organizes information about a given constellation operator abstract type DataConstructor end
Used to provide an interable loss function for training
=#
struct OperatorLoss <: GeneralizedLoss
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params
physics::PhysicalModel
end
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
)
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end # struct
function (dc::UniformDataConstructor)(N_constellations,N_debris)
return State(
rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris)
)
end # function
#get actions
a = operator.collected_policies((s,d))
#get updated stocks and debris
state = state_transition(operator.physics,state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
abstract type GeneralizedLoss end
struct OperatorLoss <: GeneralizedLoss
#=
This struct organizes information about a given constellation operator
It is used to provide an iterable loss function for training.
The fields each identify one aspect of the operator's decision problem:
- The economic model describing payoffs and discounting.
- The estimated NN value function held by the operator.
- The estimated NN policy function that describes each of the
- Each operator holds a reference to the parameters they can update.
- The physical world that describes how satellites progress.
There is an overriding function that uses these details to calculate the residual function based on
the bellman residuals and a maximization condition.
There are two versions of this function.
- return an actual loss calculation (MAE) using the owned policy function.
- return the loss calculation using a provided policy function.
=#
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params #but only some of it is available for training
physics::PhysicalModel #It would be nice to move this to somewhere else in the model.
end # struct end # struct
# overriding function to calculate operator loss
function (operator::OperatorLoss)( function (operator::OperatorLoss)(
state::State state::State
,policy::Flux.Chain #allow for another policy to be subsituted ,policy::Flux.Chain #allow for another policy to be subsituted
) )
#get actions #get actions
a = policy(state) a = policy(state)
#get updated stocks and debris #get updated stocks and debris
state = stake_transition(state,a) state = stake_transition(state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state) bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state) maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
return Flux.mae(bellman_residuals.^2 ,maximization_condition) return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function end #function
function (operator::OperatorLoss)(
state::State
)
#just use the included policy.
return operator(state,operator.collected_policies)
end # function
function train_operators!(op::OperatorLoss, data::State, opt) function train_operator!(op::OperatorLoss, data::State, opt)
#Train the policy functions
Flux.train!(op, op.operator_policy_params, data, opt) Flux.train!(op, op.operator_policy_params, data, opt)
#Train the value function
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt) Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
end end
@ -90,6 +113,7 @@ struct PlannerLoss <: GeneralizedLoss
There is an issue with appropriately training the value functions. There is an issue with appropriately training the value functions.
In this case, it is not happening... In this case, it is not happening...
=# =#
#planner discount level (As it may disagree with operators)
β::Float32 β::Float32
operators::Array{GeneralizedLoss} operators::Array{GeneralizedLoss}
@ -104,25 +128,25 @@ end
function (planner::PlannerLoss)( function (planner::PlannerLoss)(
state::State state::State
) )
a = planner.policy((s ,d)) #TODO! Rewrite to use a new states setup.
a = planner.policy((s ,d)) #TODO: States
#get updated stocks and debris #get updated stocks and debris
state = state_transition(s ,d ,a) state = state_transition(s ,d ,a) #TODO: States
#calculate the total benefit from each of the models #calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States
#issue here with mutating. Maybe generators/list comprehensions? #issue here with mutating. Maybe generators/list comprehensions?
bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s,d)) bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s,d)) )#TODO: States
maximization_condition = - benefit - planner.β .* planner.value((s,d)) maximization_condition = - benefit - planner.β .* planner.value((s,d)) #TODO: States
return Flux.mae(bellman_residuals.^2 ,maximization_condition) return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end end # function
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt) function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = [] errors = []
for i = 1:N_epoch for i = 1:N_epoch
data = data_gen() data = data_gen()
@ -131,37 +155,20 @@ function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
append!(errors, error(pl, data1) / 200) append!(errors, error(pl, data1) / 200)
end end
return errors return errors
end end # function
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt) function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = [] errors = []
for i = 1:N_epoch for i = 1:N_epoch
data = data_gen()
data = data_gen() data = data_gen()
for op in pl.operators for op in pl.operators
train_operators!(op,data,opt) train_operator!(op,data,opt)
end end
end end
return errors return errors
end end # function
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end
function (dc::UniformDataConstructor)()
#currently ignores the quantity of data it should construct.
State(
rand(dc.satellites_bottom:dc.satellites_top, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, N_debris)
)
end
end # end module end # end module

@ -1,21 +1,29 @@
#==# #=Satellite State=#
struct State abstract type State end
stocks::Array{Float32}
debris::Array{Float32}
end
function state_to_tuple(s::State) struct SingleStates <: State
return (s.stocks ,s.debris) stocks::Vector{Float32}
debris::Vector{Float32}
end end
### Physics struct MultiStates <: State
stocks::Array{Float32}
debris::Array{Float32}
end
#function state_to_tuple(s::State)
# return (s.stocks ,s.debris)
#end
#=Physical Model
This contains parameters describing the physical model.
=#
abstract type PhysicalModel end abstract type PhysicalModel end
#=Basic implementation of a physical model=#
struct BasicPhysics <: PhysicalModel struct BasicPhysics <: PhysicalModel
#rate at which debris hits satellites #rate at which debris hits satellites
debris_collision_rate::Real debris_collision_rate::Real
@ -30,15 +38,17 @@ struct BasicPhysics <: PhysicalModel
#Ratio at which launches produce debris #Ratio at which launches produce debris
launch_debris_ratio::Real launch_debris_ratio::Real
end end
function state_transition( function state_transition(
physics::BasicPhysics physics::BasicPhysics
,state::State ,state::State
,launches::Vector{Float32} ,launches::Vector{Float32}
) ,survival_rate::Function
)
#= #=
Physical Transitions Physical Transitions
=# =#
survival_rate = survival(state,physics) survival_rates = survival_rate(state,physics)
# Debris transitions # Debris transitions
@ -46,7 +56,7 @@ function state_transition(
natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris
# get changes in debris from satellite loss # get changes in debris from satellite loss
satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rate)' * state.stocks satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rates)' * state.stocks
# get changes in debris from launches # get changes in debris from launches
launch_debris = physics.launch_debris_ratio * sum(launches) launch_debris = physics.launch_debris_ratio * sum(launches)
@ -56,17 +66,16 @@ function state_transition(
# Stocks Transitions # Stocks Transitions
stocks = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches stocks = (LinearAlgebra.diagm(survival_rates) .- physics.decay_rate)*state.stocks + launches
return State(stocks,debris) return State(stocks,debris)
end end
function survival_rates_1(
function survival( #This function describes the rate at which satellites survive each period.
state::State state::State
,physical_model::BasicPhysics ,physical_model::BasicPhysics
) )
#TODO! get this to broadcast correctly.
return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris)) return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris))
end end

@ -0,0 +1,62 @@
using Test, Flux, LinearAlgebra
include("../src/Orbits.jl")
using .Orbits
#=
The purpose of this document is to organize tests of the state structs and state transition functions
=#
@testset "States and Physical models testing" verbose=true begin
n_const = 2
n_debr = 3
n_data = 5
#built structs
u = UniformDataConstructor(n_data,0,5,2,3)
s = u(n_const,n_debr)
b = BasicPhysics(
0.05
,0.02*LinearAlgebra.ones(n_const,n_const)
,0.1
,0.002
,0.002
,0.2
)
a2 = ones(n_const,n_data)
#test that dimensions match etc
@test size(b.satellite_collision_rates)[1] == size(s.stocks)[1]
@test size(b.satellite_collision_rates)[2] == size(s.stocks)[1]
@test n_data == size(s.stocks)[2]
@test n_data == size(s.debris)[2]
@test size(s.stocks) == size(a2)
@testset "DataConstructor and states" begin
@test u.N == 5
@test length(s.debris) != 3
@test length(s.stocks) != 2
@test length(s.stocks) == 10
@test length(s.debris) == 15
@test size(s.stocks) == (2,5)
@test size(s.debris) == (3,5)
end
@testset "BasicPhysics" begin
@testset "Survival Functions" verbose = true begin
@test survival_rates_1(s,b) <: AbstractArray
end
@testset "Transitions" begin
@test true
end
end
end #States and physcial models testing
Loading…
Cancel
Save