From af22c3aa4f2f70498cd292187c641d7d0720606e Mon Sep 17 00:00:00 2001 From: asu Date: Thu, 11 Apr 2024 15:46:18 +0200 Subject: [PATCH] Fix in-place input normalization when using `sentence`/`speaker` norm --- speechbrain/processing/features.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/speechbrain/processing/features.py b/speechbrain/processing/features.py index ee7731389e..20b5756d5a 100644 --- a/speechbrain/processing/features.py +++ b/speechbrain/processing/features.py @@ -1082,6 +1082,11 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0): current_means = [] current_stds = [] + if self.norm_type == "sentence" or self.norm_type == "speaker": + # we will do in-place slice assignments over `out` + out = torch.empty_like(x) + # otherwise don't assign it yet + for snt_id in range(N_batches): # Avoiding padded time steps actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() @@ -1095,7 +1100,7 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0): current_stds.append(current_std) if self.norm_type == "sentence": - x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data + out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data if self.norm_type == "speaker": spk_id = int(spk_ids[snt_id][0]) @@ -1141,14 +1146,14 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0): speaker_mean = current_mean.data speaker_std = current_std.data - x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std + out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std if self.norm_type == "batch" or self.norm_type == "global": current_mean = torch.mean(torch.stack(current_means), dim=0) current_std = torch.mean(torch.stack(current_stds), dim=0) if self.norm_type == "batch": - x = (x - current_mean.data) / (current_std.data) + out = (x - current_mean.data) / (current_std.data) if self.norm_type == "global": if self.training: @@ -1175,9 +1180,11 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0): self.count = self.count + 1 - x = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x)) + out = (x - self.glob_mean.data.to(x)) / ( + self.glob_std.data.to(x) + ) - return x + return out def _compute_current_stats(self, x): """Computes mean and std