diff --git a/julia_code/.vscode/settings.json b/julia_code/.vscode/settings.json index 25c729f..7bbc8c6 100644 --- a/julia_code/.vscode/settings.json +++ b/julia_code/.vscode/settings.json @@ -1,3 +1,4 @@ { - "python.pythonPath": "/bin/python3" + "python.pythonPath": "/bin/python3", + "editor.detectIndentation": false } \ No newline at end of file diff --git a/julia_code/Module/src/Orbits.jl b/julia_code/Module/src/Orbits.jl index a165c1a..50073b8 100644 --- a/julia_code/Module/src/Orbits.jl +++ b/julia_code/Module/src/Orbits.jl @@ -15,68 +15,91 @@ export operator_policy_function_generator, cross_linked_planner_policy_function_generator # Exports from below -export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor, - train_planner!, train_operators! +export GeneralizedLoss ,OperatorLoss ,PlannerLoss ,UniformDataConstructor , + train_planner!, train_operators! , + BasicPhysics, survival_rates_1 # Code -abstract type GeneralizedLoss end - -#= -This struct organizes information about a given constellation operator -Used to provide an interable loss function for training -=# -struct OperatorLoss <: GeneralizedLoss - #econ model describing operator - econ_model::EconomicModel - #Operator's value and policy functions, as well as which parameters the operator can train. - operator_value_fn::Flux.Chain - collected_policies::Flux.Chain #this is held by all operators - operator_policy_params::Flux.Params - physics::PhysicalModel -end -# overriding function to calculate operator loss -function (operator::OperatorLoss)( - state::State -) +#Construct and manage data +abstract type DataConstructor end - #get actions - a = operator.collected_policies((s,d)) +struct UniformDataConstructor <: DataConstructor + N::UInt64 + satellites_bottom::Float32 + satellites_top::Float32 + debris_bottom::Float32 + debris_top::Float32 +end # struct +function (dc::UniformDataConstructor)(N_constellations,N_debris) + return State( + rand(dc.satellites_bottom:dc.satellites_top, dc.N, 1, N_constellations) + , rand(dc.debris_bottom:dc.debris_top, dc.N, 1, N_debris) + ) +end # function - #get updated stocks and debris - state′ = state_transition(operator.physics,state,a) - bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) - maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state′)) - return Flux.mae(bellman_residuals.^2 ,maximization_condition) -end # struct -function (operator::OperatorLoss)( - state::State - ,policy::Flux.Chain #allow for another policy to be subsituted - ) - #get actions - a = policy(state) +abstract type GeneralizedLoss end - #get updated stocks and debris - state′ = stake_transition(state,a) +struct OperatorLoss <: GeneralizedLoss + #= + This struct organizes information about a given constellation operator + It is used to provide an iterable loss function for training. + The fields each identify one aspect of the operator's decision problem: + - The economic model describing payoffs and discounting. + - The estimated NN value function held by the operator. + - The estimated NN policy function that describes each of the + - Each operator holds a reference to the parameters they can update. + - The physical world that describes how satellites progress. + There is an overriding function that uses these details to calculate the residual function based on + the bellman residuals and a maximization condition. + There are two versions of this function. + - return an actual loss calculation (MAE) using the owned policy function. + - return the loss calculation using a provided policy function. + =# + #econ model describing operator + econ_model::EconomicModel + #Operator's value and policy functions, as well as which parameters the operator can train. + operator_value_fn::Flux.Chain + collected_policies::Flux.Chain #this is held by all operators + operator_policy_params::Flux.Params #but only some of it is available for training + physics::PhysicalModel #It would be nice to move this to somewhere else in the model. +end # struct +# overriding function to calculate operator loss +function (operator::OperatorLoss)( + state::State + ,policy::Flux.Chain #allow for another policy to be subsituted +) + #get actions + a = policy(state) + #get updated stocks and debris + state′ = stake_transition(state,a) - bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) + bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) - maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) + maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state′) - return Flux.mae(bellman_residuals.^2 ,maximization_condition) + return Flux.mae(bellman_residuals.^2 ,maximization_condition) end #function - -function train_operators!(op::OperatorLoss, data::State, opt) - Flux.train!(op, op.operator_policy_params, data, opt) - Flux.train!(op, Flux.params(op.operator_value_fn), data, opt) +function (operator::OperatorLoss)( + state::State +) + #just use the included policy. + return operator(state,operator.collected_policies) +end # function + +function train_operator!(op::OperatorLoss, data::State, opt) + #Train the policy functions + Flux.train!(op, op.operator_policy_params, data, opt) + #Train the value function + Flux.train!(op, Flux.params(op.operator_value_fn), data, opt) end @@ -90,6 +113,7 @@ struct PlannerLoss <: GeneralizedLoss There is an issue with appropriately training the value functions. In this case, it is not happening... =# + #planner discount level (As it may disagree with operators) β::Float32 operators::Array{GeneralizedLoss} @@ -104,64 +128,47 @@ end function (planner::PlannerLoss)( state::State ) - a = planner.policy((s ,d)) +#TODO! Rewrite to use a new states setup. + a = planner.policy((s ,d)) #TODO: States #get updated stocks and debris - state′ = state_transition(s ,d ,a) - + state′ = state_transition(s ,d ,a) #TODO: States #calculate the total benefit from each of the models - benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) + benefit = sum([ co.econ_model(s ,d ,a) for co in planner.operators]) #TODO: States #issue here with mutating. Maybe generators/list comprehensions? - bellman_residuals = planner.value((s,d)) - benefit - planner.β .* planner.value((s′,d′)) + bellman_residuals = planner.value((s,d)) - benefit - ( planner.β .* planner.value((s′,d′)) )#TODO: States - maximization_condition = - benefit - planner.β .* planner.value((s′,d′)) + maximization_condition = - benefit - planner.β .* planner.value((s′,d′)) #TODO: States return Flux.mae(bellman_residuals.^2 ,maximization_condition) -end +end # function -function train_planner!(pl::PlannerLoss, N_epoch::Int, opt) +function train_planner!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor) errors = [] for i = 1:N_epoch - data = data_gen() + data = data_gen() Flux.train!(pl, pl.policy_params, data, opt) Flux.train!(pl, pl.value_params, data, opt) append!(errors, error(pl, data1) / 200) end return errors -end +end # function -function train_operators!(pl::PlannerLoss, N_epoch::Int, opt) +function train_operators!(pl::PlannerLoss, N_epoch::Int, opt, data_gen::DataConstructor) errors = [] for i = 1:N_epoch + data = data_gen() data = data_gen() for op in pl.operators - train_operators!(op,data,opt) + train_operator!(op,data,opt) end end return errors -end +end # function -#Construct and manage data -abstract type DataConstructor end - -struct UniformDataConstructor <: DataConstructor - N::UInt64 - satellites_bottom::Float32 - satellites_top::Float32 - debris_bottom::Float32 - debris_top::Float32 -end -function (dc::UniformDataConstructor)() - #currently ignores the quantity of data it should construct. - State( - rand(dc.satellites_bottom:dc.satellites_top, N_constellations) - , rand(dc.debris_bottom:dc.debris_top, N_debris) - ) -end - end # end module diff --git a/julia_code/Module/src/physical_model.jl b/julia_code/Module/src/physical_model.jl index c6896bc..f9fa4b9 100644 --- a/julia_code/Module/src/physical_model.jl +++ b/julia_code/Module/src/physical_model.jl @@ -1,21 +1,29 @@ -#==# -struct State - stocks::Array{Float32} - debris::Array{Float32} -end +#=Satellite State=# +abstract type State end -function state_to_tuple(s::State) - return (s.stocks ,s.debris) +struct SingleStates <: State + stocks::Vector{Float32} + debris::Vector{Float32} end -### Physics +struct MultiStates <: State + stocks::Array{Float32} + debris::Array{Float32} +end +#function state_to_tuple(s::State) +# return (s.stocks ,s.debris) +#end +#=Physical Model +This contains parameters describing the physical model. +=# abstract type PhysicalModel end +#=Basic implementation of a physical model=# struct BasicPhysics <: PhysicalModel #rate at which debris hits satellites debris_collision_rate::Real @@ -30,43 +38,44 @@ struct BasicPhysics <: PhysicalModel #Ratio at which launches produce debris launch_debris_ratio::Real end -function state_transition( - physics::BasicPhysics - ,state::State - ,launches::Vector{Float32} - ) - #= - Physical Transitions - =# - survival_rate = survival(state,physics) - - # Debris transitions - - # get changes in debris from natural dynamics - natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris - - # get changes in debris from satellite loss - satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rate)' * state.stocks - - # get changes in debris from launches - launch_debris = physics.launch_debris_ratio * sum(launches) - - # total debris level - debris′ = natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris - - - # Stocks Transitions - stocks′ = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches - - return State(stocks′,debris′) +function state_transition( + physics::BasicPhysics + ,state::State + ,launches::Vector{Float32} + ,survival_rate::Function +) + #= + Physical Transitions + =# + survival_rates = survival_rate(state,physics) + + # Debris transitions + + # get changes in debris from natural dynamics + natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris + + # get changes in debris from satellite loss + satellite_loss_debris = physics.satellite_collision_debris_ratio * (1 .- survival_rates)' * state.stocks + + # get changes in debris from launches + launch_debris = physics.launch_debris_ratio * sum(launches) + + # total debris level + debris′ = natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris + + + # Stocks Transitions + stocks′ = (LinearAlgebra.diagm(survival_rates) .- physics.decay_rate)*state.stocks + launches + + return State(stocks′,debris′) end - -function survival( - state::State - ,physical_model::BasicPhysics - ) - return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris)) +function survival_rates_1( + #This function describes the rate at which satellites survive each period. + state::State + ,physical_model::BasicPhysics +) +#TODO! get this to broadcast correctly. + return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris)) end - diff --git a/julia_code/Module/tests/test_states_and_transitions.jl b/julia_code/Module/tests/test_states_and_transitions.jl new file mode 100644 index 0000000..1d2d278 --- /dev/null +++ b/julia_code/Module/tests/test_states_and_transitions.jl @@ -0,0 +1,62 @@ +using Test, Flux, LinearAlgebra + +include("../src/Orbits.jl") +using .Orbits + +#= +The purpose of this document is to organize tests of the state structs and state transition functions +=# + +@testset "States and Physical models testing" verbose=true begin + n_const = 2 + n_debr = 3 + n_data = 5 + + #built structs + u = UniformDataConstructor(n_data,0,5,2,3) + s = u(n_const,n_debr) + b = BasicPhysics( + 0.05 + ,0.02*LinearAlgebra.ones(n_const,n_const) + ,0.1 + ,0.002 + ,0.002 + ,0.2 + ) + + a2 = ones(n_const,n_data) + + #test that dimensions match etc + @test size(b.satellite_collision_rates)[1] == size(s.stocks)[1] + @test size(b.satellite_collision_rates)[2] == size(s.stocks)[1] + @test n_data == size(s.stocks)[2] + @test n_data == size(s.debris)[2] + @test size(s.stocks) == size(a2) + + @testset "DataConstructor and states" begin + @test u.N == 5 + + @test length(s.debris) != 3 + @test length(s.stocks) != 2 + + @test length(s.stocks) == 10 + @test length(s.debris) == 15 + + @test size(s.stocks) == (2,5) + @test size(s.debris) == (3,5) + end + + @testset "BasicPhysics" begin + @testset "Survival Functions" verbose = true begin + @test survival_rates_1(s,b) <: AbstractArray + end + + @testset "Transitions" begin + @test true + end + end + + + + +end #States and physcial models testing