{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "41dcca64-963f-488e-b92e-f1dc5109359a", "metadata": {}, "outputs": [], "source": [ "using Enzyme" ] }, { "cell_type": "code", "execution_count": 2, "id": "ca966ab8-469e-4f8c-af54-579c55f54bd4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "()" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "function mymul!(R, A, B)\n", " @assert axes(A,2) == axes(B,1)\n", " @inbounds @simd for i in eachindex(R)\n", " R[i] = 0\n", " end\n", " @inbounds for j in axes(B, 2), i in axes(A, 1)\n", " @inbounds @simd for k in axes(A,2)\n", " R[i,j] += A[i,k] * B[k,j]\n", " end\n", " end\n", " nothing\n", "end\n", "\n", "\n", "A = rand(5, 3)\n", "B = rand(3, 7)\n", "\n", "R = zeros(size(A,1), size(B,2))\n", "∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n", "\n", "∂z_∂A = zero(A)\n", "∂z_∂B = zero(B)\n", "\n", "Enzyme.autodiff(mymul!, Const, Duplicated(R, ∂z_∂R), Duplicated(A, ∂z_∂A), Duplicated(B, ∂z_∂B))" ] }, { "cell_type": "code", "execution_count": 3, "id": "7442bfb5-3146-493d-9abc-9afeb56c0471", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "true" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "R ≈ A * B &&\n", "∂z_∂A ≈ ∂z_∂R * B' && # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[1]\n", "∂z_∂B ≈ A' * ∂z_∂R # equivalent to Zygote.pullback(*, A, B)[2](∂z_∂R)[2]" ] }, { "cell_type": "code", "execution_count": 4, "id": "36a4cd0f-c5e2-4a6f-b434-2d347686a08b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3×7 Matrix{Float64}:\n", " 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", " 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n", " 0.0 0.0 0.0 0.0 0.0 0.0 0.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#reset\n", "R = zeros(size(A,1), size(B,2))\n", "∂z_∂R = rand(size(R)...) # Some gradient/tangent passed to us\n", "\n", "∂z_∂A = zero(A)\n", "∂z_∂B = zero(B)" ] }, { "cell_type": "code", "execution_count": 5, "id": "49cbf4e1-1ef0-4428-90af-02ac2b46c2a8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "revenue! (generic function with 1 method)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function revenue!(R, A, B)\n", " @assert axes(A,2) == axes(B,1)\n", " @inbounds @simd for i in eachindex(R)\n", " R[i] = 0\n", " end\n", " @inbounds for j in axes(B, 2), i in axes(A, 1)\n", " @inbounds @simd for k in axes(A,2)\n", " R[i,j] += A[i,k] * B[k,j]\n", " end\n", " end\n", " nothing\n", "end\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "5fea97a5-39ac-4ade-9cb4-4aba66a80825", "metadata": {}, "outputs": [], "source": [ "batch_size = 5;\n", "constellations = 2;\n", "payoff_mat = zeros(batch_size,1);\n", "\n", "stocks = rand(batch_size, constellations);\n", "\n", "payoffs = rand(constellations,1);" ] }, { "cell_type": "code", "execution_count": 7, "id": "d06ec650-621c-4a97-8bbd-0bfcfd4ab8d5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5×1 Matrix{Float64}:\n", " 0.05943992677268309\n", " 0.16746133343364858\n", " 0.22311130107900645\n", " 0.1326381498910713\n", " 0.23997313509634804" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "revenue!(payoff_mat, stocks,payoffs)\n", "\n", "payoff_mat" ] }, { "cell_type": "code", "execution_count": 12, "id": "86425c6b-baa3-494d-8c2c-35eaf08cefaf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5×1 Matrix{Float64}:\n", " 1.0\n", " 1.0\n", " 1.0\n", " 1.0\n", " 1.0" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "payoff_mat = zero(payoff_mat)\n", "∂payoff_mat = ones(size(payoff_mat)...)\n", "\n", "∂stocks = zero(stocks)\n", "∂payoff_mat" ] }, { "cell_type": "code", "execution_count": null, "id": "fbb5fa3f-b3c6-48bf-9df2-657d5d7aae18", "metadata": {}, "outputs": [], "source": [ "autodiff(revenue!\n", " ,Duplicated(payoff_mat, ∂payoff_mat)\n", " ,Const(payoffs)\n", " ,Duplicated(stocks,∂stocks)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0872e372-f3d1-4fca-b89f-f134a3dc563d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "984b1bd7-e53a-41c3-bafd-f56aacdae4b7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.6.2", "language": "julia", "name": "julia-1.6" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.6.2" } }, "nbformat": 4, "nbformat_minor": 5 }