module Orbits using Flux,LinearAlgebra,Zygote,BSON include("physical_model.jl") export State, state_to_tuple, state_transition, BasicPhysics, PhysicalModel include("benefit_functions.jl") export LinearModel, EconomicModel include("flux_helpers.jl") export operator_policy_function_generator, value_function_generator, BranchGenerator, cross_linked_planner_policy_function_generator # Exports from below export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor , train_planner!, train_operators! , BasicPhysics, survival_rates_1 # Code #Construct and manage data abstract type DataConstructor end struct UniformDataConstructor <: DataConstructor N::UInt64 satellites_bottom::Float32 satellites_top::Float32 debris_bottom::Float32 debris_top::Float32 end # struct function (dc::UniformDataConstructor)(N_constellations,N_debris) return State( rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations) , rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris) ) end # function abstract type GeneralizedLoss end struct OperatorLoss <: GeneralizedLoss #= This struct organizes information about a given constellation operator It is used to provide an iterable loss function for training. The fields each identify one aspect of the operator's decision problem: - The economic model describing payoffs and discounting. - The estimated NN value function held by the operator. - The estimated NN policy function that describes each of the - Each operator holds a reference to the parameters they can update. - The physical world that describes how satellites progress. There is an overriding function that uses these details to calculate the residual function based on the bellman residuals and a maximization condition. There are two versions of this function. - return an actual loss calculation (MAE) using the owned policy function. - return the loss calculation using a provided policy function. =# #econ model describing operator econ_model::EconomicModel #Operator's value and policy functions, as well as which parameters the operator can train. operator_value_fn::Flux.Chain collected_policies::Flux.Chain #this is held by all operators operator_policy_params::Flux.Params #but only some of it is available for training physics::PhysicalModel #It would be nice to move this to somewhere else in the model. end # struct # overriding function to calculate operator loss function (operator::OperatorLoss)( state::State ,policy::Flux.Chain #allow for another policy to be subsituted ) #get actions a = policy(state) #get updated stocks and debris state′ = stake_transition(state,a) bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end #function function (operator::OperatorLoss)( state::State ) #just use the included policy. return operator(state,operator.collected_policies) end # function function train_operator!(op::OperatorLoss, data::State, opt) #Train the policy functions Flux.train!(op, op.operator_policy_params, data, opt) #Train the value function Flux.train!(op, Flux.params(op.operator_value_fn), data, opt) end #= Describe the Planners loss function =# struct PlannerLoss <: GeneralizedLoss #= Ideally, with just a well formed PlannerLoss and the training functions below, we should be able to train the approximation. There is an issue with appropriately training the value functions. In this case, it is not happening... =# #planner discount level (As it may disagree with operators) β::Float32 operators::Array{GeneralizedLoss} policy::Flux.Chain policy_params::Flux.Params value::Flux.Chain value_params::Flux.Params physical_model::PhysicalModel end function (planner::PlannerLoss)( state::State ) #TODO! Rewrite to use a new states setup. a = planner.policy((s ,d)) #TODO: States #get updated stocks and debris state′ = state_transition(s ,d ,a) #TODO: States #calculate the total benefit from each of the models benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States #issue here with mutating. Maybe generators/list comprehensions? bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s′,d′)) )#TODO: States maximization_condition = - benefit - planner.β .* planner.value((s′,d′)) #TODO: States return Flux.mae(bellman_residuals.^2 ,maximization_condition) end # function function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor) errors = [] for i = 1:N_epoch data = data_gen() Flux.train!(pl, pl.policy_params, data, opt) Flux.train!(pl, pl.value_params, data, opt) append!(errors, error(pl, data1) / 200) end return errors end # function function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor) errors = [] for i = 1:N_epoch data = data_gen() data = data_gen() for op in pl.operators train_operator!(op,data,opt) end end return errors end # function end # end module