partial refactoring and redesign to use struct of arrays on states

temporaryWork^2
will king 4 years ago
parent 51eff92f7e
commit 3a65710526

@ -1,3 +1,4 @@
{
"python.pythonPath": "/bin/python3"
"python.pythonPath": "/bin/python3",
"editor.detectIndentation": false
}

@ -16,66 +16,89 @@ export operator_policy_function_generator,
# Exports from below
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor ,
train_planner!, train_operators!
train_planner!, train_operators! ,
BasicPhysics, survival_rates_1
# Code
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end # struct
function (dc::UniformDataConstructor)(N_constellations,N_debris)
return State(
rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris)
)
end # function
abstract type GeneralizedLoss end
struct OperatorLoss <: GeneralizedLoss
#=
This struct organizes information about a given constellation operator
Used to provide an interable loss function for training
It is used to provide an iterable loss function for training.
The fields each identify one aspect of the operator's decision problem:
- The economic model describing payoffs and discounting.
- The estimated NN value function held by the operator.
- The estimated NN policy function that describes each of the
- Each operator holds a reference to the parameters they can update.
- The physical world that describes how satellites progress.
There is an overriding function that uses these details to calculate the residual function based on
the bellman residuals and a maximization condition.
There are two versions of this function.
- return an actual loss calculation (MAE) using the owned policy function.
- return the loss calculation using a provided policy function.
=#
struct OperatorLoss <: GeneralizedLoss
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params
physics::PhysicalModel
end
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
)
#get actions
a = operator.collected_policies((s,d))
#get updated stocks and debris
state = state_transition(operator.physics,state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
operator_policy_params::Flux.Params #but only some of it is available for training
physics::PhysicalModel #It would be nice to move this to somewhere else in the model.
end # struct
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
,policy::Flux.Chain #allow for another policy to be subsituted
)
#get actions
a = policy(state)
#get updated stocks and debris
state = stake_transition(state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function
function (operator::OperatorLoss)(
state::State
)
#just use the included policy.
return operator(state,operator.collected_policies)
end # function
function train_operators!(op::OperatorLoss, data::State, opt)
function train_operator!(op::OperatorLoss, data::State, opt)
#Train the policy functions
Flux.train!(op, op.operator_policy_params, data, opt)
#Train the value function
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
end
@ -90,6 +113,7 @@ struct PlannerLoss <: GeneralizedLoss
There is an issue with appropriately training the value functions.
In this case, it is not happening...
=#
#planner discount level (As it may disagree with operators)
β::Float32
operators::Array{GeneralizedLoss}
@ -104,25 +128,25 @@ end
function (planner::PlannerLoss)(
state::State
)
a = planner.policy((s ,d))
#TODO! Rewrite to use a new states setup.
a = planner.policy((s ,d)) #TODO: States
#get updated stocks and debris
state = state_transition(s ,d ,a)
state = state_transition(s ,d ,a) #TODO: States
#calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators])
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States
#issue here with mutating. Maybe generators/list comprehensions?
bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s,d))
bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s,d)) )#TODO: States
maximization_condition = - benefit - planner.β .* planner.value((s,d))
maximization_condition = - benefit - planner.β .* planner.value((s,d)) #TODO: States
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
end # function
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
@ -131,37 +155,20 @@ function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
append!(errors, error(pl, data1) / 200)
end
return errors
end
end # function
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt)
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
data = data_gen()
for op in pl.operators
train_operators!(op,data,opt)
train_operator!(op,data,opt)
end
end
return errors
end
#Construct and manage data
abstract type DataConstructor end
end # function
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end
function (dc::UniformDataConstructor)()
#currently ignores the quantity of data it should construct.
State(
rand(dc.satellites_bottom:dc.satellites_top, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, N_debris)
)
end
end # end module

@ -1,21 +1,29 @@
#==#
struct State
stocks::Array{Float32}
debris::Array{Float32}
end
#=Satellite State=#
abstract type State end
function state_to_tuple(s::State)
return (s.stocks ,s.debris)
struct SingleStates <: State
stocks::Vector{Float32}
debris::Vector{Float32}
end
### Physics
struct MultiStates <: State
stocks::Array{Float32}
debris::Array{Float32}
end
#function state_to_tuple(s::State)
# return (s.stocks ,s.debris)
#end
#=Physical Model
This contains parameters describing the physical model.
=#
abstract type PhysicalModel end
#=Basic implementation of a physical model=#
struct BasicPhysics <: PhysicalModel
#rate at which debris hits satellites
debris_collision_rate::Real
@ -30,15 +38,17 @@ struct BasicPhysics <: PhysicalModel
#Ratio at which launches produce debris
launch_debris_ratio::Real
end
function state_transition(
physics::BasicPhysics
,state::State
,launches::Vector{Float32}
,survival_rate::Function
)
#=
Physical Transitions
=#
survival_rate = survival(state,physics)
survival_rates = survival_rate(state,physics)
# Debris transitions
@ -46,7 +56,7 @@ function state_transition(
natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris
# get changes in debris from satellite loss
satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rate)' * state.stocks
satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rates)' * state.stocks
# get changes in debris from launches
launch_debris = physics.launch_debris_ratio * sum(launches)
@ -56,17 +66,16 @@ function state_transition(
# Stocks Transitions
stocks = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches
stocks = (LinearAlgebra.diagm(survival_rates) .- physics.decay_rate)*state.stocks + launches
return State(stocks,debris)
end
function survival(
function survival_rates_1(
#This function describes the rate at which satellites survive each period.
state::State
,physical_model::BasicPhysics
)
#TODO! get this to broadcast correctly.
return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris))
end

@ -0,0 +1,62 @@
using Test, Flux, LinearAlgebra
include("../src/Orbits.jl")
using .Orbits
#=
The purpose of this document is to organize tests of the state structs and state transition functions
=#
@testset "States and Physical models testing" verbose=true begin
n_const = 2
n_debr = 3
n_data = 5
#built structs
u = UniformDataConstructor(n_data,0,5,2,3)
s = u(n_const,n_debr)
b = BasicPhysics(
0.05
,0.02*LinearAlgebra.ones(n_const,n_const)
,0.1
,0.002
,0.002
,0.2
)
a2 = ones(n_const,n_data)
#test that dimensions match etc
@test size(b.satellite_collision_rates)[1] == size(s.stocks)[1]
@test size(b.satellite_collision_rates)[2] == size(s.stocks)[1]
@test n_data == size(s.stocks)[2]
@test n_data == size(s.debris)[2]
@test size(s.stocks) == size(a2)
@testset "DataConstructor and states" begin
@test u.N == 5
@test length(s.debris) != 3
@test length(s.stocks) != 2
@test length(s.stocks) == 10
@test length(s.debris) == 15
@test size(s.stocks) == (2,5)
@test size(s.debris) == (3,5)
end
@testset "BasicPhysics" begin
@testset "Survival Functions" verbose = true begin
@test survival_rates_1(s,b) <: AbstractArray
end
@testset "Transitions" begin
@test true
end
end
end #States and physcial models testing
Loading…
Cancel
Save