diff --git a/recipes/CommonVoice/ASR/transformer/README.md b/recipes/CommonVoice/ASR/transformer/README.md index e5ca8565d9..bd366f1672 100644 --- a/recipes/CommonVoice/ASR/transformer/README.md +++ b/recipes/CommonVoice/ASR/transformer/README.md @@ -4,6 +4,12 @@ This folder contains scripts necessary to run an ASR experiment with the CommonV # How to run python train.py hparams/{hparam_file}.py +## For Whisper finetuning: + +python train_with_whisper.py hparams/train__hf_whisper.yaml e.g. train__hf_whisper + +Note: When using whisper large model, to improve memory usage during model recovery. You could use (see https://github.com/speechbrain/speechbrain/pull/1743) + # Data preparation It is important to note that CommonVoice initially offers mp3 audio files at 42Hz. Hence, audio files are downsampled on the fly within the dataio function of the training script. @@ -12,12 +18,31 @@ Here is a list of the different languages that we tested within the CommonVoice with our transformers: - French +For Whisper-large-v2 finetuning, here is list of the different language that we tested within the CommonVoice.10_0 dataset: +- Hindi +- Arabic +- Persian +- Serbian +- Mongolian +- French + + # Results | Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Model link | GPUs | | ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| | French | 2020-06-22 | train_fr.yaml | No | 5.15 | 17.80 | 6.01 | 19.21 | [model](https://drive.google.com/drive/folders/12ny6daoz1Ze1MmgLrsqf352AXvhwob6d?usp=sharing) | 1xV100 16GB | +## Whisper Finetuning Result: +Following table contains whisper-finetuning results for 1 epoch using whisper_large_v2 model, freezing encoder and finetuning decoder. +| Language | Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | Model link | GPUs | +| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| +| Arabic | 2023-01-10 | train_ar_hf_whisper.yaml | No | 4.02 | 12.47 | 5.20 | 16.96 | [model](https://drive.google.com/drive/folders/10mYPYfj9NpDNAa0nO16Zd_K1bIEUOIpx?usp=sharing) | 1xV100 16GB | +| Persian | 2023-01-10 | train_fa_hf_whisper.yaml | No | 6.91 | 25.30 | 9.38 | 31.75 | [model](https://drive.google.com/drive/folders/1nzMMYmB5SxMKsFUk-rM9_ijcqzia8pX7?usp=sharing) | 1xV100 16GB | +| Mongolian | 2023-01-10 | train_mn_hf_whisper.yaml | No | 24.05 | 62.37 | 25.73 | 64.92 | [model](https://drive.google.com/drive/folders/10E2xclgNx_6BFxNmv9i1HorBNnsMveP_?usp=sharing) | 1xV100 16GB | +| Hindi | 2023-01-10 | train_hi_hf_whisper.yaml | No | 4.54 | 10.46 | 7.00 | 15.27 | [model](https://drive.google.com/drive/folders/11PKCsyIE703mmDv6n6n_UnD0bUgMPbg_?usp=sharing) | 1xV100 16GB | +| Serbian | 2023-01-10 | train_sr_hf_whisper.yaml | No | 8.92 | 27.12 | 7.60 | 23.63 | [model](https://drive.google.com/drive/folders/1QG67qoekEB29jBd9knt8stLJD4T_xgG7?usp=sharing) | 1xV100 16GB | +| French | 2023-01-10 | train_fr_hf_whisper.yaml | No | 3.00 | 8.95 | 3.83 | 10.62 | [model](https://drive.google.com/drive/folders/1_iI_G-pMYNeyLsvmHPgNR6gPi8zazkF4?usp=sharing) | 1xV100 16GB | The output folders with checkpoints and logs can be found [here](https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing). diff --git a/recipes/CommonVoice/ASR/transformer/extra_requirements.txt b/recipes/CommonVoice/ASR/transformer/extra_requirements.txt new file mode 100644 index 0000000000..976a2b1f39 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/extra_requirements.txt @@ -0,0 +1 @@ +transformers diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml new file mode 100644 index 0000000000..2126bbc880 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml @@ -0,0 +1,142 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: ar # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml new file mode 100644 index 0000000000..a6e612cc68 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml @@ -0,0 +1,142 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: fa # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml new file mode 100644 index 0000000000..94117376c7 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml @@ -0,0 +1,142 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper///test +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: fr # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml new file mode 100644 index 0000000000..7edb1668fe --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml @@ -0,0 +1,142 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: hi # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml new file mode 100644 index 0000000000..3ddce71fb3 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml @@ -0,0 +1,142 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: mn # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml b/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml new file mode 100644 index 0000000000..52e4442697 --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml @@ -0,0 +1,143 @@ +# ################################ +# Model: Whisper (Encoder-Decoder) + NLL +# Augmentation: TimeDomainSpecAugment +# Authors: Pooneh Mousavi 2022 +# ################################ + +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 1986 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref results/train_whisper// +wer_file: !ref /wer.txt +save_folder: !ref /save +train_log: !ref /train_log.txt + +# URL for the biggest Fairseq english whisper model. +whisper_hub: openai/whisper-tiny +test_only: False # Set it to True if you only want to do the evaluation + +# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information. +normalized_transcripts: True + +# Data files +locale: sr # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data. +data_folder: !PLACEHOLDER +train_tsv_file: !ref /train.tsv # Standard CommonVoice .tsv files +dev_tsv_file: !ref /dev.tsv # Standard CommonVoice .tsv files +test_tsv_file: !ref /test.tsv # Standard CommonVoice .tsv files +accented_letters: True +train_csv: !ref /train.csv +valid_csv: !ref /dev.csv +test_csv: !ref /test.csv +skip_prep: False # Skip data preparation + +# We remove utterance slonger than 10s in the train/dev/test sets as +# longer sentences certainly correspond to "open microphones". +avoid_if_longer_than: 10.0 + +ckpt_interval_minutes: 30 # save checkpoint every N min + +# Training parameters +number_of_epochs: 1 +lr_whisper: 0.00003 +sorting: ascending +auto_mix_prec: False +sample_rate: 16000 + +# With data_parallel batch_size is split into N jobs +# With DDP batch_size is multiplied by N jobs +batch_size: 12 +test_batch_size: 8 + +# These values are only used for the searchers. +# They needs to be hardcoded and should not be changed with Whisper. +# They are used as part of the searching process. +# The bos token of the searcher will be timestamp_index +# and will be concatenated with the bos, language and task tokens. +timestamp_index: 50363 +eos_index: 50257 +bos_index: 50258 + +# Decoding parameters +min_decode_ratio: 0.0 +max_decode_ratio: 0.1 +test_beam_size: 8 + +# Model parameters +freeze_whisper: False +freeze_encoder: True + +train_loader_kwargs: + batch_size: !ref + +valid_loader_kwargs: + batch_size: !ref + +test_loader_kwargs: + batch_size: !ref + + +# +# Functions and classes +# +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment + sample_rate: !ref + speeds: [95, 100, 105] + +whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper + source: !ref + freeze: !ref + freeze_encoder: !ref + save_path: !ref /whisper_checkpoint + encoder_only: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +nll_loss: !name:speechbrain.nnet.losses.nll_loss + +modules: + whisper: !ref + +whisper_opt_class: !name:torch.optim.AdamW + lr: !ref + weight_decay: 0.000000001 + +valid_greedy_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch + model: !ref + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + +test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch + module: [!ref ] + bos_index: !ref + eos_index: !ref + min_decode_ratio: !ref + max_decode_ratio: !ref + beam_size: !ref + +lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + whisper: !ref + scheduler_whisper: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + split_tokens: True diff --git a/recipes/CommonVoice/ASR/transformer/train_with_whisper.py b/recipes/CommonVoice/ASR/transformer/train_with_whisper.py new file mode 100644 index 0000000000..b9fe1b811d --- /dev/null +++ b/recipes/CommonVoice/ASR/transformer/train_with_whisper.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +"""Recipe for training a whisper-based ASR system with CommonVoice. +The system employs whisper from OpenAI (https://cdn.openai.com/papers/whisper.pdf). +This recipe take the whisper encoder-decoder to fine-tune on. + +To run this recipe, do the following: +> python train_with_whisper.py hparams/train__hf_whisper.yaml + + * Pooneh Mousavi 2022 +""" + +import sys +import torch +import logging +import torchaudio +import speechbrain as sb +from speechbrain.utils.distributed import run_on_main +from speechbrain.utils.data_utils import undo_padding +from hyperpyyaml import load_hyperpyyaml +from transformers.models.whisper.tokenization_whisper import LANGUAGES + +logger = logging.getLogger(__name__) + + +# Define training procedure +class ASR(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + bos_tokens, bos_tokens_lens = batch.tokens_bos + + # Add augmentation if specified + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "augmentation"): + wavs = self.hparams.augmentation(wavs, wav_lens) + + # We compute the padding mask and replace the values with the pad_token_id + # that the Whisper decoder expect to see. + abs_tokens_lens = (bos_tokens_lens * bos_tokens.shape[1]).long() + pad_mask = ( + torch.arange(abs_tokens_lens.max(), device=self.device)[None, :] + < abs_tokens_lens[:, None] + ) + bos_tokens[~pad_mask] = self.tokenizer.pad_token_id + + # Forward encoder + decoder + enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens) + + hyps = None + if stage == sb.Stage.VALID: + hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens) + elif stage == sb.Stage.TEST: + hyps, _ = self.hparams.valid_greedy_searcher(enc_out, wav_lens) + + return logits, hyps, wav_lens + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss NLL given predictions and targets.""" + + logits, hyps, wav_lens, = predictions + batch = batch.to(self.device) + ids = batch.id + tokens_eos, tokens_eos_lens = batch.tokens_eos + + log_probs = self.hparams.log_softmax(logits) + loss = self.hparams.nll_loss( + log_probs, tokens_eos, length=tokens_eos_lens, + ) + + if stage != sb.Stage.TRAIN: + tokens, tokens_lens = batch.tokens + + # Decode token terms to words + predicted_words = self.tokenizer.batch_decode( + hyps, skip_special_tokens=True + ) + + # Convert indices to words + target_words = undo_padding(tokens, tokens_lens) + target_words = self.tokenizer.batch_decode( + target_words, skip_special_tokens=True + ) + + if hasattr(self.hparams, "normalized_transcripts"): + predicted_words = [ + self.tokenizer._normalize(text).split(" ") + for text in predicted_words + ] + + target_words = [ + self.tokenizer._normalize(text).split(" ") + for text in target_words + ] + else: + predicted_words = [text.split(" ") for text in predicted_words] + + target_words = [text.split(" ") for text in target_words] + self.wer_metric.append(ids, predicted_words, target_words) + self.cer_metric.append(ids, predicted_words, target_words) + + return loss + + def on_stage_start(self, stage, epoch): + """Gets called at the beginning of each epoch""" + if stage != sb.Stage.TRAIN: + self.cer_metric = self.hparams.cer_computer() + self.wer_metric = self.hparams.error_rate_computer() + + def on_stage_end(self, stage, stage_loss, epoch): + """Gets called at the end of an epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + else: + stage_stats["CER"] = self.cer_metric.summarize("error_rate") + stage_stats["WER"] = self.wer_metric.summarize("error_rate") + + # Perform end-of-iteration things, like annealing, logging, etc. + if stage == sb.Stage.VALID: + + old_lr_whisper, new_lr_whisper = self.hparams.lr_annealing_whisper( + stage_stats["loss"] + ) + + sb.nnet.schedulers.update_learning_rate( + self.optimizer, new_lr_whisper + ) + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr_whisper": old_lr_whisper}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"WER": stage_stats["WER"]}, min_keys=["WER"], + ) + elif stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + with open(self.hparams.wer_file, "w") as w: + self.wer_metric.write_stats(w) + + +def dataio_prepare(hparams, tokenizer): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions.""" + data_folder = hparams["data_folder"] + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["train_csv"], replacements={"data_root": data_folder}, + ) + + if hparams["sorting"] == "ascending": + # we sort training data to speed up training and get better results. + train_data = train_data.filtered_sorted(sort_key="duration") + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_loader_kwargs"]["shuffle"] = False + + elif hparams["sorting"] == "descending": + train_data = train_data.filtered_sorted( + sort_key="duration", reverse=True + ) + # when sorting do not shuffle in dataloader ! otherwise is pointless + hparams["train_loader_kwargs"]["shuffle"] = False + + elif hparams["sorting"] == "random": + pass + + else: + raise NotImplementedError( + "sorting must be random, ascending or descending" + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["valid_csv"], replacements={"data_root": data_folder}, + ) + valid_data = valid_data.filtered_sorted(sort_key="duration") + + # test is separate + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=hparams["test_csv"], replacements={"data_root": data_folder}, + ) + + datasets = [train_data, valid_data, test_data] + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + info = torchaudio.info(wav) + sig = sb.dataio.dataio.read_audio(wav) + resampled = torchaudio.transforms.Resample( + info.sample_rate, hparams["sample_rate"], + )(sig) + return resampled + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("wrd") + @sb.utils.data_pipeline.provides( + "wrd", "tokens_list", "tokens_bos", "tokens_eos", "tokens" + ) + def text_pipeline(wrd): + yield wrd + tokens_list = tokenizer.encode(wrd) + # avoid bos and eos tokens. + tokens_list = tokens_list[1:-1] + yield tokens_list + tokens_bos = torch.LongTensor([hparams["bos_index"]] + tokens_list) + yield tokens_bos + tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) + yield tokens_eos + tokens = torch.LongTensor(tokens_list) + yield tokens + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "tokens_list", "tokens_bos", "tokens_eos", "tokens"], + ) + + return train_data, valid_data, test_data + + +if __name__ == "__main__": + # CLI: + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # If distributed_launch=True then + # create ddp_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset prep (parsing Librispeech) + from common_voice_prepare import prepare_common_voice # noqa + + # multi-gpu (ddp) save data preparation + run_on_main( + prepare_common_voice, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "train_tsv_file": hparams["train_tsv_file"], + "dev_tsv_file": hparams["dev_tsv_file"], + "test_tsv_file": hparams["test_tsv_file"], + "accented_letters": hparams["accented_letters"], + "language": hparams["locale"], + "skip_prep": hparams["skip_prep"], + }, + ) + # Defining tokenizer and loading it + tokenizer = hparams["whisper"].tokenizer + language = LANGUAGES[hparams["locale"]] + + tokenizer.set_prefix_tokens(language, "transcribe", False) + + # we need to prepare the tokens for searchers + hparams["valid_greedy_searcher"].set_decoder_input_tokens( + tokenizer.prefix_tokens + ) + hparams["valid_greedy_searcher"].set_language_token( + tokenizer.prefix_tokens[1] + ) + + hparams["test_beam_searcher"].set_decoder_input_tokens( + tokenizer.prefix_tokens + ) + hparams["test_beam_searcher"].set_language_token(tokenizer.prefix_tokens[1]) + + # here we create the datasets objects as well as tokenization and encoding + train_data, valid_data, test_data = dataio_prepare(hparams, tokenizer) + + # Trainer initialization + asr_brain = ASR( + modules=hparams["modules"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + opt_class=hparams["whisper_opt_class"], + ) + + # We load the pretrained whisper model + if "pretrainer" in hparams.keys(): + run_on_main(hparams["pretrainer"].collect_files) + hparams["pretrainer"].load_collected(asr_brain.device) + + # We dynamicaly add the tokenizer to our brain class. + # NB: This tokenizer corresponds to the one used for Whisper. + asr_brain.tokenizer = tokenizer + if hparams["test_only"] is False: + # Training + asr_brain.fit( + asr_brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_loader_kwargs"], + valid_loader_kwargs=hparams["valid_loader_kwargs"], + ) + + # Testing + asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt" + asr_brain.evaluate( + test_data, + min_key="WER", + test_loader_kwargs=hparams["test_loader_kwargs"], + ) + + asr_brain.hparams.wer_file = hparams["output_folder"] + "/wer_valid.txt" + asr_brain.evaluate( + valid_data, + min_key="WER", + test_loader_kwargs=hparams["test_loader_kwargs"], + ) diff --git a/recipes/CommonVoice/common_voice_prepare.py b/recipes/CommonVoice/common_voice_prepare.py index 0bffc52b3f..c7583e3082 100644 --- a/recipes/CommonVoice/common_voice_prepare.py +++ b/recipes/CommonVoice/common_voice_prepare.py @@ -1,11 +1,11 @@ """ Data preparation. - Download: https://voice.mozilla.org/en/datasets - Author ------ Titouan Parcollet +Luca Della Libera 2022 +Pooneh Mousavi 2022 """ import os @@ -32,12 +32,11 @@ def prepare_common_voice( """ Prepares the csv files for the Mozilla Common Voice dataset. Download: https://voice.mozilla.org/en/datasets - Arguments --------- data_folder : str Path to the folder where the original Common Voice dataset is stored. - This path should include the lang: /datasets/CommonVoice/en/ + This path should include the lang: /datasets/CommonVoice// save_folder : str The directory where to store the csv files. train_tsv_file : str, optional @@ -53,7 +52,6 @@ def prepare_common_voice( Specify the language for text normalization. skip_prep: bool If True, skip data preparation. - Example ------- >>> from recipes.CommonVoice.common_voice_prepare import prepare_common_voice @@ -119,7 +117,6 @@ def prepare_common_voice( # Additional checks to make sure the data folder contains Common Voice check_commonvoice_folders(data_folder) - # Creating csv files for {train, dev, test} data file_pairs = zip( [train_tsv_file, dev_tsv_file, test_tsv_file], @@ -134,9 +131,7 @@ def prepare_common_voice( def skip(save_csv_train, save_csv_dev, save_csv_test): """ Detects if the Common Voice data preparation has been already done. - If the preparation has been done, we can skip it. - Returns ------- bool @@ -162,7 +157,6 @@ def create_csv( ): """ Creates the csv file given a list of wav files. - Arguments --------- orig_tsv_file : str @@ -172,7 +166,6 @@ def create_csv( accented_letters : bool, optional Defines if accented letters will be kept as individual letters or transformed to the closest non-accented letters. - Returns ------- None @@ -234,52 +227,7 @@ def create_csv( words = unicode_normalisation(words) # !! Language specific cleaning !! - # Important: feel free to specify the text normalization - # corresponding to your alphabet. - - if language in ["en", "fr", "it", "rw"]: - words = re.sub( - "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words - ).upper() - - if language == "de": - # this replacement helps preserve the case of ß - # (and helps retain solitary occurrences of SS) - # since python's upper() converts ß to SS. - words = words.replace("ß", "0000ß0000") - words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper() - words = words.replace("'", " ") - words = words.replace("’", " ") - words = words.replace( - "0000SS0000", "ß" - ) # replace 0000SS0000 back to ß as its initial presence in the corpus - - if language == "fr": - # Replace J'y D'hui etc by J_ D_hui - words = words.replace("'", " ") - words = words.replace("’", " ") - - elif language == "ar": - HAMZA = "\u0621" - ALEF_MADDA = "\u0622" - ALEF_HAMZA_ABOVE = "\u0623" - letters = ( - "ابتةثجحخدذرزسشصضطظعغفقكلمنهويىءآأؤإئ" - + HAMZA - + ALEF_MADDA - + ALEF_HAMZA_ABOVE - ) - words = re.sub("[^" + letters + " ]+", "", words).upper() - elif language == "ga-IE": - # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase - def pfxuc(a): - return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ" - - def galc(w): - return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower() - - words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words) - words = " ".join(map(galc, words.split(" "))) + words = language_specific_preprocess(language, words) # Remove accents if specified if not accented_letters: @@ -298,8 +246,12 @@ def galc(w): chars = " ".join([char for char in chars][:]) # Remove too short sentences (or empty): - if len(words.split(" ")) < 3: - continue + if language in ["ja", "ch"]: + if len(chars) < 3: + continue + else: + if len(words.split(" ")) < 3: + continue # Composition of the csv_line csv_line = [snt_id, str(duration), mp3_path, spk_id, str(words)] @@ -325,27 +277,87 @@ def galc(w): logger.info(msg) +def language_specific_preprocess(language, words): + # !! Language specific cleaning !! + # Important: feel free to specify the text normalization + # corresponding to your alphabet. + + if language in ["en", "fr", "it", "rw"]: + words = re.sub( + "[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words + ).upper() + + if language == "de": + # this replacement helps preserve the case of ß + # (and helps retain solitary occurrences of SS) + # since python's upper() converts ß to SS. + words = words.replace("ß", "0000ß0000") + words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper() + words = words.replace("'", " ") + words = words.replace("’", " ") + words = words.replace( + "0000SS0000", "ß" + ) # replace 0000SS0000 back to ß as its initial presence in the corpus + + if language == "fr": + # Replace J'y D'hui etc by J_ D_hui + words = words.replace("'", " ") + words = words.replace("’", " ") + + elif language == "ar": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابتةثجحخدذرزژشسصضطظعغفقكلمنهويىءآأؤإئ" + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "fa": + HAMZA = "\u0621" + ALEF_MADDA = "\u0622" + ALEF_HAMZA_ABOVE = "\u0623" + letters = ( + "ابپتةثجحخچدذرزژسشصضطظعغفقگکلمنهویىءآأؤإئ" + + HAMZA + + ALEF_MADDA + + ALEF_HAMZA_ABOVE + ) + words = re.sub("[^" + letters + " ]+", "", words).upper() + elif language == "ga-IE": + # Irish lower() is complicated, but upper() is nondeterministic, so use lowercase + def pfxuc(a): + return len(a) >= 2 and a[0] in "tn" and a[1] in "AEIOUÁÉÍÓÚ" + + def galc(w): + return w.lower() if not pfxuc(w) else w[0] + "-" + w[1:].lower() + + words = re.sub("[^-A-Za-z'ÁÉÍÓÚáéíóú]+", " ", words) + words = " ".join(map(galc, words.split(" "))) + elif language == "es": + # Fix the following error in dataset large: + # KeyError: 'The item En noviembre lanzaron Queen Elizabeth , coproducida por Foreign Noi$e . requires replacements which were not supplied.' + words = words.replace("$", "s") + return words + + def check_commonvoice_folders(data_folder): """ Check if the data folder actually contains the Common Voice dataset. - If not, raises an error. - Returns ------- None - Raises ------ FileNotFoundError If data folder doesn't contain Common Voice dataset. """ - files_str = "/clips" - # Checking clips if not os.path.exists(data_folder + files_str): - err_msg = ( "the folder %s does not exist (it is expected in " "the Common Voice dataset)" % (data_folder + files_str) @@ -354,20 +366,13 @@ def check_commonvoice_folders(data_folder): def unicode_normalisation(text): - - try: - text = unicode(text, "utf-8") - except NameError: # unicode is a default on python 3 - pass return str(text) def strip_accents(text): - text = ( unicodedata.normalize("NFD", text) .encode("ascii", "ignore") .decode("utf-8") ) - return str(text) diff --git a/speechbrain/pretrained/interfaces.py b/speechbrain/pretrained/interfaces.py index ab020766f8..27ab1666f1 100644 --- a/speechbrain/pretrained/interfaces.py +++ b/speechbrain/pretrained/interfaces.py @@ -8,6 +8,7 @@ * Titouan Parcollet 2021 * Abdel Heba 2021 * Andreas Nautsch 2022 + * Pooneh Mousavi 20023 """ import logging import hashlib @@ -2871,3 +2872,126 @@ def decode_spectrogram(self, spectrogram): def forward(self, spectrogram): "Decodes the input spectrograms" return self.decode_batch(spectrogram) + + +class WhisperASR(Pretrained): + """A ready-to-use Whisper ASR model + + The class can be used to run the entire encoder-decoder whisper model + (transcribe()) to transcribe speech. The given YAML must contains the fields + specified in the *_NEEDED[] lists. + + # Example + # ------- + # >>> from speechbrain.pretrained.interfaces import foreign_class + # >>> tmpdir = getfixture("tmpdir") + # >>> asr_model = WhisperASR.from_hparams(source="speechbrain/asr-whisper-large-v2-commonvoice-fr", savedir=tmpdir,) + # >>> asr_model.transcribe_file("tests/samples/example2.wav") + # """ + + HPARAMS_NEEDED = ["language"] + MODULES_NEEDED = ["whisper", "decoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.hparams.whisper.tokenizer + self.tokenizer.set_prefix_tokens( + self.hparams.language, "transcribe", False + ) + self.hparams.decoder.set_decoder_input_tokens( + self.tokenizer.prefix_tokens + ) + + def transcribe_file(self, path): + """Transcribes the given audiofile into a sequence of words. + + Arguments + --------- + path : str + Path to audio file which to transcribe. + + Returns + ------- + str + The audiofile transcription produced by this ASR system. + """ + waveform = self.load_audio(path) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.transcribe_batch( + batch, rel_length + ) + return predicted_words + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.mods.whisper.forward_encoder(wavs) + return encoder_out + + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch transcribed. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens) + predicted_words = self.tokenizer.batch_decode( + predicted_tokens, skip_special_tokens=True + ) + if self.hparams.normalized_transcripts: + predicted_words = [ + self.tokenizer._normalize(text).split(" ") + for text in predicted_words + ] + + return predicted_words, predicted_tokens + + def forward(self, wavs, wav_lens): + """Runs full transcription - note: no gradients through decoding""" + return self.transcribe_batch(wavs, wav_lens) diff --git a/tests/recipes/CommonVoice.csv b/tests/recipes/CommonVoice.csv index 219e22cce3..b6a1f9b608 100644 --- a/tests/recipes/CommonVoice.csv +++ b/tests/recipes/CommonVoice.csv @@ -15,4 +15,10 @@ ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/Co ASR,CommonVoice,recipes/CommonVoice/ASR/seq2seq/train_with_wav2vec.py,recipes/CommonVoice/ASR/seq2seq/hparams/train_rw_with_wav2vec.yaml,recipes/CommonVoice/ASR/seq2seq/common_voice_prepare.py,recipes/CommonVoice/ASR/seq2seq/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-rw,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --number_of_ctc_epochs=1 --skip_prep=True --output_neurons=27 --dnn_neurons=128 --emb_size=64 --dec_neurons=128 --beam_size=3 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint, ASR,CommonVoice,recipes/CommonVoice/ASR/transducer/train.py,recipes/CommonVoice/ASR/transducer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transducer/common_voice_prepare.py,recipes/CommonVoice/ASR/transducer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,"--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --rnn_neurons=64 --dnn_neurons=64 --dec_neurons=64 --joint_dim=64 --cnn_channels=[64, 100, 128]", ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --output_neurons=27 --d_model=128 --num_encoder_layers=3 --num_decoder_layers=3 --d_ffn=256 --stage_one_epochs=1, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_ar_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fa_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_fr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_sr_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_mn_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, +ASR,CommonVoice,recipes/CommonVoice/ASR/transformer/train_with_whisper.py,recipes/CommonVoice/ASR/transformer/hparams/train_hi_hf_whisper.yaml,recipes/CommonVoice/ASR/transformer/common_voice_prepare.py,recipes/CommonVoice/ASR/transformer/README.md,https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True, SSL,CommonVoice,recipes/CommonVoice/self-supervised-learning/wav2vec2/train_hf_wav2vec2.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/hparams/wav2vec2_base.yaml,recipes/CommonVoice/self-supervised-learning/wav2vec2/common_voice_prepare.py,recipes/CommonVoice/self-supervised-learning/wav2vec2/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --d_model=128 --wav2vec2_folder=tests/tmp/wav2vec2_checkpoint,