Skip to content

Conversation

@ckl117
Copy link
Contributor

@ckl117 ckl117 commented Oct 28, 2024

PR types

Others

PR changes

Others

Description

append attn支持FP8 e4m3量化;
编译自定义算子时,自动生成FP8 cutlass gemm,并增加FP8 cutlass GEMM默认配置;
将FP8组网统一到FusedBlockMultiTransformer,方便后续维护;

@paddle-bot
Copy link

paddle-bot bot commented Oct 28, 2024

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Oct 28, 2024

Codecov Report

Attention: Patch coverage is 0% with 319 lines in your changes missing coverage. Please review.

Project coverage is 52.97%. Comparing base (66c5d65) to head (b50da65).
Report is 2 commits behind head on develop.

Current head b50da65 differs from pull request most recent head deb4651

Please upload reports for the commit deb4651 to get more accurate results.

Files with missing lines Patch % Lines
...dlenlp/experimental/transformers/llama/modeling.py 0.00% 129 Missing ⚠️
...dlenlp/experimental/transformers/qwen2/modeling.py 0.00% 124 Missing ⚠️
...erimental/transformers/fused_transformer_layers.py 0.00% 66 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9328      +/-   ##
===========================================
+ Coverage    52.24%   52.97%   +0.72%     
===========================================
  Files          673      673              
  Lines       109100   107355    -1745     
===========================================
- Hits         56998    56868     -130     
+ Misses       52102    50487    -1615     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +2873 to +2874
def compute_activation(self, ffn1_out, i):
return ffn1_out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8的activation是被融合了吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用自定义算子实现的cutlass版本的FP8 dual gemm融合了act;
使用Paddle实现的cublaslt的FP8 gemm在compute_ffn1函数内计算了act,所以继承的这个方法置空就可以了;

smooth_weight,
q_base_seq_id_this_block,
q_head_idx,
quant_max_bound,quant_min_bound,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

格式化一下C++代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也添加一下llama的重构吧

yuanlehome
yuanlehome previously approved these changes Oct 30, 2024
@ckl117 ckl117 closed this Nov 1, 2024
@ckl117 ckl117 reopened this Nov 1, 2024
@yuanlehome yuanlehome merged commit 5217a3b into PaddlePaddle:develop Nov 4, 2024
11 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants