MHA |
多头注意力模块,支持 flash_attn ,输入数据格式为:x:(B,T,C),atten_mask:(B,T) |
🔗 |
GQA |
分组注意力模块,支持 flash_attn ,输入数据格式为:x:(B,T,C),atten_mask:(B,T) |
🔗 |
MQA |
多查询注意力模块,支持 flash_attn ,输入数据格式为:x:(B,T,C),atten_mask:(B,T) |
🔗 |
SWA |
滑动窗口注意力模块,支持 flash_attn ,输入数据格式为:x:(B,T,C),atten_mask:(B,T) |
🔗 |
PosEncoding |
位置编码,RotaryPositionalEncoding 、AbsolutePositionEmbedding 、LearnedPositionEmbedding 。输入:x:(B,T,C) |
🔗 |
Norm |
正则化操作,LayerNorm 、BatchNorm 、RMSNorm 。输入:(B,T,C) 或者 (B,C,H,W) |
🔗 |
ResNet |
视觉编码器,ResNet50 , ResNet101 , ResNet152 。参数:num_classes:预测类别, channel_ratio:通道裁剪比率 |
🔗 |
Vit |
视觉编码器,Vit |
🔗 |
SwinTransformer |
视觉编码器,SwinTransformer |
🔗 |