Skip to content

[FEATURE] [RFC] Support for interchangable attention backends #2607

@fffffgggg54

Description

@fffffgggg54

Is your feature request related to a problem? Please describe.
Currently, many models rely on a standard multi-head self-attention operator. Timm currently allows the user choose between 2 versions, an eager pytorch implementation and a fused implementation provided by PyTorch (torch.nn.functional.scaled_dot_product_attention), along with the 3 implementations available through PyTorch (FA2, memory-efficient attention, eager). This can be restrictive (better implementation available elsewhere, upstream issues that prevent PT SDPA from working correctly) or leave performance on the table (FA3 and other newer implementations). Adding more supported backends to timm for the user to choose from (and eventually allowing the user to register their own) will alleviate this restriction. Overall, the current way eager vs sdpa is handled is also somewhat hacky imo.

Describe the solution you'd like
My thoughts are to create a registry for backends, similar to how models are managed. Supported backends should be attempted to be imported (flash_attn, xformers, others) and registered on success. The user should also have access to this, if they want to provide some other implementation with the same call signature. I'm not sure if this is the best approach.

Describe alternatives you've considered
Alternatives would be to modify/monkeypatch the model implementation to call another attention implementation. Not sure of how necessary this is, since I'm not sure of the performance advantages of FA3/others over PT sdpa for vision models. Part of the reason other libraries keep an attention impl registry seems to be that language has much more variation in attention compared to vision.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions