|
|
module Orbits
|
|
|
|
|
|
using Flux,LinearAlgebra,Zygote,BSON
|
|
|
|
|
|
include("physical_model.jl")
|
|
|
export State, state_to_tuple, state_transition, BasicPhysics, PhysicalModel
|
|
|
|
|
|
include("benefit_functions.jl")
|
|
|
export LinearModel, EconomicModel
|
|
|
|
|
|
include("flux_helpers.jl")
|
|
|
export operator_policy_function_generator,
|
|
|
value_function_generator,
|
|
|
BranchGenerator,
|
|
|
cross_linked_planner_policy_function_generator
|
|
|
|
|
|
# Exports from below
|
|
|
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor ,
|
|
|
train_planner!, train_operators! ,
|
|
|
BasicPhysics, survival_rates_1
|
|
|
|
|
|
# Code
|
|
|
|
|
|
|
|
|
#Construct and manage data
|
|
|
abstract type DataConstructor end
|
|
|
|
|
|
struct UniformDataConstructor <: DataConstructor
|
|
|
N::UInt64
|
|
|
satellites_bottom::Float32
|
|
|
satellites_top::Float32
|
|
|
debris_bottom::Float32
|
|
|
debris_top::Float32
|
|
|
end # struct
|
|
|
function (dc::UniformDataConstructor)(N_constellations,N_debris)
|
|
|
return State(
|
|
|
rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations)
|
|
|
, rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris)
|
|
|
)
|
|
|
end # function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
abstract type GeneralizedLoss end
|
|
|
|
|
|
struct OperatorLoss <: GeneralizedLoss
|
|
|
#=
|
|
|
This struct organizes information about a given constellation operator
|
|
|
It is used to provide an iterable loss function for training.
|
|
|
The fields each identify one aspect of the operator's decision problem:
|
|
|
- The economic model describing payoffs and discounting.
|
|
|
- The estimated NN value function held by the operator.
|
|
|
- The estimated NN policy function that describes each of the
|
|
|
- Each operator holds a reference to the parameters they can update.
|
|
|
- The physical world that describes how satellites progress.
|
|
|
There is an overriding function that uses these details to calculate the residual function based on
|
|
|
the bellman residuals and a maximization condition.
|
|
|
There are two versions of this function.
|
|
|
- return an actual loss calculation (MAE) using the owned policy function.
|
|
|
- return the loss calculation using a provided policy function.
|
|
|
=#
|
|
|
#econ model describing operator
|
|
|
econ_model::EconomicModel
|
|
|
#Operator's value and policy functions, as well as which parameters the operator can train.
|
|
|
operator_value_fn::Flux.Chain
|
|
|
collected_policies::Flux.Chain #this is held by all operators
|
|
|
operator_policy_params::Flux.Params #but only some of it is available for training
|
|
|
physics::PhysicalModel #It would be nice to move this to somewhere else in the model.
|
|
|
end # struct
|
|
|
# overriding function to calculate operator loss
|
|
|
function (operator::OperatorLoss)(
|
|
|
state::State
|
|
|
,policy::Flux.Chain #allow for another policy to be subsituted
|
|
|
)
|
|
|
#get actions
|
|
|
a = policy(state)
|
|
|
|
|
|
#get updated stocks and debris
|
|
|
state′ = stake_transition(state,a)
|
|
|
|
|
|
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′)
|
|
|
|
|
|
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′)
|
|
|
|
|
|
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
|
|
|
end #function
|
|
|
function (operator::OperatorLoss)(
|
|
|
state::State
|
|
|
)
|
|
|
#just use the included policy.
|
|
|
return operator(state,operator.collected_policies)
|
|
|
end # function
|
|
|
|
|
|
function train_operator!(op::OperatorLoss, data::State, opt)
|
|
|
#Train the policy functions
|
|
|
Flux.train!(op, op.operator_policy_params, data, opt)
|
|
|
#Train the value function
|
|
|
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
|
|
|
end
|
|
|
|
|
|
|
|
|
#=
|
|
|
Describe the Planners loss function
|
|
|
=#
|
|
|
struct PlannerLoss <: GeneralizedLoss
|
|
|
#=
|
|
|
Ideally, with just a well formed PlannerLoss and the training functions below, we should be able to train the approximation.
|
|
|
|
|
|
There is an issue with appropriately training the value functions.
|
|
|
In this case, it is not happening...
|
|
|
=#
|
|
|
#planner discount level (As it may disagree with operators)
|
|
|
β::Float32
|
|
|
|
|
|
operators::Array{GeneralizedLoss}
|
|
|
|
|
|
policy::Flux.Chain
|
|
|
policy_params::Flux.Params
|
|
|
value::Flux.Chain
|
|
|
value_params::Flux.Params
|
|
|
|
|
|
physical_model::PhysicalModel
|
|
|
end
|
|
|
function (planner::PlannerLoss)(
|
|
|
state::State
|
|
|
)
|
|
|
#TODO! Rewrite to use a new states setup.
|
|
|
a = planner.policy((s ,d)) #TODO: States
|
|
|
#get updated stocks and debris
|
|
|
state′ = state_transition(s ,d ,a) #TODO: States
|
|
|
|
|
|
|
|
|
#calculate the total benefit from each of the models
|
|
|
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States
|
|
|
#issue here with mutating. Maybe generators/list comprehensions?
|
|
|
|
|
|
|
|
|
bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s′,d′)) )#TODO: States
|
|
|
|
|
|
maximization_condition = - benefit - planner.β .* planner.value((s′,d′)) #TODO: States
|
|
|
|
|
|
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
|
|
|
end # function
|
|
|
|
|
|
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
|
|
|
errors = []
|
|
|
for i = 1:N_epoch
|
|
|
data = data_gen()
|
|
|
Flux.train!(pl, pl.policy_params, data, opt)
|
|
|
Flux.train!(pl, pl.value_params, data, opt)
|
|
|
append!(errors, error(pl, data1) / 200)
|
|
|
end
|
|
|
return errors
|
|
|
end # function
|
|
|
|
|
|
|
|
|
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
|
|
|
errors = []
|
|
|
for i = 1:N_epoch
|
|
|
data = data_gen()
|
|
|
data = data_gen()
|
|
|
for op in pl.operators
|
|
|
train_operator!(op,data,opt)
|
|
|
end
|
|
|
end
|
|
|
return errors
|
|
|
end # function
|
|
|
|
|
|
|
|
|
end # end module
|