You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
5.7 KiB
Julia

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

module Orbits
using Flux,LinearAlgebra,Zygote,BSON
include("physical_model.jl")
export State, state_to_tuple, state_transition, BasicPhysics, PhysicalModel
include("benefit_functions.jl")
export LinearModel, EconomicModel
include("flux_helpers.jl")
export operator_policy_function_generator,
value_function_generator,
BranchGenerator,
cross_linked_planner_policy_function_generator
# Exports from below
export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor ,
train_planner!, train_operators! ,
BasicPhysics, survival_rates_1
# Code
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end # struct
function (dc::UniformDataConstructor)(N_constellations,N_debris)
return State(
rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris)
)
end # function
abstract type GeneralizedLoss end
struct OperatorLoss <: GeneralizedLoss
#=
This struct organizes information about a given constellation operator
It is used to provide an iterable loss function for training.
The fields each identify one aspect of the operator's decision problem:
- The economic model describing payoffs and discounting.
- The estimated NN value function held by the operator.
- The estimated NN policy function that describes each of the
- Each operator holds a reference to the parameters they can update.
- The physical world that describes how satellites progress.
There is an overriding function that uses these details to calculate the residual function based on
the bellman residuals and a maximization condition.
There are two versions of this function.
- return an actual loss calculation (MAE) using the owned policy function.
- return the loss calculation using a provided policy function.
=#
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params #but only some of it is available for training
physics::PhysicalModel #It would be nice to move this to somewhere else in the model.
end # struct
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
state::State
,policy::Flux.Chain #allow for another policy to be subsituted
)
#get actions
a = policy(state)
#get updated stocks and debris
state = stake_transition(state,a)
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function
function (operator::OperatorLoss)(
state::State
)
#just use the included policy.
return operator(state,operator.collected_policies)
end # function
function train_operator!(op::OperatorLoss, data::State, opt)
#Train the policy functions
Flux.train!(op, op.operator_policy_params, data, opt)
#Train the value function
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
end
#=
Describe the Planners loss function
=#
struct PlannerLoss <: GeneralizedLoss
#=
Ideally, with just a well formed PlannerLoss and the training functions below, we should be able to train the approximation.
There is an issue with appropriately training the value functions.
In this case, it is not happening...
=#
#planner discount level (As it may disagree with operators)
β::Float32
operators::Array{GeneralizedLoss}
policy::Flux.Chain
policy_params::Flux.Params
value::Flux.Chain
value_params::Flux.Params
physical_model::PhysicalModel
end
function (planner::PlannerLoss)(
state::State
)
#TODO! Rewrite to use a new states setup.
a = planner.policy((s ,d)) #TODO: States
#get updated stocks and debris
state = state_transition(s ,d ,a) #TODO: States
#calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States
#issue here with mutating. Maybe generators/list comprehensions?
bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s,d)) )#TODO: States
maximization_condition = - benefit - planner.β .* planner.value((s,d)) #TODO: States
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end # function
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
Flux.train!(pl, pl.policy_params, data, opt)
Flux.train!(pl, pl.value_params, data, opt)
append!(errors, error(pl, data1) / 200)
end
return errors
end # function
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor)
errors = []
for i = 1:N_epoch
data = data_gen()
data = data_gen()
for op in pl.operators
train_operator!(op,data,opt)
end
end
return errors
end # function
end # end module