Refactored into files

temporaryWork^2
will king 4 years ago
parent cd6c414ca2
commit db44fa86cc

@ -12,7 +12,7 @@ include("flux_helpers.jl")
using .NNTools using .NNTools
# Exports # Exports
export GeneralizedLoss , OperatorLoss, PlannerLoss, UniformDataConstructor
# Code # Code
@ -34,8 +34,7 @@ struct OperatorLoss <: GeneralizedLoss
end end
# overriding function to calculate operator loss # overriding function to calculate operator loss
function (operator::OperatorLoss)( function (operator::OperatorLoss)(
s::Vector{Float32} state::State
,d::Vector{Float32}
) )
@ -43,13 +42,12 @@ function (operator::OperatorLoss)(
a = operator.collected_policies((s,d)) a = operator.collected_policies((s,d))
#get updated stocks and debris #get updated stocks and debris
s = operator.stocks_transition(s ,d ,a) state = state_transition(operator.physics,state,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d)) bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d)) maximization_condition = - operator.econ_model(state ,a) - operator.econ_model.β * operator.operator_value_fn((state))
return Flux.mae(bellman_residuals.^2 ,maximization_condition) return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end # struct end # struct
@ -60,20 +58,25 @@ function (operator::OperatorLoss)(
#get actions #get actions
a = policy((s,d)) a = policy(state)
#get updated stocks and debris #get updated stocks and debris
s = operator.stocks_transition(s ,d ,a) state = stake_transition(state,a)
d = operator.debris_transition(s ,d ,a)
bellman_residuals = operator.operator_value_fn((s,d)) - operator.econ_model(s,d,a) - operator.econ_model.β * operator.operator_value_fn((s,d)) bellman_residuals = operator.operator_value_fn(state) - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
maximization_condition = - operator.econ_model(s ,d ,a) - operator.econ_model.β * operator.operator_value_fn((s,d)) maximization_condition = - operator.econ_model(state,a) - operator.econ_model.β * operator.operator_value_fn(state)
return Flux.mae(bellman_residuals.^2 ,maximization_condition) return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end #function end #function
function train_operators!(op::OperatorLoss, data::State, opt)
Flux.train!(op, op.operator_policy_params, data, opt)
Flux.train!(op, Flux.params(op.operator_value_fn), data, opt)
end
#= #=
Describe the Planners loss function Describe the Planners loss function
=# =#
@ -116,4 +119,45 @@ function (planner::PlannerLoss)(
return Flux.mae(bellman_residuals.^2 ,maximization_condition) return Flux.mae(bellman_residuals.^2 ,maximization_condition)
end end
function train_planner!(pl::PlannerLoss, N_epoch::Int, opt)
errors = []
for i = 1:N_epoch
data = data_gen()
Flux.train!(pl, pl.policy_params, data, opt)
Flux.train!(pl, pl.value_params, data, opt)
append!(errors, error(pl, data1) / 200)
end
return errors
end
function train_operators!(pl::PlannerLoss, N_epoch::Int, opt)
errors = []
for i = 1:N_epoch
data = data_gen()
for op in pl.operators
train_operators!(op,data,opt)
end
end
return errors
end
#Construct and manage data
abstract type DataConstructor end
struct UniformDataConstructor <: DataConstructor
N::UInt64
satellites_bottom::Float32
satellites_top::Float32
debris_bottom::Float32
debris_top::Float32
end
function (dc::UniformDataConstructor)()
#currently ignores the quantity of data it should construct.
State(
rand(dc.satellites_bottom:dc.satellites_top, N_constellations)
, rand(dc.debris_bottom:dc.debris_top, N_debris)
)
end
end # end module end # end module

@ -1,11 +1,11 @@
module BenefitFunctions module BenefitFunctions
export LinearModel, BenefitFunction export LinearModel, BenefitFunction, EconomicModel
import("physical_model.lj") include("physical_model.jl")
using PhysicalModel: State using .PhysicsModule: State
#= #=
Benefit Functions: Benefit Functions:
@ -22,9 +22,9 @@ struct LinearModel <: EconomicModel
end end
function (em::LinearModel)( function (em::LinearModel)(
state::State state::State
,a::Vector{Float32} ,actions::Vector{Float32}
) )
return em.payoff_array*s - em.policy_costs*a return em.payoff_array*state.stocks - em.policy_costs*actions
end end
#basic CES model #basic CES model
@ -37,13 +37,13 @@ struct CES <: EconomicModel
end end
function (em::CES)( function (em::CES)(
state::State state::State
,a::Vector{Float32} ,actions::Vector{Float32}
) )
#issue here with multiplication #issue here with multiplication
r1 = em.payoff_array .* (s.^em.r) r1 = em.payoff_array .* (state.stocks.^em.r)
r2 = - em.debris_costs .* (d.^em.r) r2 = - em.debris_costs .* (state.debris.^em.r)
r3 = - em.policy_costs .* (a.^em.r) r3 = - em.policy_costs .* (actions.^em.r)
return (r1 + r2 + r3) .^ (1/em.r) return (r1 + r2 + r3) .^ (1/em.r)
end end
@ -58,10 +58,10 @@ struct CRRA <: EconomicModel
end end
function (em::CRRA)( function (em::CRRA)(
state::State state::State
,a::Vector{Float32} ,actions::Vector{Float32}
) )
#issue here with multiplication #issue here with multiplication
core = (em.payoff_array*s - em.debris_costs*d - em.policy_costs).^(1 - em.σ) core = (em.payoff_array*state.stocks - em.debris_costs*state.debris - em.policy_costs*actions).^(1 - em.σ)
return (core-1)/(1-em.σ) return (core-1)/(1-em.σ)
end end

@ -4,14 +4,21 @@ export BranchGenerator, value_function_generator
using Flux, BSON, Zygote using Flux, BSON, Zygote
include("physical_model.jl")
using .PhysicsModule: State,tuple
#= TupleDuplicator #= TupleDuplicator
This is used to create a tuple of size n with deepcopies of any object x This is used to create a tuple of size n with deepcopies of any object x
=# =#
struct TupleDuplicator struct TupleDuplicator
n::Int n::Int
end end
#make tuples consisting of copies of whatever was provided
(f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...) (f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...)
#do the same but for the case of states, coerce them to tuples first
(f::TupleDuplicator)(x::State) = tuple([deepcopy(tuple(x)) for i=1:f.n]...)

@ -1,15 +0,0 @@
module Fundamentals
#Add exports here
export State
#=
The state of the world.
=#
struct State
stocks::Array{Float32}
debris::Array{Float32}
end
end #end module

@ -1,7 +1,7 @@
module PhysicsModule module PhysicsModule
#Add exports here #Add exports here
export State, PhysicalModel, BasicModel export State ,tuple ,state_transition ,PhysicalModel ,BasicModel
using Flux, LinearAlgebra using Flux, LinearAlgebra
@ -13,6 +13,9 @@ struct State
debris::Array{Float32} debris::Array{Float32}
end end
function tuple(st::State)
return (st.stocks, st.debris)
end
### Physics ### Physics
@ -43,7 +46,7 @@ function state_transition(
=# =#
survival_rate = survival(state,physics) survival_rate = survival(state,physics)
# Debris # Debris transitions
# get changes in debris from natural dynamics # get changes in debris from natural dynamics
natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris natural_debris_dynamics = (1 - physics.decay_rate + physics.autocatalysis_rate) * state.debris
@ -58,7 +61,7 @@ function state_transition(
debris = natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris debris = natural_debris_dynamics .+ satellite_loss_debris .+ launch_debris
#stocks # Stocks Transitions
stocks = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches stocks = (LinearAlgebra.diagm(survival_rate) .- physics.decay_rate)*state.stocks + launches
@ -70,9 +73,7 @@ function survival(
state::State state::State
,physical_model::BasicModel ,physical_model::BasicModel
) )
return exp.( return exp.(-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks .- (physical_model.debris_collision_rate * state.debris))
-(physical_model.satellite_collision_rates .+ physical_model.decay_rate) * state.stocks
.- (physical_model.debris_collision_rate * state.debris)
)
end end
end #end module end #end module

Loading…
Cancel
Save