AI/DeepLearning

배치 정규화 Batch Normalization + Pytorch 실습

lgvv 2025. 11. 13. 02:44

배치 정규화 Batch Normalization + Pytorch 실습

 

각 층의 활성화값 분포를 관찰해보며, 가중치의 초깃값을 적절히 설정하면 각 층의 활성화값 분포가 적당히 퍼지면서 학습이 원활하게 수행됨.

 

그렇다면 각 층이 활성화를 적당히 퍼뜨리도록 강제하면 더 좋지 않을까란 생각에서 배치 정규화가 출발함.

 

 

 

배치 정규화 알고리즘

 

2015년에 제안되었으나 많은 연구자가 현재까지도 사용하고 있음.

 

배치정규화의 장점

  • 학습을 빨리 진행할 수 있음.
  • 초깃값에 크게 의존하지 않음.
  • 오버피팅을 억제 (드롭아웃 등의 필요성 감소)

 

배치 정규화의 효과


대부분의 초깃값 표준편차에서 학습 진도가 빠름.

배치 정규화를 이용하지 않은 경우 초깃값이 잘 분포되지 않으면 학습이 전혀 되지 않는 경우도 있음.

 

 

PyTorch를 이용한 배치 정규화(Batch Normalization) 구현 및 비교

샘플 코드를 구현하고 결과를 확인

import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------------------
# 1. 배치 정규화를 사용하지 않은 단순 모델
# -----------------------------------------
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 입력층에서 첫 번째 은닉층
        self.fc2 = nn.Linear(128, 64)   # 두 번째 은닉층
        self.fc3 = nn.Linear(64, 10)    # 출력층

    def forward(self, x):
        x = x.view(-1, 784) # 이미지를 1D 벡터로 펼치기
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# -----------------------------------------
# 2. 배치 정규화를 적용한 모델
# -----------------------------------------
class BatchNormMLP(nn.Module):
    def __init__(self):
        super(BatchNormMLP, self).__init__()
        # 선형 계층 뒤에 배치 정규화 계층 추가
        self.fc1 = nn.Linear(784, 128)
        self.bn1 = nn.BatchNorm1d(128) # 1D 데이터용 배치 정규화 (은닉층 노드 수만큼 지정)
        
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784) # 이미지를 1D 벡터로 펼치기
        
        # 순서: 선형 변환 -> 배치 정규화 -> 활성화 함수(ReLU)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        
        x = self.fc3(x)
        return x

# -----------------------------------------
# 모델 사용 예시
# -----------------------------------------

# 더미 입력 데이터 생성 (미니배치 크기 32, 이미지 크기 28x28)
dummy_input = torch.randn(32, 28, 28)

# 배치 정규화 없는 모델 인스턴스화
model_simple = SimpleMLP()
output_simple = model_simple(dummy_input)
print(f"단순 MLP 출력 형태: {output_simple.shape}")

# 배치 정규화 적용 모델 인스턴스화
model_bn = BatchNormMLP()
output_bn = model_bn(dummy_input)
print(f"배치 정규화 MLP 출력 형태: {output_bn.shape}")

# 학습 모드 설정 (학습 시에는 bn 계층이 미니배치 통계 사용)
model_bn.train() 

# 추론 모드 설정 (추론 시에는 저장된 이동 평균/분산 사용)
model_bn.eval()

 

 

아래는 출력 결과

단순 MLP 출력 형태: torch.Size([32, 10])
배치 정규화 MLP 출력 형태: torch.Size([32, 10])
BatchNormMLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

 

 

 

PyTorch를 활용한 배치 정규화 효과 시각적 비교

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm # 학습 진행률을 보여주는 라이브러리 (선택 사항)

# --- (이전 SimpleMLP 및 BatchNormMLP 클래스 정의 코드는 여기에 그대로 유지) ---
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class BatchNormMLP(nn.Module):
    def __init__(self):
        super(BatchNormMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
# -----------------------------------------------------------------------

# -----------------------------------------
# 3. 실제 학습 및 시각화 코드 추가
# -----------------------------------------

# 하이퍼파라미터 설정
NUM_EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.01

# 1. 데이터 로드 (MNIST 사용)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

def train_model(model, name):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)
    model.train()
    
    losses = []
    print(f"\n--- 학습 시작: {name} ---")
    
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        for data in tqdm(trainloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(trainloader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
    
    return losses

# 2. 두 모델 학습 실행
losses_simple = train_model(SimpleMLP(), "SimpleMLP")
losses_bn = train_model(BatchNormMLP(), "BatchNormMLP")

# 3. 결과 시각화 (눈으로 보기)
plt.figure(figsize=(10, 5))
plt.plot(range(1, NUM_EPOCHS + 1), losses_simple, label='SimpleMLP (No BN)', marker='o')
plt.plot(range(1, NUM_EPOCHS + 1), losses_bn, label='BatchNormMLP (With BN)', marker='o')
plt.title('Training Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.legend()
plt.grid(True)
plt.show() # <-- 이 명령어로 그래프 창이 뜹니다!

 

 

 

출력결과

100.0%
100.0%
100.0%
100.0%

--- 학습 시작: SimpleMLP ---
Epoch 1/10: 100%|██████████| 938/938 [00:03<00:00, 311.22it/s]
Epoch 1, Average Loss: 1.0485
Epoch 2/10: 100%|██████████| 938/938 [00:02<00:00, 316.81it/s]
Epoch 2, Average Loss: 0.3868
Epoch 3/10: 100%|██████████| 938/938 [00:03<00:00, 310.33it/s]
Epoch 3, Average Loss: 0.3244
Epoch 4/10: 100%|██████████| 938/938 [00:02<00:00, 313.54it/s]
Epoch 4, Average Loss: 0.2905
Epoch 5/10: 100%|██████████| 938/938 [00:02<00:00, 319.21it/s]
Epoch 5, Average Loss: 0.2648
Epoch 6/10: 100%|██████████| 938/938 [00:02<00:00, 318.42it/s]
Epoch 6, Average Loss: 0.2419
Epoch 7/10: 100%|██████████| 938/938 [00:03<00:00, 305.64it/s]
Epoch 7, Average Loss: 0.2225
Epoch 8/10: 100%|██████████| 938/938 [00:03<00:00, 307.17it/s]
Epoch 8, Average Loss: 0.2045
Epoch 9/10: 100%|██████████| 938/938 [00:03<00:00, 311.74it/s]
Epoch 9, Average Loss: 0.1891
Epoch 10/10: 100%|██████████| 938/938 [00:03<00:00, 307.87it/s]
Epoch 10, Average Loss: 0.1757

--- 학습 시작: BatchNormMLP ---
Epoch 1/10: 100%|██████████| 938/938 [00:03<00:00, 272.93it/s]
Epoch 1, Average Loss: 0.5523
Epoch 2/10: 100%|██████████| 938/938 [00:03<00:00, 273.80it/s]
Epoch 2, Average Loss: 0.2110
Epoch 3/10: 100%|██████████| 938/938 [00:03<00:00, 280.90it/s]
Epoch 3, Average Loss: 0.1486
Epoch 4/10: 100%|██████████| 938/938 [00:03<00:00, 279.51it/s]
Epoch 4, Average Loss: 0.1186
Epoch 5/10: 100%|██████████| 938/938 [00:03<00:00, 273.34it/s]
Epoch 5, Average Loss: 0.0981
Epoch 6/10: 100%|██████████| 938/938 [00:03<00:00, 280.56it/s]
Epoch 6, Average Loss: 0.0824
Epoch 7/10: 100%|██████████| 938/938 [00:03<00:00, 275.68it/s]
Epoch 7, Average Loss: 0.0717
Epoch 8/10: 100%|██████████| 938/938 [00:03<00:00, 281.33it/s]
Epoch 8, Average Loss: 0.0633
Epoch 9/10: 100%|██████████| 938/938 [00:03<00:00, 273.99it/s]
Epoch 9, Average Loss: 0.0562
Epoch 10/10: 100%|██████████| 938/938 [00:03<00:00, 273.89it/s]
Epoch 10, Average Loss: 0.0504