module Orbits using Flux,LinearAlgebra,Zygote include("physical_model.jl") using .PhysicsModule include("benefit_functions.jl") using .BenefitFunctions include("flux_helpers.jl") using .NNTools # Exports # Code abstract type GeneralizedLoss end #= This struct organizes information about a given constellation operator Used to provide an interable loss function for training =# struct OperatorLoss <: GeneralizedLoss #econ model describing operator econ_model::EconomicModel #Operator's value and policy functions, as well as which parameters the operator can train. operator_value_fn::Flux.Chain collected_policies::Flux.Chain #this is held by all operators operator_policy_params::Flux.Params physics::PhysicalModel end # overriding function to calculate operator loss function (operator::OperatorLoss)( s::Vector{Float32} ,d::Vector{Float32} ) #get actions a = operator.collected_policies((s,d)) #get updated stocks and debris s′ = operator.stocks_transition(s ,d ,a) d′ = operator.debris_transition(s ,d ,a) bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end # struct function (operator::OperatorLoss)( state::State ,policy::Flux.Chain #allow for another policy to be subsituted ) #get actions a = policy((s,d)) #get updated stocks and debris s′ = operator.stocks_transition(s ,d ,a) d′ = operator.debris_transition(s ,d ,a) bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end #function #= Describe the Planners loss function =# struct PlannerLoss <: GeneralizedLoss #= Ideally, with just a well formed PlannerLoss and the training functions below, we should be able to train the approximation. There is an issue with appropriately training the value functions. In this case, it is not happening... =# β::Float32 operators::Array{GeneralizedLoss} policy::Flux.Chain policy_params::Flux.Params value::Flux.Chain value_params::Flux.Params physical_model::PhysicalModel end function (planner::PlannerLoss)( state::State ) a = planner.policy((s ,d)) #get updated stocks and debris state′ = state_transition(s ,d ,a) #calculate the total benefit from each of the models benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #issue here with mutating. Maybe generators/list comprehensions? bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s′,d′)) maximization_condition = - benefit - planner.β .* planner.value((s′,d′)) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end end # end module