|
|
|
|
@ -12,7 +12,7 @@ include("flux_helpers.jl")
|
|
|
|
|
using .NNTools
|
|
|
|
|
|
|
|
|
|
# Exports
|
|
|
|
|
|
|
|
|
|
export GeneralizedLoss , OperatorLoss, PlannerLoss, UniformDataConstructor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Code
|
|
|
|
|
@ -34,8 +34,7 @@ struct OperatorLoss <: GeneralizedLoss
|
|
|
|
|
end
|
|
|
|
|
# overriding function to calculate operator loss
|
|
|
|
|
function (operator::OperatorLoss)(
|
|
|
|
|
s::Vector{Float32}
|
|
|
|
|
,d::Vector{Float32}
|
|
|
|
|
state::State
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -43,13 +42,12 @@ function (operator::OperatorLoss)(
|
|
|
|
|
a = operator.collected_policies((s,d))
|
|
|
|
|
|
|
|
|
|
#get updated stocks and debris
|
|
|
|
|
s′ = operator.stocks_transition(s ,d ,a)
|
|
|
|
|
d′ = operator.debris_transition(s ,d ,a)
|
|
|
|
|
state′ = state_transition(operator.physics,state,a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′))
|
|
|
|
|
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′)
|
|
|
|
|
|
|
|
|
|
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′))
|
|
|
|
|
maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state′))
|
|
|
|
|
|
|
|
|
|
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
|
|
|
|
|
end # struct
|
|
|
|
|
@ -60,20 +58,25 @@ function (operator::OperatorLoss)(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#get actions
|
|
|
|
|
a = policy((s,d))
|
|
|
|
|
a = policy(state)
|
|
|
|
|
|
|
|
|
|
#get updated stocks and debris
|
|
|
|
|
s′ = operator.stocks_transition(s ,d ,a)
|
|
|
|
|
d′ = operator.debris_transition(s ,d ,a)
|
|
|
|
|
state′ = stake_transition(state,a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′))
|
|
|
|
|
bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′)
|
|
|
|
|
|
|
|
|
|
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′))
|
|
|
|
|
maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′)
|
|
|
|
|
|
|
|
|
|
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
|
|
|
|
|
end #function
|
|
|
|
|
|
|
|
|
|
function train_operators!(op::OperatorLoss, data::State, opt)
|
|
|
|
|
Flux.train!(op, op.operator_policy_params, data, opt)
|
|
|
|
|
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#=
|
|
|
|
|
Describe the Planners loss function
|
|
|
|
|
=#
|
|
|
|
|
@ -116,4 +119,45 @@ function (planner::PlannerLoss)(
|
|
|
|
|
return Flux.mae(bellman_residuals.^2 ,maximization_condition)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
|
|
|
|
|
errors = []
|
|
|
|
|
for i = 1:N_epoch
|
|
|
|
|
data = data_gen()
|
|
|
|
|
Flux.train!(pl, pl.policy_params, data, opt)
|
|
|
|
|
Flux.train!(pl, pl.value_params, data, opt)
|
|
|
|
|
append!(errors, error(pl, data1) / 200)
|
|
|
|
|
end
|
|
|
|
|
return errors
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt)
|
|
|
|
|
errors = []
|
|
|
|
|
for i = 1:N_epoch
|
|
|
|
|
data = data_gen()
|
|
|
|
|
for op in pl.operators
|
|
|
|
|
train_operators!(op,data,opt)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return errors
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#Construct and manage data
|
|
|
|
|
abstract type DataConstructor end
|
|
|
|
|
|
|
|
|
|
struct UniformDataConstructor <: DataConstructor
|
|
|
|
|
N::UInt64
|
|
|
|
|
satellites_bottom::Float32
|
|
|
|
|
satellites_top::Float32
|
|
|
|
|
debris_bottom::Float32
|
|
|
|
|
debris_top::Float32
|
|
|
|
|
end
|
|
|
|
|
function (dc::UniformDataConstructor)()
|
|
|
|
|
#currently ignores the quantity of data it should construct.
|
|
|
|
|
State(
|
|
|
|
|
rand(dc.satellites_bottom:dc.satellites_top, N_constellations)
|
|
|
|
|
, rand(dc.debris_bottom:dc.debris_top, N_debris)
|
|
|
|
|
)
|
|
|
|
|
end
|
|
|
|
|
end # end module
|
|
|
|
|
|