https://github.com/DCASE-REPO/DESED_task/blob/master/desed_task/nnet/CNN.py
DESED_task/desed_task/nnet/CNN.py at master · DCASE-REPO/DESED_task
Domestic environment sound event detection task. Contribute to DCASE-REPO/DESED_task development by creating an account on GitHub.
github.com
1. real CNN input
[CNN INPUT SHAPE] torch.Size([48, 1, 626, 128])
- batch를 48 개씩 묶어서 626*128(H*W) 전달
2. cnn_init_
n_in_channel # 입력 채널 수 (예: 1이면 모노, 3이면 RGB 이미지 등)
activation = "relu" # 활성화 함수 종류 (relu, leakyrelu, glu, cg 중 선택)
conv_dropout = 0 # dropout 옵션
kernel_size = [3,3,3] # 각 conv layer의 커널 크기 (예: [3,3,3])
padding = [1,1,1] # padding 값
stride = [1.1.1] # stride 값
nb_filters = [64, 64, 64] # 각 layer의 출력 채널 수 (예: [64, 64, 64])
pooling = [(1,4), (1,4), (1,4)] # 각 레이어 후에 사용하는 pooling 크기
normalization = "batch" # "batch" 또는 "layer" 중 선택
3. Conv2d
kernal_size: 한 번의 conv 연산에서 바라보는 지역 범위(window)
stride: shape를 비율로 줄이기 위해 연속적인 컨볼루션을 진행할 때 건너뛰기할 칸 크기
padding: 출력크기(shape)를 조절하기 위해 설정 ( p = (k-1)/2이 되도록 설정)
- padding1로 18 *18로 만든 후 한 칸씩(stride =1) Conv2D 진행 => 16*16 tensor 생성
def conv(i, normalization="batch", dropout=None, activ="relu"):
nIn = n_in_channel if i == 0 else nb_filters[i - 1]
nOut = nb_filters[i]
cnn.add_module(
"conv{0}".format(i),
nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]),
)
- dilation: receptive field가 넒어지도록 함(kernal_size의 간격이 넒어짐)
output : torch.Size([1, 64, 128, 32])
4. normalization
- BatchNorm2d : 학습이 더 안정적으로 진행되도록 모든 채널이 평균 0, 표준편차 1로 바꿈
채널 단위로 (H × W) 위치에서 값을 정규화
[B, C, H, W] -> [B, C, H, W]
if normalization == "batch":
cnn.add_module(
"batchnorm{0}".format(i),
nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99),
)
elif normalization == "layer":
cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut)
5. Activation Function : ReLU
- 인공신경망에 비선형성을 부여
if activ.lower() == "leakyrelu":
cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2))
elif activ.lower() == "relu":
cnn.add_module("relu{0}".format(i), nn.ReLU())
elif activ.lower() == "glu":
cnn.add_module("glu{0}".format(i), GLU(nOut))
elif activ.lower() == "cg":
cnn.add_module("cg{0}".format(i), ContextGating(nOut)
5-2 LeakyReLU
5-3 GLU(gated linear unit)
class GLU(nn.Module):
def __init__(self, input_num):
super(GLU, self).__init__()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(input_num, input_num)
def forward(self, x):
lin = self.linear(x.permute(0, 2, 3, 1))
lin = lin.permute(0, 3, 1, 2)
sig = self.sigmoid(x)
res = lin * sig
return res
선형 변환과 시그모이드를 각각 적용하여 원소별 곱(element-wise multiplication)하여 중요한 정보만 통과시키는 게이팅(gating) 메커니즘을 구현
입력 텐서를 nn.Linear에 맞게 채널 축을 마지막으로 옮긴 후 선형 변환을 적용하고, 원래 차원으로 복원한 결과를 시그모이드 함수로 처리된 원본 입력과 곱함으로써, 정보의 흐름을 선택적으로 조절하는 역할
5-4 ContextGating
class ContextGating(nn.Module):
def __init__(self, input_num):
super(ContextGating, self).__init__()
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(input_num, input_num)
def forward(self, x):
lin = self.linear(x.permute(0, 2, 3, 1))
lin = lin.permute(0, 3, 1, 2)
sig = self.sigmoid(lin)
res = x * sig
return res
입력 특성(x)에 대해 선형 변환을 수행한 후, 시그모이드 함수에 통과시켜 문맥에 따라 정보 흐름을 조절하는 게이트를 생성
문맥을 반영해 게이트를 생성하고, 그 결과를 다시 원본 입력에 곱함으로써, 보다 세밀하게 어떤 정보를 강조할지를 학습할 수 있도록 돕는 구조
6. Dropout
- 뉴런 활성화 임의로 제거
if dropout is not None:
cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout))
7. AvgPool2d
- 연산을 조정하고 과적합을 방지하기 위해 pooling 작업
conv를 3번 진행하므로 pooling도 3번 진행하도록 설정
[1, 64, 128, 862] -> [1, 64, 128, 215] -> [1, 64, 128, 53] -> [1, 64, 128, 13]
# 128x862x64
for i in range(len(nb_filters)):
conv(i, normalization=normalization, dropout=conv_dropout, activ=activation)
cnn.add_module(
"pooling{0}".format(i), nn.AvgPool2d(pooling[i])
) # bs x tframe x mels
8. CNN output test
if __name__ == "__main__":
x= torch.randn(1,1,128,862)
cnn = CNN(n_in_channel=1)
output = cnn(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
Input shape: torch.Size([1, 1, 128, 862])
Output shape: torch.Size([1, 64, 128, 13])
실제 출력
torch.Size([48,64,626,2])
'AI' 카테고리의 다른 글
[SED] DCASE 2023 Task 4 Baseline test (0) | 2025.03.14 |
---|---|
[FL] Federated Learning tutorial (flwr) (0) | 2025.03.03 |
[ML] 강화학습(RL)의 이해 (0) | 2025.02.02 |
[PyTorch] MNIST 문자 인식 모델 (2) | 2024.10.09 |
window에서 tensorflow-gpu 사용하기 (0) | 2024.05.29 |