1. Spiking Neural Network의 학습 원리
일반적인 인공 신경망(ANN)에서는 기울기 하강법(Gradient Descent)과 역전파(Backpropagation) 알고리즘을 사용해 가중치를 업데이트합니다. SNN 역시 동일한 원리에 기초하지만, 차이점은 데이터가 시간적인 차원(Temporal Dimension)을 갖고 있다는 점입니다.
SNN에서는 RNN(Recurrent Neural Network)과 마찬가지로 시간의 흐름에 따라 네트워크를 펼치는 BPTT (Backpropagation Through Time) 방식이 사용됩니다. SNN의 차이점은 각 시간 스텝의 출력이 실숫값이 아닌 $0$ 또는 $1$의 스파이크라는 것입니다.
2. SNN을 위한 Loss 계산 기법
SNN에서 분류(Classification) 작업을 수행할 때, 가장 일반적인 접근법은 각 클래스 레이블에 할당된 출력 뉴런 중 가장 스파이크를 많이 발생시킨(Highest Firing Rate) 뉴런을 정답으로 간주하는 Rate-coding 접근법을 사용합니다.
따라서 손실 함수(Loss function)는 전체 시뮬레이션 시간 동안의 스파이크 수를 합산(Sum)하거나 막 전위(Membrane Potential)의 누적값을 기반으로 Cross Entropy Loss 혹은 MSE Loss를 계산합니다.
3. 파이토치(PyTorch)를 통한 FC-SNN 구현 및 학습
아래는 MNIST 데이터셋 이미지를 이진 스파이크 열로 변환하여 학습하는 완전 연결(Fully Connected) SNN의 구현 예제입니다.
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
# 하이퍼파라미터 및 네트워크 변수
num_steps = 25 # 시뮬레이션 타임 스텝
batch_size = 128
beta = 0.95 # 신경망 막 전위 감쇠율
# 대리 기울기(Surrogate Gradient) 함수 적용 -> Tutorial 6에서 자세히 다룸
spike_grad = surrogate.fast_sigmoid(slope=25)
# 피드포워드 SNN 아키텍처 정의 (PyTorch nn.Sequential 활용)
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 1000),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
nn.Linear(1000, 10),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
)
# 손실 함수 및 옵티마이저 (기존 PyTorch와 완전 동일)
optimizer = torch.optim.Adam(net.parameters(), lr=2e-3, betas=(0.9, 0.999))
loss_fn = nn.CrossEntropyLoss()
# 단일 배치를 이용한 훈련 루프 스니펫
def forward_pass(net, data):
spk_rec = [] # 출력 스파이크 기록
utils.reset(net) # 각 배치마다 네트워크의 Hidden state 리셋
for step in range(num_steps): # BPTT 시간 롤아웃
spk_out, mem_out = net(data)
spk_rec.append(spk_out)
return torch.stack(spk_rec)
# 학습 과정
net.train()
for data, targets in train_loader: # train_loader 가정
optimizer.zero_grad()
# 순전파
spk_rec = forward_pass(net, data)
# 전체 타임 스텝의 스파이크 합산 -> 손실 계산 (Rate coding)
loss_val = torch.zeros((1), dtype=dtype)
for step in range(num_steps):
loss_val += loss_fn(spk_rec[step], targets)
# 역전파
loss_val.backward()
optimizer.step()
훈련 시 핵심 요약 사항
- `init_hidden=True`: snnTorch 내부에서 뉴런의 은닉 상태(Membrane Potential)를 자동 관리해 주어 코드가 간결해집니다.
- 은닉 상태 초기화(`utils.reset(net)`): 시퀀스가 끝날 때마다 상태(전위)를 리셋하여 다음 배치 데이터가 이전 데이터의 간섭을 받지 않게 해야 합니다.
- 시간 루프(for loop): 순전파 시에는 `num_steps`만큼 루프를 돌면서 결과를 리스트나 텐서로 쌓은 후 손실을 계산해야 합니다.
4. 완전한 학습 루프 구현 (MNIST 완전예제)
다음은 MNIST 데이터셋으로 SNN을 완전하게 학습하고 평가하는 전체 코드입니다.
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import utils, surrogate, functional as SF
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# ── 하이퍼파라미터 ─────────────────────────────────────────
num_steps = 25
batch_size = 128
num_epochs = 5
beta = 0.95
lr = 2e-3
# ── 데이터 로드 ──────────────────────────────────────────
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0,), (1,)),
])
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, drop_last=True)
# ── 모델 정의 ─────────────────────────────────────────────
spike_grad = surrogate.fast_sigmoid(slope=25)
net = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 1000),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
nn.Linear(1000, 10),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True),
)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))
loss_fn = SF.ce_rate_loss() # snnTorch 제공 Rate coding 기반 CE Loss
# ── 순전파 헬퍼 ──────────────────────────────────────────
def forward_pass(net, data):
spk_rec = []
utils.reset(net)
for _ in range(num_steps):
spk_out, mem_out = net(data)
spk_rec.append(spk_out)
return torch.stack(spk_rec) # (T, B, C)
# ── 정확도 계산 헬퍼 ─────────────────────────────────────
def batch_accuracy(loader):
total, correct = 0, 0
net.eval()
with torch.no_grad():
for data, targets in loader:
spk_rec = forward_pass(net, data)
# 가장 많이 발화한 뉴런의 인덱스 = 예측 클래스
predicted = spk_rec.sum(dim=0).argmax(dim=1)
correct += (predicted == targets).float().sum().item()
total += targets.size(0)
return 100 * correct / total
# ── 학습 루프 ─────────────────────────────────────────────
for epoch in range(num_epochs):
net.train()
epoch_loss = 0.0
for data, targets in train_loader:
optimizer.zero_grad()
spk_rec = forward_pass(net, data)
# Rate coding 기반 Cross Entropy Loss
loss_val = loss_fn(spk_rec, targets)
loss_val.backward()
optimizer.step()
epoch_loss += loss_val.item()
train_acc = batch_accuracy(train_loader)
test_acc = batch_accuracy(test_loader)
print(f"Epoch {epoch+1:2d} | Loss: {epoch_loss/len(train_loader):.4f} "
f"| Train Acc: {train_acc:.2f}% | Test Acc: {test_acc:.2f}%")
# 예시 출력:
# Epoch 1 | Loss: 0.9823 | Train Acc: 87.45% | Test Acc: 87.92%
# Epoch 5 | Loss: 0.3412 | Train Acc: 95.78% | Test Acc: 95.40%
# ── 모델 저장 ────────────────────────────────────────────
torch.save(net.state_dict(), "snn_mnist.pth")
print("모델 저장 완료: snn_mnist.pth")
5. Loss 함수 선택 가이드
snnTorch의 functional 모듈은 SNN에 최적화된 다양한 손실 함수를 제공합니다.
SF.ce_rate_loss(): 전체 타임 스텝의 발화율(Firing Rate)을 소프트맥스에 통과시켜 Cross Entropy Loss 계산. 가장 일반적. Rate coding과 잘 어울림.SF.ce_max_membrane_loss(): 출력 뉴런의 최대 막 전위를 기준으로 Cross Entropy 계산. 막 전위 기반 접근.SF.mse_count_loss(): 정답 클래스가 목표 발화수에 가깝도록 MSE Loss 적용. 발화 횟수를 직접 제어할 때 사용.