Skip to content

Conversation

@sangeet2020
Copy link
Contributor

@sangeet2020 sangeet2020 commented Jun 6, 2023

Change needed in Whisper fine-tuning recipe to accommodate the latest release of transformers>4.30

Hi SB team,
I initiate a PR where I perform minor fixes in the main whisper fine-tuning script to accommodate changes in the latest Pytorch release of 2.0.
After the most recent pull of develop branch and latest torch version I have been experiencing this issue.

To reproduce this issue (choose any dataset)

(sb_env) [root@serv-3338 transformer]# python train_with_whisper.py --debug hparams/train_hf_whisper.yaml --seed 101 --model_version tiny
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/train_whisper/101
rescuespeech_prepare - ../../csv_files/Task_ASR/train.csv already exists, skipping data preparation!
rescuespeech_prepare - ../../csv_files/Task_ASR/dev.csv already exists, skipping data preparation!
rescuespeech_prepare - ../../csv_files/Task_ASR/test.csv already exists, skipping data preparation!
speechbrain.utils.train_logger - Training on input type: : clean_wav
speechbrain.core - Info: auto_mix_prec arg from hparam file is used
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Since debug mode is active, switching checkpointer output to temporary directory: /tmp/tmp9847_aub
speechbrain.core - 37.8M trainable parameters in ASR
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
  0%|                                                                                                                                                                                                               | 0/653 [00:07<?, ?it/s]
speechbrain.core - Exception:
Traceback (most recent call last):
  File "/netscratch/sagar/thesis/speechbrain/recipes/RescueSpeech/ASR/transformer/train_with_whisper.py", line 323, in <module>
    asr_brain.fit(
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/core.py", line 1238, in fit
    self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/core.py", line 1091, in _fit_train
    loss = self.fit_batch(batch)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/core.py", line 959, in fit_batch
    outputs = self.compute_forward(batch, Stage.TRAIN)
  File "/netscratch/sagar/thesis/speechbrain/recipes/RescueSpeech/ASR/transformer/train_with_whisper.py", line 53, in compute_forward
    enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)
  File "/netscratch/sagar/thesis/sb_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/lobes/models/huggingface_whisper.py", line 166, in forward
    out_encoder = self.forward_encoder(wav)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/lobes/models/huggingface_whisper.py", line 189, in forward_encoder
    return self._get_encoder_states(wav)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/lobes/models/huggingface_whisper.py", line 200, in _get_encoder_states
    mel = self._get_mel(wav)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/lobes/models/huggingface_whisper.py", line 217, in _get_mel
    mels = self._log_mel_spectrogram(mels)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/lobes/models/huggingface_whisper.py", line 247, in _log_mel_spectrogram
    mel_spec = filters @ magnitudes
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 80] but got: [2, 201].

with torch==2.0.0. The detailed env.log is

SpeechBrain system description
==============================
Python version:
3.10.6 (main, Jun  5 2023, 22:14:15) [GCC 7.5.0]
==============================
Installed Python packages:
appdirs==1.4.4
attrs==23.1.0
black==19.10b0
certifi==2023.5.7
cfgv==3.3.1
charset-normalizer==3.1.0
click==8.0.4
cmake==3.26.3
distlib==0.3.6
entrypoints==0.3
filelock==3.12.0
flake8==3.7.9
fsspec==2023.5.0
huggingface-hub==0.15.1
HyperPyYAML==1.2.1
identify==2.5.24
idna==3.4
Jinja2==3.1.2
joblib==1.2.0
lit==16.0.5.post0
MarkupSafe==2.1.3
mccabe==0.6.1
more-itertools==9.1.0
mpmath==1.3.0
networkx==3.1
nodeenv==1.8.0
numpy==1.24.3
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
packaging==23.1
pathspec==0.11.1
Pillow==9.5.0
platformdirs==3.5.1
pluggy==0.13.1
pre-commit==3.3.2
py==1.11.0
pycodestyle==2.5.0
pyflakes==2.1.1
pytest==5.4.1
PyYAML==6.0
regex==2023.6.3
requests==2.31.0
ruamel.yaml==0.17.28
ruamel.yaml.clib==0.2.7
scipy==1.8.1
sentencepiece==0.1.99
-e /netscratch/sagar/thesis/speechbrain
sympy==1.12
tokenizers==0.13.3
toml==0.10.2
torch==2.0.0
torchaudio==2.0.1
torchvision==0.15.1
tqdm==4.65.0
transformers==4.29.2
triton==2.0.0
typed-ast==1.5.4
typing_extensions==4.6.3
urllib3==2.0.2
virtualenv==20.23.0
wcwidth==0.2.6
yamllint==1.23.0
==============================
Could not get git revision==============================
CUDA version:
11.7

NOTE

This issue does not occur (hence, not reproducible) in torch==1.11.0+cu113

The changes I propose have been tested with the exact version as mentioned above in the env.log. Following these changes, the error mentioned above vanishes.

thank you


Note: when merged, we desire to include your PR title in our contributions list, check out one of our past version releases
https://github.com/speechbrain/speechbrain/releases/tag/v0.5.14

Tip: below, on the « Create Pull Request » use the drop-down to select: « Create Draft Pull Request » – your PR will be in draft mode until you declare it « Ready for review »

@Adel-Moumen
Copy link
Collaborator

Hello @sangeet2020,

Thanks for the PR.

Is it backwards compatible with torch 1.13?

@sangeet2020
Copy link
Contributor Author

Hello @Adel-Moumen,

So after quite a lot of experiments, it seems that the issue is not with the torch, but transformers.

Here are my observations on

  • Python 3.10.6
  • torch==2.0.1+cu117

transformers train_with_whisper.py Size mismatch error
4.26.0 no changes in the script No
4.30.2 no changes in the script YES
4.30.2 Change the script as suggested in this PR No

I would request to you to try these setting once, so to be assured.

Let me know if you need log files for each.

Thank You

@Adel-Moumen Adel-Moumen self-assigned this Jun 19, 2023

filters = self._mel_filters
mel_spec = filters @ magnitudes
mel_spec = filters.transpose(0, 1) @ magnitudes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why it is coming from the transformers library.... For me, it seems to be related to the torch library...

Copy link
Contributor Author

@sangeet2020 sangeet2020 Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the torch and trasnformers version are you using? I'll try the exact same version.
Can you try installing torch 2.0 and transformers 4.26 vs transformers 4.32. I can confirm that I maintain the torch version throughout my expts and vary the transformer version, and the error vanishes as the per the table above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in transformers >= 4.30, they change how they compute feature_extractor.mel_filters. As a result in transformers >= 4.30, the shape of feature_extractor.mel_filters is (201,80) while in the prev version, it was (80,201). It causes a problem in our _log_mel_spectrogram function ,we copy from openAI, when we calculate mel_spec = filters @ magnitudes. It expects the filter to be (80,201).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Pooneh,I see. Do you know why they changed their 'feature_extractor.mel_filters' ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is related to this PR on transformers:
huggingface/transformers#21998

Basically what they do is replacing the hand-rolled STFT in the different models including whisper with the one from audio_util.

In Transformers version <= 4.28:
In order to get self.mel_filters , they use function that is specific for whisper --> "get_mel_filters "in /transformers/models/whisper/feature_extraction_whisper.py which returns (n_mels, n_freqs, ) --> (80, 210)
then, to calculate the log-Mel spectrogram, they call their own stft function in transformers/models/whisper/feature_extraction_whisper.py and finally call _np_extract_fbank_features which is basically the same as our function _log_mel_spectrogram in /speechbrain/lobes/models/huggingface_whisper.py which is the same as open_ai function.

In version >= 4.29:
In order to get the self.mel_filters they use "mel_filter_bank" function in transformers/audio_utils.py.
It is the same as following pytorchaudio function:
https://pytorch.org/audio/main/generated/torchaudio.functional.melscale_fbanks.html
This will return (n_freqs, n_mels) --> (210, 80).
Then they call spectrogram function in audio_utils.py to generate log_mel_spectrogram and they transpose mel_filters.
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram)
This function is quite similar to melscale in Pytorchaudio
https://pytorch.org/audio/main/generated/torchaudio.transforms.MelScale.html

Final Points:

  1. It seems they try to unify all audio functions across different models and use the same implementation as pytorchaudio and Librosa. These models are affected by this change:
  • CLAP
  • M-CTC-T
  • SpeechT5
  • TVLT
  • Whisper
  1. Based on the implementation, I think the suggested fix by @sangeet2020 makes sense.
  2. I am going to run all comon_voice recipes for the major release. Maybe it would be better to merge this change so I could test whisper recipes and make sure there won't be any bugs introduced by this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Adel-Moumen and @sangeet2020 I am thinking about applying transpose only if the shape is not n_mels, n_freqs, So it could also work for an older version of transformers.

@poonehmousavi poonehmousavi mentioned this pull request Jun 24, 2023
@mravanelli
Copy link
Collaborator

I think we need to proceed with that and merge it after doing some final tests. My understanding is that the problem is not pytorch 2.0 but the version of the transformer used. In that case, I think we should also change the extra-requirement file in the recipes that use whisper (e.g, https://github.com/speechbrain/speechbrain/blob/develop/recipes/CommonVoice/ASR/transformer/extra_requirements.txt).

@poonehmousavi could you please proceed with it and do the last tests? When you think everything is fine we can merge it (we need it for CL-MASR as well).

@poonehmousavi
Copy link
Collaborator

@mravanelli Yes the problem is transformers, with the proposed solution, it works with the new version of transformers. I have tested it on a small dataset. I will test that more thoroughly while running whisper experiments fro common-voice.

@mravanelli
Copy link
Collaborator

Also, it looks like there are many places were transformer is added as an additional dependency (e.g., https://github.com/speechbrain/speechbrain/blob/develop/recipes/LibriSpeech/ASR/CTC/extra_requirements.txt). Do you think all these recipes might be affected somehow by the change done in transformers >= 0.30?

@poonehmousavi
Copy link
Collaborator

These transformer models are affected by this change:
CLAP
M-CTC-T
SpeechT5
TVLT
Whisper
apart from Whisper is there any other model that we used from the above list?
Also in general, it shouldn't cause any problem, in a whisper we don't use the transformer function for extracting features and use the OpenAI function because of the efficiency as far as I remembered but not super sure and that is the result of the inconsistencies.
@Adel-Moumen what was the reason that we don't directly use transformer whisperfeature_extractor?

@Adel-Moumen
Copy link
Collaborator

Hello @poonehmousavi, the reason was that it was extremely slow. This is why we decided to use the OpenAI's implementation, because it was way waster.

@poonehmousavi
Copy link
Collaborator

Hello @poonehmousavi, the reason was that it was extremely slow. This is why we decided to use the OpenAI's implementation because it was way waster.

They change their function and basically use the touch audio function, I guess it might be faster now. One other solution is to use their own function instead of transposing.

@Adel-Moumen
Copy link
Collaborator

Hello @poonehmousavi, the reason was that it was extremely slow. This is why we decided to use the OpenAI's implementation because it was way waster.

They change their function and basically use the touch audio function, I guess it might be faster now. One other solution is to use their own function instead of transposing.

Yes, I agree. Could you please try and compare?

@poonehmousavi
Copy link
Collaborator

Sure. I will try that.

@sangeet2020
Copy link
Contributor Author

I think the proposed solution is not the best solution. When trying to infer a Whisper model trained on SpeechBrain using transformers4.30.2, I get a size mismatch while loading the state dictionary for the "HuggingFaceWhisper" module of the SB pre-trained model. And this error comes when doing parameter transfer in.

The error is

    hparams["pretrainer"].load_collected()
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/utils/parameter_transfer.py", line 312, in load_collected
    self._call_load_hooks(paramfiles, device)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/utils/parameter_transfer.py", line 329, in _call_load_hooks
    default_hook(obj, loadpath, device=device)
  File "/netscratch/sagar/thesis/speechbrain/speechbrain/utils/checkpoints.py", line 144, in torch_parameter_transfe
r
    incompatible_keys = obj.load_state_dict(
  File "/netscratch/sagar/thesis/sb_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load
_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for HuggingFaceWhisper:
        size mismatch for _mel_filters: copying a param with shape torch.Size([80, 201]) from checkpoint, the shape 
in current model is torch.Size([201, 80]).

@sangeet2020
Copy link
Contributor Author

sangeet2020 commented Jun 26, 2023

I have tested it on a small dataset.

@poonehmousavi
Could you try loading the intermediate checkpoint (only the whisper.ckpt) and do inference, while maintaining the same transformer version?

@poonehmousavi
Copy link
Collaborator

poonehmousavi commented Jun 26, 2023

@poonehmousavi
Could you try loading the intermediate checkpoint (only the whisper.ckpt) and do inference, while maintaining the same transformer version?

For sure it won't work, because it tries to load the whisper-checkpoint trained with an older version of transformers which saves the _mel_filters and it caused the mismatch. I am going to update the checkpoints soon using new versions.

@sangeet2020
Copy link
Contributor Author

sangeet2020 commented Jun 26, 2023

No even whisper checkpoint trained on latest transformer version doesn't work for me.

@poonehmousavi
Copy link
Collaborator

in
I will take a look to check what is the problem.

@sangeet2020 sangeet2020 changed the title Change needed in Whisper fine-tuning recipe to accommodate torch2.0 Change needed in Whisper fine-tuning recipe to accommodate transformers4.30.0 Jun 27, 2023
@mravanelli
Copy link
Collaborator

mravanelli commented Jul 2, 2023

I did a push the solution implemented by @lucadellalib in the CL_MASR benchmark. This solution is more backward compatible.
However, the models that we have on HF only work with the old version of transformers as they have been trained in this way:

https://huggingface.co/speechbrain/whisper_rescuespeech
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fr
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fa
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-sr
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-mn
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-ar
https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-hi

  • @poonehmousavi, do you see a way to make these models compatible with the latest version of transformers? Since we are going to release a new version, we might want our model to work with the latest version. If there is not an easy fix for that, I guess to only option is retraining, right?

@mravanelli
Copy link
Collaborator

I would like to take this opportunity to bring attention to the poor performance I observed while using our Whisper interface.

For example, when using the French language (https://huggingface.co/speechbrain/asr-whisper-large-v2-commonvoice-fr), the following code snippet demonstrates the issue:

>>> from speechbrain.pretrained import WhisperASR
torchvision is not available - cannot save figures
>>> import speechbrain
>>> speechbrain.__file__
'/scratch/ravanelm/speechbrain_fix_whisper/speechbrain/__init__.py'
>>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-large-v2-commonvoice-fr", savedir="pretrained_models/asr-whisper-large-v2-commonvoice-fr")
Downloading (…)rocessor_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 185k/185k [00:00<00:00, 2.98MB/s]
Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.99k/1.99k [00:00<00:00, 11.5MB/s]
Downloading pytorch_model.bin: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.17G/6.17G [00:55<00:00, 110MB/s]
speechbrain.lobes.models.huggingface_whisper - whisper encoder-decoder is frozen.
>>> asr_model.transcribe_file("/scratch/ravanelm/speechbrain_devjuly2/asr-whisper-large-v2-commonvoice-fr/example-fr.wav")
[['selon', 'se', 's', 'dw']]

The transcription output, [['selon', 'se', 's', 'dw']], does not accurately represent what the speaker is saying. It appears that the transcription is truncated.

The same issue occurs with the German language model:

>>> asr_model = WhisperASR.from_hparams(source="speechbrain/rescuespeech_whisper", savedir="pretrained_models/rescuespeech_whisper")
speechbrain.lobes.models.huggingface_whisper - whisper encoder-decoder is frozen.
>>> asr_model.transcribe_file("/scratch/ravanelm/speechbrain_devjuly2/whisper_rescuespeech/example_de.wav")
[['wenn', 'd', 'doff']]

Once again, the transcription [['wenn', 'd', 'doff']] is truncated.

These observations lead me to suspect that we may have a consistent issue across all the Whisper HF models we release. I believe there might be a problem with our interface that prevents it from performing the same computations during training and validation as in the fine-tuning recipes.

@poonehmousavi , any idea? This is pretty crucial as our Whisper interfaces appear to be functioning incorrectly.

@lucadellalib
Copy link
Collaborator

@mravanelli @poonehmousavi I just pushed a better fix that should solve the remaining backward compatibility issues even with previously trained models

@poonehmousavi
Copy link
Collaborator

I have checked the proposed solution. It works with the new version and it is backward compatible with prev versions. The previously trained models can be used by this change as well, Therefore, I think we could merge this pull request.

@mravanelli mravanelli self-requested a review July 5, 2023 15:02
@mravanelli
Copy link
Collaborator

I did some final tests and I confirm that everything is working and backward compatible. Thank you all for contributing to this fix!

@mravanelli mravanelli merged commit fe94a92 into speechbrain:develop Jul 5, 2023
@poonehmousavi
Copy link
Collaborator

I have tried the whisper interface for French and Arabic and I didn't observe the problem you mentioned. it is the result for example file on HF.
whisper_fr_interface

@sangeet2020
Copy link
Contributor Author

Thanks a lot @lucadellalib for making the fix and @poonehmousavi for doing the needed tests.
✌️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants