이전 Tutorial 2에서 설계된 LIF neuron model은 복잡하고 hyper parameter의 조정이 필요하다. 이때 파라미터의 추적은 어렵고 SNN으로의 확장을 할 시 더 복잡해 짐으로 단순화를 진행한다.

LIF neuron model 단순화

감쇠율 : $\beta$

Euler 방법을 이용해 passive membrane 모델의 해는 다음과 같다.

$$ V_m(t+Δt)=V_m(t)+τΔt(−V_m(t)+I(t)R) $$

이때 입력 전류 $I(t)$가 없는 경우를 가정한다.

$$ V_m(t+Δt)=V_m(t)−\frac{Δt}{τ}V_m(t) $$

이때 막전위의 감쇠율을 $\beta$라고 하며 아래와 같다.

$$ \beta = \frac{V_m(t+\triangle t)}{V_m(t)}=1-\frac{\triangle t}{\tau} $$

가중 입력 전류

t가 시퀀스의 의 시간 단계라고 가정하면 $\triangle t = 1$로 볼 수 있다. 또한 hyper parameter수를 줄이기 위해 $R = 1$로 가정하면 아래와 같은 식이 출력된다.

$$ \beta = 1 - \frac{1}{C} \to (1-\beta)I_{in}=\frac{1}{\tau}I_{in} $$

이때 $1-\beta$를 입력 전류의 가중치라고 보며 membrane 전위에 순간적으로 기여한다고 가정한다. 또한 시간 구간이 짧아서 neuron은 하나의 Spike만 발생할 수 있다고 가정한다.

$$ V(t+1) = \beta V(t) + (1-\beta)I_{in}(t+1) $$

deeplearning에선 입력의 가중치 계수가 학습 가능한 parameter로 사용된다. 이때 신호 $V(t)$와 가중치 W의 상호작용을 단순화 하기 위해 둘을 곱한 결과로 표현한다.

$$ V(t+1) = \beta V(t) + WX(t+1) $$

Spikint & Reset

막 전위가 임계값을 초과하면 뉴런이 출력 스파이크를 발생시킨다.

$$ S[t] = 1, if ;;V(t)>V_{thr} \ ;;;0, otherwise $$

Spike가 발생하면 membrane 전위는 초기화가 되어야 한다. 이때 감소에 의한 리셋(reset by substraction) 모델은 다음과 같다.

$$ V(t+1) = \beta V(t) + WX(t+1)-S(t)V_{thr} $$

이때 W는 학습 가능한 파라미터 이며 $V_{thr}$은 종종 1로 설정된다.


def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
  spk = (mem > threshold)  # 막 전위가 임계값을 초과하면 spk=1, 그렇지 않으면 0
  mem = beta * mem + w * x - spk * threshold
  return spk, mem
delta_t = torch.tensor(1e-3)
tau = torch.tensor(5e-3)
beta = torch.exp(-delta_t / tau)
print(f"The decay rate is: {beta:.3f}")
num_steps = 200

# 입력/출력 초기화 및 작은 스텝 전류 입력
x = torch.cat((torch.zeros(10), torch.ones(190) * 0.5), 0)
mem = torch.zeros(1)
spk_out = torch.zeros(1)
mem_rec = []
spk_rec = []
# 뉴런 파라미터
w = 0.4
beta = 0.819

# 뉴런 시뮬레이션
for step in range(num_steps):
  spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
  mem_rec.append(mem)
  spk_rec.append(spk)

snn.Leaky 클래스 직접 활용하기

직접 구현한 함수 대신 snnTorch의 snn.Leaky 클래스를 사용하면 훨씬 간결하게 동일한 LIF 뉴런을 구현할 수 있습니다. 내부적으로 위의 수식과 동일하게 동작하지만, PyTorch의 autograd와 완벽하게 통합됩니다.

import snntorch as snn
import torch

# snn.Leaky 인스턴스화: beta만 지정하면 됨
lif = snn.Leaky(beta=0.819)

# 은닉 상태(막 전위) 초기화
mem = lif.init_leaky()  # 0으로 초기화된 텐서 반환

# 입력 데이터 (임의의 spike train)
num_steps = 200
x = torch.cat((torch.zeros(10), torch.ones(190) * 0.5), 0)

mem_rec, spk_rec = [], []

for step in range(num_steps):
    spk, mem = lif(x[step], mem)
    mem_rec.append(mem.item())
    spk_rec.append(spk.item())

print(f"총 스파이크 수: {sum(spk_rec)}")
print(f"최대 막 전위: {max(mem_rec):.4f}")

완전 연결(Feedforward) SNN 구현

실제 딥러닝에서는 Linear 레이어와 LIF 뉴런을 번갈아 쌓아 다층 SNN을 구성합니다. MNIST 분류 예시로 2층 Feedforward SNN을 구현해 봅니다.

import torch
import torch.nn as nn
import snntorch as snn
from snntorch import utils

# 하이퍼파라미터
num_steps = 25  # 시뮬레이션 타임 스텝
beta = 0.95     # 막 전위 감쇠율

# 네트워크 구조 정의
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # 레이어 정의
        self.fc1 = nn.Linear(28*28, 1000)
        self.lif1 = snn.Leaky(beta=beta)

        self.fc2 = nn.Linear(1000, 10)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):
        # 은닉 상태 초기화
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        spk_rec = []  # 출력층 스파이크 기록

        for step in range(num_steps):
            # 첫 번째 레이어
            cur1 = self.fc1(x)         # 선형 변환
            spk1, mem1 = self.lif1(cur1, mem1)  # LIF 뉴런 통과

            # 두 번째 레이어 (출력층)
            cur2 = self.fc2(spk1)      # 첫 번째 레이어의 스파이크를 입력
            spk2, mem2 = self.lif2(cur2, mem2)
            spk_rec.append(spk2)

        return torch.stack(spk_rec)    # (num_steps, batch, 10)

net = Net()
print(net)

# 추론: 총 타임 스텝의 스파이크 합산 → Rate decoding으로 분류
# ex) 각 배치 샘플에 대해 가장 많이 발화한 출력 뉴런의 인덱스 = 예측 클래스
with torch.no_grad():
    data = torch.randn(32, 28*28)         # 더미 입력 데이터
    output = net(data)                    # (25, 32, 10)
    predicted = output.sum(dim=0).argmax(dim=1)  # (32,) 예측 클래스
    print(f"예측 클래스 형태: {predicted.shape}")

핵심 파라미터 요약

  • β (beta, 감쇠율): 0과 1 사이의 값. 1에 가까울수록 막 전위가 느리게 감쇠하여 장기 기억 유지. 0에 가까울수록 빠르게 초기화되어 최근 입력에만 민감.
  • threshold: 스파이크 발생 임계값 (기본값 1.0). 낮을수록 쉽게 발화하고 높을수록 강한 자극 필요.
  • reset_mechanism: 스파이크 후 리셋 방식. "subtract"(기본) 또는 "zero".
  • num_steps: 시뮬레이션 타임 스텝 수. 많을수록 정확하지만 연산량 증가.
목록으로 돌아가기