1. The Dead Neuron Problem (비미분성 문제)

SNN 모델을 PyTorch의 BPTT(Backpropagation Through Time)를 사용해 학습시킬 때 치명적인 수학적 모순에 직면하게 됩니다. 뉴런의 출력 스파이크 $S[t]$는 막 전위 $U[t]$가 임계값 $U_{thr}$을 넘을 때 발생하는 Heaviside Step Function (계단 함수)으로 정의됩니다.

$$ S[t] = \Theta(U[t] - U_{thr}) $$

문제는 이 계단 함수를 미분할 때 발생합니다. 임계값을 제외한 모든 구간에서 미분값(기울기)이 0이 되고, 임계값 지점에서는 무한대($\infty$)가 됩니다. 기울기가 연속적으로 0이 되어 역전파 체인 룰(Chain Rule)에 의해 기울기가 소실되는 현상을 Dead Neuron Problem이라고 부릅니다. 이로 인해 가중치 학습이 불가능해집니다.

2. Surrogate Gradient Descent (대리 기울기) 접근법

이러한 미분 불가능성 문제를 피하기 위해 고안된 우회법이 바로 대리 기울기(Surrogate Gradient)입니다.

원리는 단순합니다. 순전파(Forward Pass) 시에는 원래대로 완벽한 계단 함수(Heaviside step function)를 사용하여 0과 1의 정확한 이진 스파이크를 발생시킵니다. 하지만 역전파(Backward Pass) 시에는 계단 함수 대신 미분 가능한 부드러운 대체 함수(Surrogate Function)의 미분값을 빌려와 사용하는 것입니다.

대표적인 Surrogate Functions

  • Fast Sigmoid: 시그모이드(Sigmoid) 함수를 변형한 형태로, 계산 비용이 저렴하여 가장 대중적으로 사용됩니다.
  • ATan (ArcTanget): 아크탄젠트 곡선을 미분 대체제로 사용하는 방식입니다.
  • Straight Through Estimator (STE): 역전파 시 기울기를 단순히 상수 1로 통과시키는 기법입니다.

3. snnTorch에 Surrogate Gradient 적용하기

snnTorch 패키지에서는 snntorch.surrogate 모듈을 통해 다양한 Surrogate 기울기 함수를 제공합니다. 이를 뉴런 층(layer) 선언 시 파라미터로 넘겨주기만 하면 됩니다.

import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate

# 1. 사용할 Surrogate Gradient 함수 정의 (기울기 가파름 정도 조절 가능)
spike_grad = surrogate.fast_sigmoid(slope=25)
# spike_grad = surrogate.atan()  # 다른 함수 선택 가능

# 2. 뉴런 인스턴스화 시 'spike_grad' 인자에 전달
beta = 0.9
lif_neuron = snn.Leaky(beta=beta, spike_grad=spike_grad)

# Convolutional SNN 예시
net = nn.Sequential(
    nn.Conv2d(1, 12, kernel_size=5),
    nn.MaxPool2d(2),
    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True), # 은닉층
    nn.Flatten(),
    nn.Linear(12 * 12 * 12, 10),
    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True) # 출력층
)

Slope 변수의 역할

예제 코드의 slope=25는 Surrogate 함수의 가파름 정도(Steepness)를 결정합니다. slope 값을 높이면 역전파 시 사용되는 함수가 원래의 계단 함수에 더 가까워지지만(정확성 향상), 오차 역전파 과정에서 기울기 폭발(Gradient Explosion)이 일어날 확률이 높아집니다. 반대로 slope 값을 낮추면 학습이 안정적이지만 근사 오차가 커집니다. 따라서 하이퍼파라미터 튜닝이 필요합니다.

4. Surrogate 함수 종류 및 수식 비교

snnTorch surrogate 모듈이 제공하는 주요 함수들과 그 특성을 비교합니다.

함수 snnTorch 호출 특징
Fast Sigmoid surrogate.fast_sigmoid(slope=25) 계산 비용 저렴, 가장 많이 쓰임
Arctangent (ATan) surrogate.atan(alpha=2.0) 완만한 기울기, 기울기 소실 억제
Straight Through Estimator surrogate.straight_through_estimator() 역전파 시 기울기를 그대로 통과
Heaviside (기본값, 학습 불가) 기본값 (Surrogate 미지정) Dead neuron 문제. 학습용으로 사용 불가

5. 완전한 Convolutional SNN 학습 예시 (MNIST)

Surrogate Gradient를 적용한 Conv SNN으로 MNIST를 학습하는 전체 코드입니다.

import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate, utils, functional as SF
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ── 설정 ──────────────────────────────────────────────
num_steps  = 25
batch_size = 128
num_epochs = 3
beta       = 0.9
lr         = 1e-3

spike_grad = surrogate.fast_sigmoid(slope=25)

# ── 데이터 ──────────────────────────────────────────
transform = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0,), (1,))
])
train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    datasets.MNIST('./data', train=False, download=True, transform=transform),
    batch_size=batch_size, shuffle=False, drop_last=True
)

# ── Convolutional SNN 모델 ─────────────────────────
#  입력: (T=25, B=128, 1, 28, 28)
#  Conv 블록 → Pooling → LIF → Flatten → FC → LIF(출력)
net = nn.Sequential(
    nn.Conv2d(1, 12, kernel_size=5),              # (B, 12, 24, 24)
    nn.MaxPool2d(2),                               # (B, 12, 12, 12)
    snn.Leaky(beta=beta, spike_grad=spike_grad,
              init_hidden=True),                   # (B, 12, 12, 12) 스파이크
    nn.Conv2d(12, 64, kernel_size=5),              # (B, 64, 8, 8)
    nn.MaxPool2d(2),                               # (B, 64, 4, 4)
    snn.Leaky(beta=beta, spike_grad=spike_grad,
              init_hidden=True),
    nn.Flatten(),                                  # (B, 64*4*4 = 1024)
    nn.Linear(64 * 4 * 4, 10),                    # (B, 10)
    snn.Leaky(beta=beta, spike_grad=spike_grad,
              init_hidden=True, output=True),      # 출력층
)

optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn   = SF.ce_rate_loss()

# ── 학습 ──────────────────────────────────────────
for epoch in range(num_epochs):
    net.train()
    for i, (data, targets) in enumerate(train_loader):
        optimizer.zero_grad()

        # 타임 스텝 루프
        spk_rec = []
        utils.reset(net)
        for _ in range(num_steps):
            spk_out, _ = net(data)
            spk_rec.append(spk_out)
        spk_rec = torch.stack(spk_rec)  # (T, B, 10)

        loss_val = loss_fn(spk_rec, targets)
        loss_val.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch+1}, Iter {i}] Loss: {loss_val.item():.4f}")

# ── 평가 ──────────────────────────────────────────
net.eval()
correct, total = 0, 0
with torch.no_grad():
    for data, targets in test_loader:
        spk_rec = []
        utils.reset(net)
        for _ in range(num_steps):
            spk_out, _ = net(data)
            spk_rec.append(spk_out)
        spk_rec = torch.stack(spk_rec)
        predicted = spk_rec.sum(dim=0).argmax(dim=1)
        correct += (predicted == targets).sum().item()
        total   += targets.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")
# 일반적으로 3 에폭 학습 시 약 98~99% 정확도 달성
목록으로 돌아가기