-
Notifications
You must be signed in to change notification settings - Fork 1.6k
BEST-RQ implementation #2309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BEST-RQ implementation #2309
Conversation
|
Hello @whettenr, Thanks for this PR! We've recently merged in develop a lot of new PRs as part of SpeechBrain 1.0. Unfortunately, this PR as many conflicts due to the latests merge. Do you mind updating your fork and fix all the potential conflicts ? Thanks. Best, |
|
Hey @Adel-Moumen, |
|
Hello @whettenr, a few major changes before we go into the details of the code. You know that SB follows a dataset-task-oriented structure. Hence you will need to refactor the structure of your code. BEST-RQ is trained on LS, so it should be on Librispeech self-supervised-learning folder, alongside wav2vec2. For the finetuning, they should be in the different task folders, like ASR, CTC. You also have too many hparams I think, let's keep only the most meaningful ones. |
|
@whettenr thank you for your contribution. I tried reproducing your results from this paper. I followed your pretraining implementation from here using the Branchformer configuration. The only differences in my setup were increasing the batch size from 100 to 1000 and to match your setup I reduced the warmup steps from 25000 to 2500. After pretraining, I froze the model and used all hidden representation with linear combination following the SUPURB CTC downstream task and trained a single-layer BiLSTM on 100 Libri, but I got PER 65 on test-clean, which is significantly worse than the WER reported in the paper. Do you have any ideas on what might be causing this issue? |
This is my expert.py which shows how I loaded and used the pretrained BestRQ model in s3prl: |
I'm not sure what exactly could be causing it, but there few key differences that maybe could lead to bad performance.
My advice: Hope this helps |
|
Thanks, @whettenr, for the quick response. I followed your recommendation and used the Conformer setup. However, I noticed that although you have defined the Noam scheduler in the Conformer config file, it is not used in your train.py file (https://github.com/whettenr/speechbrain/blob/sync-deletion/recipes/BEST-RQ/train.py). During training, the logs showed the same value for the learning rate throughout. Was this intentional or an oversight? I assumed it might have been an oversight, so I modified train.py slightly to incorporate the Noam scheduler. Here is the change I made: I also added TensorBoard to visualize training and validation losses: The current setup is exactly the same as yours, except I am using 5 GPUs with a batch size of 1000 samples per GPU, buckets=100, and warmup=700, masking 10%. Let me know if the values I obtained for training and validation are close to the ranges of your values or if they are off. By the way, the accuracy on validation is around 20%. |
if you want to reproduce the results, I believe @whettenr used the MP3S benchmark available here: https://github.com/speechbrain/benchmarks/tree/main/benchmarks/MP3S. As pointed out by Ryan, augmenting the batch size and reducing the warmup might affect negatively the dowstream results. BTW, I just saw on your GitHub profile that you were at JHU. I am currently working there as part of the JSALT workshop so if you want we can have a one-to-one discussion about speechbrain and your issue :) |
I believe a validation accuracy around 20% is pretty good! Have you tried that model on the DS tasks? and yes I did use the MP3S benchmark. |
Thank you, @Adel-Moumen! I will definitely try the MP3S benchmark. It's such a coincidence that you're at JSALT. I hope you're enjoying your time there. Unfortunately, I'm currently in Boston, but I look forward to meeting you in person someday. |
@whettenr I saw your logs here https://github.com/whettenr/bestrq/blob/main/results/best_hyperconformer/2000/log.txt and the numbers look close to mine. Do you mind sharing the |
|
@asumagic I took care of cleaning, fixing, documenting the code and even retraining the models. Could you please have a look very briefly at the code to see if anything seems crazy to you? I also added the return type documentation due to this new hidden_states return. The models are missing from the readme, but we can do another PR for that imho. If @asumagic is happy, i'll merge. |
asumagic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice! I've only made a superficial review, though. With those points fixed, if the recipe tests pass, it's ok to me.
hi @AmirHussein96 , Have you tried finetuning only with a linear and CTC loss like HuBERT, does it work? |
|
@xxchauncey it does work, you have a recipe to do that in this PR. It will get merged soon. We will train the models and get the results soon-ish as well. |
|
@asumagic I added a few docstring test for the hidden_state thing, but i guess you are right, we don't even have a single unit test for the transformerASR models outside of docstrings tests... maybe we should have a PR to fix that. |
|
yayyy |


What does this PR do?
Adds an implementation of BEST-RQ.
Add layer dropout interface for transformer classes.
Add output hidden layers and wrappers to be able to run MP3S benchmarks.
Fixes #<issue_number>
Before submitting
PR review
Reviewer checklist