|
|
|
|
|
|
|
|
#= TupleDuplicator
|
|
|
This is used to create a tuple of size n with deepcopies of any object x
|
|
|
=#
|
|
|
struct TupleDuplicator
|
|
|
n::Int
|
|
|
end
|
|
|
#make tuples consisting of copies of whatever was provided
|
|
|
function (f::TupleDuplicator)(s::State)
|
|
|
st = state_to_tuple(s)
|
|
|
return f(st)
|
|
|
#BROKEN: Fails in test, but works in REPL.
|
|
|
end
|
|
|
|
|
|
(f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...)
|
|
|
|
|
|
|
|
|
|
|
|
struct BranchGenerator
|
|
|
n::UInt
|
|
|
end
|
|
|
function (b::BranchGenerator)(branch::Flux.Parallel,join_fn::Function)
|
|
|
# used to deepcopy the branches and duplicate the inputs in the returned chain
|
|
|
f = TupleDuplicator(b.n)
|
|
|
|
|
|
return Flux.Chain(state_to_tuple,f,Flux.Parallel(join_fn,f(branch)...))
|
|
|
#note that it destructures the state to a tuple, duplicates,
|
|
|
# and then passes to the parallelized functions.
|
|
|
end
|
|
|
|
|
|
|
|
|
# Neural Network Generators
|
|
|
|
|
|
function value_function_generator(number_params=32)
|
|
|
return Flux.Chain(
|
|
|
Flux.Parallel(vcat
|
|
|
#parallel joins together stocks and debris, after a little bit of preprocessing
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_constellations, number_params*2,Flux.relu)
|
|
|
,Flux.Dense(number_params*2, number_params*2,Flux.σ)
|
|
|
)
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_debris, number_params,Flux.relu)
|
|
|
,Flux.Dense(number_params, number_params,Flux.σ)
|
|
|
)
|
|
|
)
|
|
|
#Apply some transformations to the preprocessed data.
|
|
|
,Flux.Dense(number_params*3,number_params,Flux.σ)
|
|
|
,Flux.Dense(number_params,number_params,Flux.σ)
|
|
|
,Flux.Dense(number_params,1)
|
|
|
)
|
|
|
|
|
|
end
|
|
|
|
|
|
function cross_linked_planner_policy_function_generator(number_params=32)
|
|
|
return Flux.Chain(
|
|
|
Flux.Parallel(vcat
|
|
|
#parallel joins together stocks and debris
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_constellations, number_params*2,Flux.relu)
|
|
|
,Flux.Dense(number_params*2, number_params*2,Flux.σ)
|
|
|
)
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_debris, number_params,Flux.relu)
|
|
|
,Flux.Dense(number_params, number_params)
|
|
|
)
|
|
|
)
|
|
|
#Apply some transformations
|
|
|
,Flux.Dense(number_params*3,number_params,Flux.σ)
|
|
|
,Flux.Dense(number_params,number_params,Flux.σ)
|
|
|
,Flux.Dense(number_params,N_constellations,Flux.relu)
|
|
|
)
|
|
|
|
|
|
end
|
|
|
|
|
|
function operator_policy_function_generator(
|
|
|
N_constellations::Int
|
|
|
,N_debris
|
|
|
,number_params=32
|
|
|
)
|
|
|
return Flux.Chain(
|
|
|
Flux.Parallel(vcat
|
|
|
#parallel joins together stocks and debris
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_constellations, number_params*2,Flux.relu)
|
|
|
,Flux.Dense(number_params*2, number_params*2,Flux.tanh)
|
|
|
)
|
|
|
,Flux.Chain(
|
|
|
Flux.Dense(N_debris, number_params,Flux.relu)
|
|
|
,Flux.Dense(number_params, number_params)
|
|
|
)
|
|
|
)
|
|
|
#Apply some transformations
|
|
|
,Flux.Dense(number_params*3,number_params,Flux.tanh)
|
|
|
,Flux.Dense(number_params,number_params,Flux.tanh)
|
|
|
,Flux.Dense(number_params,1,x -> Flux.relu.(sinh.(x)))
|
|
|
)
|
|
|
|
|
|
end
|
|
|
|