Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions speechbrain/processing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,11 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0):
current_means = []
current_stds = []

if self.norm_type == "sentence" or self.norm_type == "speaker":
# we will do in-place slice assignments over `out`
out = torch.empty_like(x)
# otherwise don't assign it yet

for snt_id in range(N_batches):
# Avoiding padded time steps
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
Expand All @@ -1095,7 +1100,7 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0):
current_stds.append(current_std)

if self.norm_type == "sentence":
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
out[snt_id] = (x[snt_id] - current_mean.data) / current_std.data

if self.norm_type == "speaker":
spk_id = int(spk_ids[snt_id][0])
Expand Down Expand Up @@ -1141,14 +1146,14 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0):
speaker_mean = current_mean.data
speaker_std = current_std.data

x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
out[snt_id] = (x[snt_id] - speaker_mean) / speaker_std

if self.norm_type == "batch" or self.norm_type == "global":
current_mean = torch.mean(torch.stack(current_means), dim=0)
current_std = torch.mean(torch.stack(current_stds), dim=0)

if self.norm_type == "batch":
x = (x - current_mean.data) / (current_std.data)
out = (x - current_mean.data) / (current_std.data)

if self.norm_type == "global":
if self.training:
Expand All @@ -1175,9 +1180,11 @@ def forward(self, x, lengths, spk_ids=torch.tensor([]), epoch=0):

self.count = self.count + 1

x = (x - self.glob_mean.data.to(x)) / (self.glob_std.data.to(x))
out = (x - self.glob_mean.data.to(x)) / (
self.glob_std.data.to(x)
)

return x
return out

def _compute_current_stats(self, x):
"""Computes mean and std
Expand Down