diff --git a/julia_code/2021-11-17T08:45:09.302_planner.bson b/julia_code/2021-11-17T08:45:09.302_planner.bson deleted file mode 100644 index a4f9187..0000000 Binary files a/julia_code/2021-11-17T08:45:09.302_planner.bson and /dev/null differ diff --git a/julia_code/Module/src/Orbits.jl b/julia_code/Module/src/Orbits.jl index 7f231a8..385d3a4 100644 --- a/julia_code/Module/src/Orbits.jl +++ b/julia_code/Module/src/Orbits.jl @@ -12,7 +12,7 @@ include("flux_helpers.jl") using .NNTools # Exports - +export GeneralizedLoss , OperatorLoss, PlannerLoss, UniformDataConstructor # Code @@ -34,8 +34,7 @@ struct OperatorLoss <: GeneralizedLoss end # overriding function to calculate operator loss function (operator::OperatorLoss)( - s::Vector{Float32} - ,d::Vector{Float32} + state::State ) @@ -43,37 +42,41 @@ function (operator::OperatorLoss)( a = operator.collected_policies((s,d)) #get updated stocks and debris - s′ = operator.stocks_transition(s ,d ,a) - d′ = operator.debris_transition(s ,d ,a) + state′ = state_transition(operator.physics,state,a) - bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) + bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) - maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) + maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state′)) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end # struct function (operator::OperatorLoss)( state::State ,policy::Flux.Chain #allow for another policy to be subsituted -) + ) #get actions - a = policy((s,d)) + a = policy(state) #get updated stocks and debris - s′ = operator.stocks_transition(s ,d ,a) - d′ = operator.debris_transition(s ,d ,a) + state′ = stake_transition(state,a) - bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) + bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) - maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s′,d′)) + maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) return Flux.mae(bellman_residuals.^2 ,maximization_condition) end #function +function train_operators!(op::OperatorLoss, data::State, opt) + Flux.train!(op, op.operator_policy_params, data, opt) + Flux.train!(op, Flux.params(op.operator_value_fn), data, opt) +end + + #= Describe the Planners loss function =# @@ -116,4 +119,45 @@ function (planner::PlannerLoss)( return Flux.mae(bellman_residuals.^2 ,maximization_condition) end +function train_planner!(pl::PlannerLoss, N_epoch::Int, opt) + errors = [] + for i = 1:N_epoch + data = data_gen() + Flux.train!(pl, pl.policy_params, data, opt) + Flux.train!(pl, pl.value_params, data, opt) + append!(errors, error(pl, data1) / 200) + end + return errors +end + + +function train_operators!(pl::PlannerLoss, N_epoch::Int, opt) + errors = [] + for i = 1:N_epoch + data = data_gen() + for op in pl.operators + train_operators!(op,data,opt) + end + end + return errors +end + + +#Construct and manage data +abstract type DataConstructor end + +struct UniformDataConstructor <: DataConstructor + N::UInt64 + satellites_bottom::Float32 + satellites_top::Float32 + debris_bottom::Float32 + debris_top::Float32 +end +function (dc::UniformDataConstructor)() + #currently ignores the quantity of data it should construct. + State( + rand(dc.satellites_bottom:dc.satellites_top, N_constellations) + , rand(dc.debris_bottom:dc.debris_top, N_debris) + ) +end end # end module diff --git a/julia_code/Module/src/benefit_functions.jl b/julia_code/Module/src/benefit_functions.jl index 782d284..1957891 100644 --- a/julia_code/Module/src/benefit_functions.jl +++ b/julia_code/Module/src/benefit_functions.jl @@ -1,11 +1,11 @@ module BenefitFunctions -export LinearModel, BenefitFunction +export LinearModel, BenefitFunction, EconomicModel -import("physical_model.lj") -using PhysicalModel: State +include("physical_model.jl") +using .PhysicsModule: State #= Benefit Functions: @@ -22,9 +22,9 @@ struct LinearModel <: EconomicModel end function (em::LinearModel)( state::State - ,a::Vector{Float32} + ,actions::Vector{Float32} ) - return em.payoff_array*s - em.policy_costs*a + return em.payoff_array*state.stocks - em.policy_costs*actions end #basic CES model @@ -37,13 +37,13 @@ struct CES <: EconomicModel end function (em::CES)( state::State - ,a::Vector{Float32} + ,actions::Vector{Float32} ) #issue here with multiplication - r1 = em.payoff_array .* (s.^em.r) - r2 = - em.debris_costs .* (d.^em.r) - r3 = - em.policy_costs .* (a.^em.r) + r1 = em.payoff_array .* (state.stocks.^em.r) + r2 = - em.debris_costs .* (state.debris.^em.r) + r3 = - em.policy_costs .* (actions.^em.r) return (r1 + r2 + r3) .^ (1/em.r) end @@ -58,10 +58,10 @@ struct CRRA <: EconomicModel end function (em::CRRA)( state::State - ,a::Vector{Float32} + ,actions::Vector{Float32} ) #issue here with multiplication - core = (em.payoff_array*s - em.debris_costs*d - em.policy_costs).^(1 - em.σ) + core = (em.payoff_array*state.stocks - em.debris_costs*state.debris - em.policy_costs*actions).^(1 - em.σ) return (core-1)/(1-em.σ) end diff --git a/julia_code/Module/src/flux_helpers.jl b/julia_code/Module/src/flux_helpers.jl index 1a7046e..4503b88 100644 --- a/julia_code/Module/src/flux_helpers.jl +++ b/julia_code/Module/src/flux_helpers.jl @@ -1,17 +1,24 @@ -module NN Tools +module NNTools export BranchGenerator, value_function_generator using Flux, BSON, Zygote +include("physical_model.jl") +using .PhysicsModule: State,tuple + + + #= TupleDuplicator This is used to create a tuple of size n with deepcopies of any object x =# struct TupleDuplicator n::Int end +#make tuples consisting of copies of whatever was provided (f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...) - +#do the same but for the case of states, coerce them to tuples first +(f::TupleDuplicator)(x::State) = tuple([deepcopy(tuple(x)) for i=1:f.n]...) diff --git a/julia_code/Module/src/fundamentals.jl b/julia_code/Module/src/fundamentals.jl deleted file mode 100644 index 0d41125..0000000 --- a/julia_code/Module/src/fundamentals.jl +++ /dev/null @@ -1,15 +0,0 @@ -module Fundamentals - -#Add exports here -export State - - -#= -The state of the world. -=# -struct State - stocks::Array{Float32} - debris::Array{Float32} -end - -end #end module diff --git a/julia_code/Module/src/physical_model.jl b/julia_code/Module/src/physical_model.jl index a68e296..2e0f8e5 100644 --- a/julia_code/Module/src/physical_model.jl +++ b/julia_code/Module/src/physical_model.jl @@ -1,7 +1,7 @@ module PhysicsModule #Add exports here -export State, PhysicalModel, BasicModel +export State ,tuple ,state_transition ,PhysicalModel ,BasicModel using Flux, LinearAlgebra @@ -13,6 +13,9 @@ struct State debris::Array{Float32} end +function tuple(st::State) + return (st.stocks, st.debris) +end ### Physics @@ -43,7 +46,7 @@ function state_transition( =# survival_rate = survival(state,physics) - # Debris + # Debris transitions # get changes in debris from natural dynamics natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris @@ -58,7 +61,7 @@ function state_transition( debris′ = natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris - #stocks + # Stocks Transitions stocks′ = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches @@ -69,10 +72,8 @@ end function survival( state::State ,physical_model::BasicModel - ) - return exp.( - -(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks - .- (physical_model.debris_collision_rate * state.debris) - ) + ) + return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris)) end + end #end module