#= TupleDuplicator This is used to create a tuple of size n with deepcopies of any object x =# struct TupleDuplicator n::Int end #make tuples consisting of copies of whatever was provided function (f::TupleDuplicator)(s::State) st = state_to_tuple(s) return f(st) #BROKEN: Fails in test, but works in REPL. end (f::TupleDuplicator)(x) = tuple([deepcopy(x) for i=1:f.n]...) struct BranchGenerator n::UInt end function (b::BranchGenerator)(branch::Flux.Parallel,join_fn::Function) # used to deepcopy the branches and duplicate the inputs in the returned chain f = TupleDuplicator(b.n) return Flux.Chain(state_to_tuple,f,Flux.Parallel(join_fn,f(branch)...)) #note that it destructures the state to a tuple, duplicates, # and then passes to the parallelized functions. end # Neural Network Generators function value_function_generator(number_params=32) return Flux.Chain( Flux.Parallel(vcat #parallel joins together stocks and debris, after a little bit of preprocessing ,Flux.Chain( Flux.Dense(N_constellations, number_params*2,Flux.relu) ,Flux.Dense(number_params*2, number_params*2,Flux.σ) ) ,Flux.Chain( Flux.Dense(N_debris, number_params,Flux.relu) ,Flux.Dense(number_params, number_params,Flux.σ) ) ) #Apply some transformations to the preprocessed data. ,Flux.Dense(number_params*3,number_params,Flux.σ) ,Flux.Dense(number_params,number_params,Flux.σ) ,Flux.Dense(number_params,1) ) end function cross_linked_planner_policy_function_generator(number_params=32) return Flux.Chain( Flux.Parallel(vcat #parallel joins together stocks and debris ,Flux.Chain( Flux.Dense(N_constellations, number_params*2,Flux.relu) ,Flux.Dense(number_params*2, number_params*2,Flux.σ) ) ,Flux.Chain( Flux.Dense(N_debris, number_params,Flux.relu) ,Flux.Dense(number_params, number_params) ) ) #Apply some transformations ,Flux.Dense(number_params*3,number_params,Flux.σ) ,Flux.Dense(number_params,number_params,Flux.σ) ,Flux.Dense(number_params,N_constellations,Flux.relu) ) end function operator_policy_function_generator( N_constellations::Int ,N_debris ,number_params=32 ) return Flux.Chain( Flux.Parallel(vcat #parallel joins together stocks and debris ,Flux.Chain( Flux.Dense(N_constellations, number_params*2,Flux.relu) ,Flux.Dense(number_params*2, number_params*2,Flux.tanh) ) ,Flux.Chain( Flux.Dense(N_debris, number_params,Flux.relu) ,Flux.Dense(number_params, number_params) ) ) #Apply some transformations ,Flux.Dense(number_params*3,number_params,Flux.tanh) ,Flux.Dense(number_params,number_params,Flux.tanh) ,Flux.Dense(number_params,1,x -> Flux.relu.(sinh.(x))) ) end