Skip to content

[Bug]: M1 GPU (mps) support #1794

@mattiasu96

Description

@mattiasu96

Describe the bug

It looks like the Speechbrain library does not support the M1 GPU (mps backend). The error is raised when trying to use the MPS backend on a pre-trained model (at least, this is the case I found, I don't know if it happens also in other situations, but I guess it does) and in particular the error is:

{ValueError}invalid type: 'torch.mps.FloatTensor'

The error is caused by this line in dual_path.py (file in the Speechbrain library, line 1066):

        if gap > 0:
            pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type())

And it is caused by the fact that input.type() returns torch.mps.FloatTensor but such value is not a valid Tensor type.

Such problem has been already reported in PyTorch (here: pytorch/pytorch#82296) and looks like it is on its way to be fixed.

However, it looks like Speechbrain will need to upgrade its PyTorch dependency (from the PyTorch discussion it looks like they're gonna include the fix in Torch 2.0) or find a workaround with the datatype in the meanwhile 🤔

Expected behaviour

Being able to use the MPS backend on a M1 Mac to run Speechbrain models

To Reproduce

from speechbrain.pretrained.interfaces import SepformerSeparation
import torchaudio
import torch

separator = SepformerSeparation.from_hparams(source="speechbrain/sepformer-wsj02mix", savedir="./pretrained-sepformer-wsj02mix",  run_opts={"device": "mps"})

s1, fs = torchaudio.load('./my_file.wav') # Just insert here any wav file you want
resampler = torchaudio.transforms.Resample(fs, 8000)

s1 = resampler(s1)

est_sources = separator.separate_batch(s1)

Versions

0.5.13

Relevant log output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[49], line 1
----> 1 est_sources = separator.separate_batch(s1)

File ~/Desktop/personal_git/voice-assistant/.venv/lib/python3.10/site-packages/speechbrain/pretrained/interfaces.py:1976, in SepformerSeparation.separate_batch(self, mix)
   1974 mix = mix.to(self.device)
   1975 mix_w = self.mods.encoder(mix)
-> 1976 est_mask = self.mods.masknet(mix_w)
   1977 mix_w = torch.stack([mix_w] * self.hparams.num_spks)
   1978 sep_h = mix_w * est_mask

File ~/Desktop/personal_git/voice-assistant/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Desktop/personal_git/voice-assistant/.venv/lib/python3.10/site-packages/speechbrain/lobes/models/dual_path.py:1017, in Dual_Path_Model.forward(self, x)
   1012     x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
   1013         x.size(1) ** 0.5
   1014     )
   1016 # [B, N, K, S]
-> 1017 x, gap = self._Segmentation(x, self.K)
   1019 # [B, N, K, S]
   1020 for i in range(self.num_layers):

File ~/Desktop/personal_git/voice-assistant/.venv/lib/python3.10/site-packages/speechbrain/lobes/models/dual_path.py:1097, in Dual_Path_Model._Segmentation(self, input, K)
   1095 B, N, L = input.shape
   1096 P = K // 2
-> 1097 input, gap = self._padding(input, K)
   1098 # [B, N, K, S]
   1099 input1 = input[:, :, :-P].contiguous().view(B, N, -1, K)

File ~/Desktop/personal_git/voice-assistant/.venv/lib/python3.10/site-packages/speechbrain/lobes/models/dual_path.py:1067, in Dual_Path_Model._padding(self, input, K)
   1065 gap = K - (P + L % K) % K
   1066 if gap > 0:
-> 1067     pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type())
   1068     input = torch.cat([input, pad], dim=2)
   1070 _pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type())

ValueError: invalid type: 'torch.mps.FloatTensor'

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions