You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.4 KiB
Julia

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

module Orbits
using Flux,LinearAlgebra,Zygote
include("physical_model.jl")
using .PhysicsModule
include("benefit_functions.jl")
using .BenefitFunctions
include("flux_helpers.jl")
using .NNTools
# Exports
# Code
abstract type GeneralizedLoss end
#=
This struct organizes information about a given constellation operator
Used to provide an interable loss function for training
=#
struct OperatorLoss <: GeneralizedLoss
#econ model describing operator
econ_model::EconomicModel
#Operator's value and policy functions, as well as which parameters the operator can train.
operator_value_fn::Flux.Chain
collected_policies::Flux.Chain #this is held by all operators
operator_policy_params::Flux.Params
physics::PhysicalModel
end
# overriding function to calculate operator loss
function (operator::OperatorLoss)(
s::Vector{Float32}
,d::Vector{Float32}
)
#get actions
a = operator.collected_policies((s,d))
#get updated stocks and debris
s = operator.stocks_transition(s ,d ,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end # struct
function (operator::OperatorLoss)(
state::State
,policy::Flux.Chain #allow for another policy to be subsituted
)
#get actions
a = policy((s,d))
#get updated stocks and debris
s = operator.stocks_transition(s ,d ,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function
#=
Describe the Planners loss function
=#
struct PlannerLoss <: GeneralizedLoss
#=
Ideally, with just a well formed PlannerLoss and the training functions below, we should be able to train the approximation.
There is an issue with appropriately training the value functions.
In this case, it is not happening...
=#
β::Float32
operators::Array{GeneralizedLoss}
policy::Flux.Chain
policy_params::Flux.Params
value::Flux.Chain
value_params::Flux.Params
physical_model::PhysicalModel
end
function (planner::PlannerLoss)(
state::State
)
a = planner.policy((s ,d))
#get updated stocks and debris
state = state_transition(s ,d ,a)
#calculate the total benefit from each of the models
benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators])
#issue here with mutating. Maybe generators/list comprehensions?
bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s,d))
maximization_condition = - benefit - planner.β .* planner.value((s,d))
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end
end # end module