Skip to content

Conversation

@lucadellalib
Copy link
Collaborator

When 'map_location=device', the checkpoint parameters are loaded on the CPU first (see docs of torch.load) and then moved to the target device. This means that, before copying the checkpoint parameters into the model, we have 2 full independent copies of the model parameters on the target device (first copy are the parameters in the model, second copy are the parameters being recovered from the checkpoint). If the model is on the CPU, we cannot avoid the waste of memory. However, if the model is on a device different from CPU (e.g. "cuda"), we can avoid moving the loaded checkpoint parameters on the device and hence avoid wasting the device memory and potentially having an out-of-memory error when dealing with huge models. Since 'obj.load_state_dict' copies the loaded parameters into the model/optimizer/scheduler in-place one by one, it automatically takes care of moving them to the model's device, even if they are on the CPU.

When 'map_location=device', the checkpoint parameters are loaded on the CPU first (see docs of torch.load) and then moved to the target device.
This means that, before copying the checkpoint parameters into the model, we have 2 full independent copies of the model parameters on the target device (first copy are the parameters in the model, second copy are the parameters being recovered from the checkpoint). If the model is on the CPU, we cannot avoid the waste of memory. However, if the model is on a device different from CPU (e.g. "cuda"), we can avoid moving the loaded checkpoint parameters on the device and hence avoid wasting the device memory and potentially having an out-of-memory error when dealing with huge models. Since 'obj.load_state_dict' copies the loaded parameters into the model/optimizer/scheduler in-place one by one, it automatically takes care of moving them to the model's device, even if they are on the CPU.
@mravanelli mravanelli requested a review from Gastron December 4, 2022 21:48
@mravanelli
Copy link
Collaborator

@Gastron could you please take a look?

@Gastron
Copy link
Collaborator

Gastron commented Dec 13, 2022

This is a good idea, I think we should go through with this change. Like noted, for example torch.load documentation suggests this.

However, the device argument of the loading code becomes unnecessary with this approach. It should be removed - unfortunately it gets used (again, unnecessarily) in many places, e.g.

self.checkpointer.recover_if_possible(

@lucadellalib would you be willing to go over the codebase and remove the unnecessary device argument? It should be removed from the torch_recovery, torch_parameter_transfer, as well as mark_as_loader and mark_as_transfer. And then it gets used in calls to recover_if_possible and possibly load_checkpoint? Note that this will take some time.

@lucadellalib
Copy link
Collaborator Author

@Gastron I think it would be better to leave it as it is to minimize the number of changes (easier to debug in case of problems, lower risk of breaking other components, etc.) and especially for backward compatibility. Some users might be using those functions and suddenly find their code not working anymore because they are passing an unexpected argument - device.

@mravanelli
Copy link
Collaborator

mravanelli commented Dec 30, 2022 via email

@Gastron
Copy link
Collaborator

Gastron commented Jan 5, 2023

I think it would be better to take the whole step of removing the device argument. However, I agree with Mirco that this is then a breaking change, meaning we should only add it during the next major version.

Of course we understand if @lucadellalib doesn't want to take the time to remove the argument from so many places, or how do you feel Luca? Perhaps someone else should contribute that part?

@mravanelli
Copy link
Collaborator

mravanelli commented Jan 5, 2023 via email

@lucadellalib lucadellalib changed the base branch from develop to unstable-v0.6 January 27, 2023 01:00
@lucadellalib
Copy link
Collaborator Author

lucadellalib commented Jan 27, 2023

@Gastron @mravanelli I removed the unnecessary device argument from all the places as discussed, please take a look.

@Gastron
Copy link
Collaborator

Gastron commented Feb 1, 2023

Great work! I browsed all the changed, looks good to me. This touches a lot of recipes, is it ok to merge this now, before the recipe testing PR @anautsch ?

@anautsch
Copy link
Collaborator

anautsch commented Feb 1, 2023

lgtm!

@Gastron we'll see what breaks & fix it - but this PR's changes seem complementary. Thank you, @lucadellalib !

@Gastron Gastron merged commit 02bead2 into speechbrain:unstable-v0.6 Feb 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants