본문 바로가기
AI/MachineLearning

[ML] DCASE2023 : Mean-teacher CRNN.py

by TSpoons 2025. 3. 23.

Data Preprocessing

_get_logits_one_head

    def _get_logits_one_head(
        self, x, pad_mask, dense, dense_softmax, classes_mask=None
    ):
        strong = dense(x)  # [bs, frames, nclass]
        strong = self.sigmoid(strong)
        if classes_mask is not None:
            classes_mask = ~classes_mask[:, None].expand_as(strong)
        if self.attention in [True, "legacy"]:
            sof = dense_softmax(x)  # [bs, frames, nclass]
            if not pad_mask is None:
                sof = sof.masked_fill(pad_mask.transpose(1, 2), -1e30)  # mask attention

            if classes_mask is not None:
                # mask the invalid classes, cannot attend to these
                sof = sof.masked_fill(classes_mask, -1e30)
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / sof.sum(1)  # [bs, nclass]
        else:
            weak = strong.mean(1)

        if classes_mask is not None:
            # mask invalid
            strong = strong.masked_fill(classes_mask, 0.0)
            weak = weak.masked_fill(classes_mask[:, 0], 0.0)

        return strong.transpose(1, 2), weak  #[bs, nclass, frames], [bs, nclass]

강력한(Strong) 예측: 각 시간 프레임마다 소리 이벤트가 있는지 예측

  • 입력 특징(x)에 밀집층(dense)을 적용해 각 프레임별 각 클래스의 점수를 계산
  • 시그모이드 함수를 적용해 0~1 사이 값으로 변환

약한(Weak) 예측: 전체 오디오 클립에 소리 이벤트가 있는지만 예측

  • 강력한 예측과 어텐션 가중치를 곱한 후 합산하여 가중 평균을 계산
  • attention 사용 x : 단순히 모든 프레임의 강력한 예측 평균을 계산

마스킹 처리 : 유효하지 않은 클래스가 있으면 해당 예측을 0으로 설정

Example

weak : [0.8, 0.2, 0.0, 0.95] 첫 번째와 네 번째 소리 유형이 높은 확률로 존재

strong =[[[0.1, 0.2, 0.8, 0.9, 0.7], # 첫 번째 클래스(개 짖는 소리)
    [0.0, 0.0, 0.1, 0.2, 0.0], # 두 번째 클래스(자동차 경적) 
        [0.3, 0.2, 0.1, 0.0, 0.0]]] # 세 번째 클래스(사이렌)

_get_logits

다중 헤드 :

  • nclass가 여러 개이면
  • 여러 번 _get_logits_one_head 호출
  • 클래스 차원을 기준으로 펼쳐서 가중치 반환

단일 헤드:

  • nclass가 여러 개가 아니라 하나면

-_get_logits_one_head 함수로 동작 후 반환

def _get_logits(self, x, pad_mask, classes_mask=None):
        out_strong = []
        out_weak = []
        if isinstance(self.nclass, (tuple, list)):
            # instead of masking the softmax we can have multiple heads for each dataset:
            # maestro_synth, maestro_real and desed.
            # not sure which approach is better. We must try.
            for indx, c_classes in enumerate(self.nclass):
                dense_softmax = (
                    self.dense_softmax[indx] if hasattr(self, "dense_softmax") else None
                )
                c_strong, c_weak = self._get_logits_one_head(
                    x, pad_mask, self.dense[indx], dense_softmax, classes_mask
                )
                out_strong.append(c_strong)
                out_weak.append(c_weak)

            # concatenate over class dimension
            return torch.cat(out_strong, 1), torch.cat(out_weak, 1)
        else:
            dense_softmax = (
                self.dense_softmax if hasattr(self, "dense_softmax") else None
            )
            return self._get_logits_one_head(
                x, pad_mask, self.dense, dense_softmax, classes_mask
            )

apply_specaugment -데이터 증강(data augmentation)

def apply_specaugment(self, x):
    if self.training:

        timemask = torchaudio.transforms.TimeMasking(
            self.specaugm_t_l, True, self.specaugm_t_p
        )
        freqmask = torchaudio.transforms.TimeMasking(
            self.specaugm_f_l, True, self.specaugm_f_p
        )  # use time masking also here
        x = timemask(freqmask(x.transpose(1, -1)).transpose(1, -1))

    return x

훈련 중일 때 두 마스킹을 연속으로 적용

  • 잡음이 섞이거나 일부 주파수가 손실된 실제 상황에서도 잘 작동
  • 주파수 마스킹 → 시간 마스킹 -> return x

(SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition)

Ex) 입력 x: (1, 4, 5) (배치 크기 1, fequency bin 4, frame 5).

forward

    def forward(self, x, pad_mask=None, embeddings=None, classes_mask=None):

        x = self.apply_specaugment(x)
        x = x.transpose(1, 2).unsqueeze(1)

        # input size : (batch_size, n_channels, n_frames, n_freq)
        if self.cnn_integration:
            bs_in, nc_in = x.size(0), x.size(1)
            x = x.view(bs_in * nc_in, 1, *x.shape[2:])

        # conv features
        x = self.cnn(x)
        bs, chan, frames, freq = x.size()
        if self.cnn_integration:
            x = x.reshape(bs_in, chan * nc_in, frames, freq)

        if freq != 1:
            warnings.warn(
                f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq"
            )
            x = x.permute(0, 2, 1, 3)
            x = x.contiguous().view(bs, frames, chan * freq)
        else:
            x = x.squeeze(-1)
            x = x.permute(0, 2, 1)  # [bs, frames, chan]

        # rnn features
        if self.use_embeddings:
            if self.aggregation_type == "global":
                x = self.cat_tf(
                    torch.cat(
                        (
                            x,
                            self.shrink_emb(embeddings)
                            .unsqueeze(1)
                            .repeat(1, x.shape[1], 1),
                        ),
                        -1,
                    )
                )
            elif self.aggregation_type == "frame":
                # there can be some mismatch between seq length of cnn of crnn and the pretrained embeddings, we use an rnn
                # as an encoder and we use the last state
                last, _ = self.frame_embs_encoder(embeddings.transpose(1, 2))
                embeddings = last[:, -1]
                reshape_emb = (
                    self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)
                )

            elif self.aggregation_type == "interpolate":
                output_shape = (embeddings.shape[1], x.shape[1])
                reshape_emb = (
                    torch.nn.functional.interpolate(
                        embeddings.unsqueeze(1), size=output_shape, mode="nearest-exact"
                    )
                    .squeeze(1)
                    .transpose(1, 2)
                )

            elif self.aggregation_type == "pool1d":
                reshape_emb = torch.nn.functional.adaptive_avg_pool1d(
                    embeddings, x.shape[1]
                ).transpose(1, 2)
            else:
                raise NotImplementedError

        if self.use_embeddings:
            if self.dropstep_recurrent and self.training:
                dropstep = torchaudio.transforms.TimeMasking(
                    self.dropstep_recurrent_len, True, self.dropstep_recurrent
                )
                x = dropstep(x.transpose(1, -1)).transpose(1, -1)
                reshape_emb = dropstep(reshape_emb.transpose(1, -1)).transpose(1, -1)
            x = self.cat_tf(self.dropout(torch.cat((x, reshape_emb), -1)))
        else:
            if self.dropstep_recurrent and self.training:
                dropstep = torchaudio.transforms.TimeMasking(
                    self.dropstep_recurrent_len, True, self.dropstep_recurrent
                )
                x = dropstep(x.transpose(1, 2)).transpose(1, 2)
                x = self.dropout(x)

        x = self.rnn(x)
        x = self.dropout(x)

        return self._get_logits(x, pad_mask, classes_mask)

 

입력 변환
- SpecAugment 적용 (데이터 증강)
- 차원 변환하여 CNN 입력 형식으로 변환
CNN 처리
- 시간 및 주파수 특징 추출
- 차원 정리하여 RNN 입력으로 변환
Embeddings 추가 (선택)
- 글로벌 / 프레임 기반 / 보간 / 풀링 방식으로 결합
Dropstep 적용 (선택)
- TimeMasking을 적용하여 일반화 성능 향상

 

RNN 처리
- 시간 흐름을 고려한 특징 학습
- Dropout 적용하여 과적합 방지

 

최종 출력
- 로짓(logits) 반환 (소프트맥스나 시그모이드로 변환 가능)

 

 

 

 

Train

def train(self, mode=True):

    super(CRNN, self).train(mode)
    if self.freeze_bn:
        print("Freezing Mean/Var of BatchNorm2D.")
        if self.freeze_bn:
            print("Freezing Weight/Bias of BatchNorm2D.")
    if self.freeze_bn:
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                if self.freeze_bn:
                    m.weight.requires_grad = False
                    m.bias.requires_grad = False

참고문헌

[https://arxiv.org/abs/1703.01780]

'AI > MachineLearning' 카테고리의 다른 글

[NLP] 임베딩(Embedding)  (0) 2025.02.27