16. 시냅스 가중치 얻기


get_weights

시냅스 가중치 또는 웨이트 값은 Neural Network의 뉴런과 뉴런 노드 사이를 연결하는 강도를 나타내는 숫자입니다.

tf.keras.layers 모듈의 모든 레이어는 get_weights() 메서드를 포함합니다.

get_weights()메서드를 이용해서, 미리 구성한 Neural Network에서 뉴런층의 시냅스 가중치를 얻어보겠습니다.


예제

import tensorflow as tf

# 1. MNIST 데이터셋 임포트
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 2. 데이터 전처리
x_train, x_test = x_train/255.0, x_test/255.0

# 3. 모델 구성
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# 4. 웨이트 얻기
weights = model.get_weights()

print(weights)
print(len(weights))

구성한 Neural Network 모델에는 입력층을 제외하고 두 개의 뉴런층이 있습니다.

get_weights() 메서드를 이용해서 모델의 시냅스 가중치 (weights)를 weights 변수에 저장하고 출력했습니다.

출력 결과는 아래와 같습니다.

[array([[ 0.01922707,  0.01191682, -0.06732179, ..., -0.04985018,
      0.04876792, -0.06567953],
    [-0.02828409,  0.00119013, -0.02863812, ..., -0.0467992 ,
     -0.00200056, -0.05697531],
    [ 0.06635921, -0.02362482, -0.00820123, ...,  0.03124215,
      0.00039714, -0.00920851],
    ...,
    [-0.05759408,  0.06655261,  0.0457562 , ...,  0.0062023 ,
      0.05761695, -0.04720653],
    [ 0.04525699,  0.06081389, -0.0400049 , ..., -0.06753884,
     -0.0037506 ,  0.02214601],
    [ 0.02717116,  0.0228968 , -0.04719903, ...,  0.00125863,
     -0.03331475, -0.04829547]], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
    0., 0.], dtype=float32), array([[ 0.06985401,  0.02613375, -0.0551643 , ...,  0.06326195,
      0.08241054,  0.07814605],
    [-0.04089634,  0.03981646, -0.09707076, ...,  0.08483479,
     -0.08038449,  0.0651061 ],
    [ 0.01726168,  0.05927653, -0.08869828, ..., -0.05225329,
     -0.00716908,  0.09950536],
    ...,
    [-0.05344408, -0.03864221, -0.05626503, ..., -0.01657531,
      0.00561044, -0.054561  ],
    [-0.075242  , -0.06006503,  0.01800562, ...,  0.09833171,
     -0.0370728 , -0.06306928],
    [-0.1053426 ,  0.04195554,  0.03840356, ...,  0.06571037,
     -0.05478839,  0.00051194]], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]
4

이 모델의 시냅스 가중치는 모두 네 개의 NumPy 어레이로 구성되어 있음을 알 수 있습니다.

각 어레이의 형태를 출력해보면

print(weights[0].shape)
print(weights[1].shape)
print(weights[2].shape)
print(weights[3].shape)
(784, 512)
(512,)
(512, 10)
(10,)

첫번째 어레이는 784x512개의 값을 갖는 2차원 어레이로서 입력층 (input layer)과 은닉층 (hidden layer)을 연결하는 가중치를 나타내는 값입니다.

두번째 어레이는 512개의 0으로 이루어져 있으며, 은닉층 (hidden layer)의 바이어스 (bias) 값을 나타냅니다.

세번째 어레이는 512x10개의 값을 갖는 2차원 어레이로서 은닉층 (hidden layer)과 출력층 (output layer)을 연결하는 가중치를 나타내는 값입니다.

네번째 어레이는 10개의 0으로 이루어져 있으며, 출력층 (output layer)의 바이어스 (bias) 값을 나타냅니다.



웨이트 저장하기

import numpy as np

np.savetxt('weights[0].csv', weights[0])
np.savetxt('weights[1].csv', weights[1])
np.savetxt('weights[2].csv', weights[2])
np.savetxt('weights[3].csv', weights[3])

NumPy의 np.savetxt() 함수를 이용하면, 각각의 가중치 값을 csv 파일로 저장할 수 있습니다.


관련 페이지


이전글/다음글