You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
203 lines
6.7 KiB
Plaintext
203 lines
6.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "3867293b-8f5e-426a-a223-834c7e04daef",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "e73a6d0e-4db9-4b0d-ab8f-e421373f1944",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LinearNet(torch.nn.Module):\n",
|
|
" def __init__(self, input_size,output_size,layers_size):\n",
|
|
" super().__init__()\n",
|
|
" \n",
|
|
" #So, this next section constructs different layers within the NN\n",
|
|
" #sinlge linear section\n",
|
|
" self.linear_step1 = torch.nn.Linear(input_size,layers_size)\n",
|
|
" #single linear section\n",
|
|
" self.linear_step2 = torch.nn.Linear(layers_size,output_size)\n",
|
|
" \n",
|
|
" def forward(self, input_values):\n",
|
|
" intermediate_values = self.linear_step1(input_values)\n",
|
|
" out_values = self.linear_step2(intermediate_values)\n",
|
|
" \n",
|
|
" return out_values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2870086a-cb32-4f7e-9ee3-2e7a545c86cb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LinearNet2(torch.nn.Module):\n",
|
|
" def __init__(self, input_size,output_size,layers_size):\n",
|
|
" super().__init__()\n",
|
|
" \n",
|
|
" #combined together into a set of sequential workers\n",
|
|
" self.sequential_layers = torch.nn.Sequential(\n",
|
|
" torch.nn.Linear(input_size,layers_size)\n",
|
|
" torch.nn.Linear(layers_size,output_size)\n",
|
|
" )\n",
|
|
" \n",
|
|
" def forward(self, input_values):\n",
|
|
" return self.sequential_layers(input_values)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "fb6c7507-5abf-428e-8644-d0e42525368d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = LinearNet(input_size = 5, output_size=5, layers_size=15)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "b3244863-39f7-4fa6-a284-059b858569cf",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"LinearNet(\n",
|
|
" (linear_step1): Linear(in_features=5, out_features=15, bias=True)\n",
|
|
" (linear_step2): Linear(in_features=15, out_features=5, bias=True)\n",
|
|
")\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "3480a0ec-7b5b-46d7-bdd2-631e57be9c85",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_in = torch.tensor([1.5,2,3,4,5])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "e1ace0be-9b05-4c77-b076-a05b68cf8f51",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([0.3597, 0.7032, 0.0924, 0.7974, 3.1524], grad_fn=<AddBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.forward(data_in)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "464ae8e4-f71e-494f-9845-fefe5c75a7b1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"OrderedDict([('linear_step1.weight',\n",
|
|
" tensor([[ 0.2498, -0.2177, 0.2323, 0.3493, -0.2541],\n",
|
|
" [-0.2169, 0.0535, -0.1223, 0.0237, -0.0184],\n",
|
|
" [-0.1554, -0.0134, 0.2918, 0.3542, 0.3464],\n",
|
|
" [-0.2378, 0.3828, -0.3026, -0.1545, -0.1484],\n",
|
|
" [ 0.3282, -0.1492, 0.3551, -0.0447, -0.3294],\n",
|
|
" [ 0.2789, -0.1546, 0.2821, 0.0136, 0.4210],\n",
|
|
" [ 0.2911, -0.2191, 0.0493, 0.1006, 0.0470],\n",
|
|
" [-0.2269, 0.1705, 0.1198, 0.4040, 0.2512],\n",
|
|
" [-0.2696, -0.4259, 0.4229, 0.1412, 0.3553],\n",
|
|
" [ 0.0293, 0.4044, 0.3961, -0.3992, 0.2586],\n",
|
|
" [-0.3101, -0.0327, 0.1832, 0.0295, -0.3185],\n",
|
|
" [ 0.0637, -0.0770, 0.2297, -0.1567, 0.4379],\n",
|
|
" [-0.0540, -0.1769, 0.3407, 0.1942, 0.3494],\n",
|
|
" [-0.3609, -0.3536, 0.2491, -0.0490, -0.1199],\n",
|
|
" [ 0.2946, -0.0782, -0.0580, 0.2313, -0.0696]])),\n",
|
|
" ('linear_step1.bias',\n",
|
|
" tensor([-0.4207, -0.1624, 0.0212, -0.0988, -0.2106, 0.2991, -0.3496, -0.1799,\n",
|
|
" -0.4257, -0.3384, -0.0020, 0.1267, 0.0252, 0.0037, 0.1784])),\n",
|
|
" ('linear_step2.weight',\n",
|
|
" tensor([[ 0.1404, -0.1424, 0.1518, -0.1080, 0.1269, -0.2030, -0.0533, -0.2240,\n",
|
|
" 0.0364, -0.0393, 0.1619, 0.1242, 0.0731, -0.1545, 0.2024],\n",
|
|
" [-0.2529, 0.0578, 0.1629, -0.0352, -0.2128, 0.0429, 0.0261, 0.2264,\n",
|
|
" -0.0470, 0.0277, 0.0272, -0.1074, -0.1334, 0.0792, -0.0173],\n",
|
|
" [ 0.0459, 0.2224, -0.2272, 0.0123, -0.0676, 0.2378, 0.2166, -0.0981,\n",
|
|
" 0.1010, -0.1593, -0.2422, -0.1253, 0.0899, -0.0760, -0.0816],\n",
|
|
" [ 0.1763, 0.2344, 0.0591, -0.2299, 0.1116, 0.0604, 0.2032, 0.1298,\n",
|
|
" 0.0509, 0.2581, 0.2425, -0.0920, 0.0098, 0.1353, -0.2110],\n",
|
|
" [ 0.0726, -0.1959, 0.2114, -0.0732, -0.1089, 0.0836, -0.1061, 0.1640,\n",
|
|
" 0.1221, 0.0281, -0.2401, 0.1108, 0.1354, 0.1903, -0.1006]])),\n",
|
|
" ('linear_step2.bias',\n",
|
|
" tensor([0.2462, 0.0098, 0.1239, 0.0689, 0.2404]))])"
|
|
]
|
|
},
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.state_dict()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "04bca0f6-dfb9-4013-ac2a-f4efd2f93c15",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|