Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions recipes/CommonVoice/ASR/CTC/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ It is important to note that CommonVoice initially offers mp3 audio files at 42H

# Languages
Here is a list of the different languages that we tested within the CommonVoice dataset and CTC:
- English
- German
- French
- Italian
- Kinyarwanda
- English

# Results

| Language | CommonVoice Release | hyperparams file | LM | Val. CER | Val. WER | Test CER | Test WER | HuggingFace link | Model link | GPUs |
| ------------- |:-------------:|:---------------------------:| -----:| -----:| -----:| -----:| -----:| :-----------:| :-----------:| :-----------:|
| English | 2020-12-11 | train_en_with_wav2vec.yaml | No | 5.01 | 12.57 | 7.32 | 15.58 | Not Avail. | [model](https://drive.google.com/drive/folders/1tYO__An68xrM5pR1UIXzEkwzvKX2Tz2o?usp=sharing) | 2xV100 32GB |
| German | 2022-08-16 | train_de_with_wav2vec.yaml | No | 1.90 | 8.02 | 2.40 | 9.54 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-de) | [model](https://drive.google.com/drive/folders/19G2Zm8896QSVDqVfs7PS_W86-K0-5xeC?usp=sharing) | 1xRTXA6000 48GB |
| French | 2020-12-11 | train_fr_with_wav2vec.yaml | No | 2.60 | 8.59 | 3.19 | 9.96 | [model](https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-fr) | [model](https://drive.google.com/drive/folders/1T9DfdZwcNI9CURxhLCi8GA5JVz8adiY8?usp=sharing) | 2xV100 32GB |
| Italian | 2020-12-11 | train_it_with_wav2vec.yaml | No | 2.77 | 9.83 | 3.16 | 10.85 | Not Avail. | [model](https://drive.google.com/drive/folders/1JhlxeA04tWg_vKcNChOoXSnjBe4luRby?usp=sharing) | 2xV100 32GB |
| Kinyarwanda | 2020-12-11 | train_rw_with_wav2vec.yaml | No | 6.20 | 20.07 | 8.25 | 23.12 | Not Avail. | [model](https://drive.google.com/drive/folders/12_BDenvOqEERDZLAN-KdiAHklvuo35tx?usp=sharing) | 2xV100 32GB |

*For German, it takes around 5.5 hrs an epoch.* <br>
The output folders with checkpoints and logs can be found [here](https://drive.google.com/drive/folders/11NMzY0zV-NqJmPMyZfC3RtT64bYe-G_O?usp=sharing).

## How to simply use pretrained models to transcribe my audio file?
Expand All @@ -48,3 +51,4 @@ Please, cite SpeechBrain if you use it for your research or business.
note={arXiv:2106.04624}
}
```
Footer
179 changes: 179 additions & 0 deletions recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Model: wav2vec2 + DNN + CTC
# Augmentation: SpecAugment
# Authors: Titouan Parcollet 2021
# Mirco Ravanelli 2021
# Sangeet Sagar 2022
# ################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 8200
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/wav2vec2_ctc_de/<seed>
wer_file: !ref <output_folder>/wer.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# URL for the LARGE Fairseq German wav2vec2 model.
wav2vec2_hub: facebook/wav2vec2-large-xlsr-53-german

# Dataset prep parameters
data_folder: !PLACEHOLDER
train_tsv_file: !ref <data_folder>/train.tsv
dev_tsv_file: !ref <data_folder>/dev.tsv
test_tsv_file: !ref <data_folder>/test.tsv
accented_letters: True
language: de
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False

# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 10.0

# Training parameters
number_of_epochs: 45
lr: 1.0
lr_wav2vec: 0.0001
sorting: ascending
auto_mix_prec: False
sample_rate: 16000
ckpt_interval_minutes: 30 # save checkpoint every N min

# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 6 per GPU to fit 16GB of VRAM
batch_size: 8
test_batch_size: 8
dataloader_num_workers: 8
test_num_workers: 8

dataloader_options:
batch_size: !ref <batch_size>
num_workers: !ref <dataloader_num_workers>
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: !ref <test_num_workers>

# BPE parameters
token_type: char # ["unigram", "bpe", "char"]
character_coverage: 1.0

# Model parameters
# activation: !name:torch.nn.LeakyReLU
dnn_neurons: 1024
wav2vec_output_dim: !ref <dnn_neurons>
freeze_wav2vec: False
dropout: 0.15

# Outputs
output_neurons: 32 # BPE size, index(blank/eos/bos) = 0

# Be sure that the bos and eos index match with the BPEs ones
blank_index: 0
bos_index: 1
eos_index: 2

# Functions and classes
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>

augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
sample_rate: !ref <sample_rate>
speeds: [95, 100, 105]

enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: !ref <dropout>
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: !ref <dropout>
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU

wav2vec2: !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
save_path: !ref <save_folder>/wav2vec2_checkpoint

#####
# Uncomment this block if you prefer to use a Fairseq pretrained model instead
# of a HuggingFace one. Here, we provide an URL that is obtained from the
# Fairseq github for the multilingual XLSR.
#
#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
# pretrained_path: !ref <wav2vec2_url>
# output_norm: True
# freeze: False
# save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt

ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True

ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>

modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>

model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]

model_opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8

wav2vec_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0

lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
wav2vec2: !ref <wav2vec2>
model: !ref <model>
scheduler_model: !ref <lr_annealing_model>
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>

error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats

cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True
12 changes: 12 additions & 0 deletions recipes/CommonVoice/common_voice_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,18 @@ def create_csv(
"[^’'A-Za-z0-9À-ÖØ-öø-ÿЀ-ӿéæœâçèàûî]+", " ", words
).upper()

if language == "de":
# this replacement helps preserve the case of ß
# (and helps retain solitary occurrences of SS)
# since python's upper() converts ß to SS.
words = words.replace("ß", "0000ß0000")
words = re.sub("[^’'A-Za-z0-9öÖäÄüÜß]+", " ", words).upper()
words = words.replace("'", " ")
words = words.replace("’", " ")
words = words.replace(
"0000SS0000", "ß"
) # replace 0000SS0000 back to ß as its initial presence in the corpus

if language == "fr":
# Replace J'y D'hui etc by J_ D_hui
words = words.replace("'", " ")
Expand Down
4 changes: 3 additions & 1 deletion speechbrain/processing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,9 @@ def forward(self, x):

# Derivative estimation (with a fixed convolutional kernel)
delta_coeff = (
torch.nn.functional.conv1d(x, self.kernel, groups=x.shape[1])
torch.nn.functional.conv1d(
x, self.kernel.to(x.device), groups=x.shape[1]
)
/ self.denom
)

Expand Down
1 change: 1 addition & 0 deletions tests/recipes.csv
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,4 @@ recipe0147,ASR,LibriSpeech,recipes/LibriSpeech/ASR/CTC/train_with_wav2vec.py,rec
recipe0148,Separation,WSJ0Mix,recipes/WSJ0Mix/separation/train.py,recipes/WSJ0Mix/separation/hparams/sepformer-conformerintra.yaml,recipes/WSJ0Mix/separation/dynamic_mixing.py,recipes/WSJ0Mix/separation/README.md,https://drive.google.com/drive/folders/1NcB7pKj7qWzDaI3ScDOyQwJLvRdW9rfl,
recipe0149,ASR,AISHELL-1,recipes/AISHELL-1/ASR/CTC/train_with_wav2vec.py,recipes/AISHELL-1/ASR/CTC/hparams/train_with_wav2vec.yaml,recipes/AISHELL-1/ASR/CTC/aishell_prepare.py,recipes/AISHELL-1/ASR/CTC/README.md,https://drive.google.com/drive/folders/1GTB5IzQPl57j-0I1IpmvKg722Ti4ahLz?usp=sharing,https://huggingface.co/speechbrain/asr-wav2vec2-ctc-aishell,,
recipe0150,TTS,AISHELL-1,recipes/LibriTTS/vocoder/hifigan/train.py,recipes/LibriTTS/vocoder/hifigan/hparams/train.yaml,recipes/LibriTTS/libritts_prepare.py,recipes/LibriTTS/README.md,https://drive.google.com/drive/folders/1cImFzEonNYhetS9tmH9R_d0EFXXN0zpn?usp=sharing,https://huggingface.co/speechbrain/tts-hifigan-libritts-16kHz,,
recipe0151,ASR,CommonVoice,recipes/CommonVoice/ASR/CTC/train_with_wav2vec.py,recipes/CommonVoice/ASR/CTC/hparams/train_de_with_wav2vec.yaml,recipes/CommonVoice/ASR/CTC/common_voice_prepare.py,recipes/CommonVoice/ASR/CTC/README.md,https://drive.google.com/drive/folders/19G2Zm8896QSVDqVfs7PS_W86-K0-5xeC?usp=sharing,https://huggingface.co/speechbrain/asr-wav2vec2-commonvoice-de,,