Step 1: Region Partition and Projection

这一步本质上做两件事:

  1. 把整张特征图切成若干个不重叠区域。
  2. 对每个 token 做 Q/K/V 投影,再把每个区域压缩成一个区域级表示,供后续路由使用。

1. 输入张量的含义

论文里写输入视觉特征为:

$$ X \in \mathbb{R}^{H \times W \times C} $$

含义是:

  • $H, W$:特征图的高和宽
  • $C$:每个 token 的通道维度

在代码里,这对应的是视觉 token 重新排回二维后的结果。代码中先将视觉 token 变回二维:

x_2d = x.view(B, H, W, C)

这里代码里多了一个 batch 维度 $B$,论文为了表达简洁省略了它。

2. Region Partition:区域划分

论文将特征图划分为 $S \times S$ 个不重叠区域。总区域数为:

$$ P = S^2 $$

由于整张特征图总共有 $HW$ 个 token,所以每个区域中的 token 数量为:

$$ \frac{HW}{S^2} $$

论文将重排后的区域特征记为:

$$ X^r \in \mathbb{R}^{S^2 \times \frac{HW}{S^2} \times C} $$

这个式子的含义是:

  • 第一维 $S^2$:一共有多少个区域
  • 第二维 $\frac{HW}{S^2}$:每个区域内部有多少个 token
  • 第三维 $C$:每个 token 的特征维度

在你的代码里,默认配置是:

$$ H = 16,\quad W = 16,\quad S = 4 $$

因此:

$$ S^2 = 16,\quad \frac{HW}{S^2} = \frac{16 \times 16}{4^2} = \frac{256}{16} = 16 $$

也就是说,代码里实际上是把 $16 \times 16$ 的特征图切成 $4 \times 4$ 个大区域,每个区域内部包含 $4 \times 4 = 16$ 个 token。

这对应代码中的分窗操作:

x_win = self._split_windows(x_2d, S)

代码里保留的是窗口内部二维结构,shape 更接近:

$$ (B, P, h, w, C) $$

其中:

$$ P = S^2,\quad h = \frac{H}{S},\quad w = \frac{W}{S} $$

而论文里把区域内部的 $(h, w)$ 展平成了 token 序列,因此写成:

$$ X^r \in \mathbb{R}^{S^2 \times \frac{HW}{S^2} \times C} $$

两者本质上是同一件事,只是表达形式不同。

3. Projection:Q / K / V 投影

论文给出的公式为:

$$ Q = XW^q,\quad K = XW^k,\quad V = XW^v $$

这表示对每个 token 做三组线性映射:

  • $Q$:Query,表示“我想关注什么”
  • $K$:Key,表示“我是什么内容”
  • $V$:Value,表示“我要传递什么信息”

在你的代码里,这一步对应:

self.qkv = nn.Linear(dim, dim * 3, bias=True)
qkv = self.qkv(x_win)
q, k, v = qkv.split([C, C, C], dim=-1)

也就是说:

  • 论文写法是分别写出三次线性映射
  • 代码实现是一次线性层直接投影到 $3C$,再拆成 $Q/K/V$

数学上两者是等价的。

如果按代码结构写得更严谨一些,其实更接近:

$$ Q^r = X^r W^q,\quad K^r = X^r W^k,\quad V^r = X^r W^v $$

因为代码里是在“区域重排后的特征”上做投影,而不是严格先对原始 $X$ 投影、再重排。

4. 为什么还要做区域级平均

论文接着写:

$$ Q_{win} = \mathrm{Mean}(Q^r),\quad K_{win} = \mathrm{Mean}(K^r) $$

并且有:

$$ Q_{win}, K_{win} \in \mathbb{R}^{S^2 \times C} $$

这一步的作用是:把每个区域内部的多个 token 压缩成一个区域级代表向量。

这样做的原因是:

  • 如果直接让所有 token 两两比较,计算量很大,而且容易受噪声干扰
  • 先得到区域级表示后,就可以先做“区域和区域之间”的粗粒度筛选
  • 后面只在少数被选中的重要区域上做细粒度注意力

在代码里,这一步对应:

q_win = q.mean(dim=(2, 3))
k_win = k.mean(dim=(2, 3))

这里是在每个窗口内部沿着空间维度做平均,因此输出 shape 为: $$ (B, P, C) $$

忽略 batch 维后,正好对应论文中的:

$$ Q_{win}, K_{win} \in \mathbb{R}^{S^2 \times C} $$

5. 这一步的直觉理解

可以把这一步理解为:

  • 原始 token:每个 patch 的细节描述
  • 区域平均向量:这个区域整体语义的摘要

例如在人脸伪造检测任务中,不同区域可能代表:

  • 头发或背景
  • 面部主体
  • 融合边界
  • 衣物或外围噪声

那么 $Q_{win}$ 和 $K_{win}$ 就是这些区域的“摘要表示”。

后面的路由步骤不会立刻在所有 token 上做全局注意力,而是先利用这些区域摘要判断:

  • 哪些区域值得关注
  • 哪些区域可以过滤掉

所以 Step 1 的核心不是直接做注意力,而是为 Step 2 的区域路由建立粗粒度索引基础。

6. 和 baseline 的本质区别

baseline adapter.py 中没有这一步“先分区、再形成区域摘要”的机制。它是把 query token 和 vision token 直接拼起来,然后整体做标准 self-attention。

而你的改进版在这一步先把视觉 token 组织成区域结构,并提取区域级表示,因此后续注意力不再是无差别的全局计算,而是:

  1. 先做粗粒度区域筛选
  2. 再做细粒度稀疏注意力

这正是 AFB 中“Focus”的起点。

Step 2: Region-to-Region Routing

这一步的目标是:先在区域层面判断“哪些区域和当前区域最相关”,然后只保留最重要的 $k$ 个连接,形成稀疏路由。

1. 区域间相似度矩阵

论文给出的公式是:

$$ A_{win} = Q_{win} K_{win}^T $$

其中:

$$ A_{win} \in \mathbb{R}^{S^2 \times S^2} $$

这个矩阵的含义是:

  • 行索引 $i$:第 $i$ 个 query 区域
  • 列索引 $j$:第 $j$ 个 key 区域
  • 元素 $A_{win}^{(i,j)}$:区域 $i$ 对区域 $j$ 的语义相关性

换句话说,$A_{win}$ 是一个区域到区域的相似度图。

如果某个值很大,说明:

  • 第 $i$ 个区域发出的 query
  • 和第 $j$ 个区域提供的 key

在语义空间里比较匹配,因此第 $i$ 个区域后续应该更多地去看第 $j$ 个区域。

在代码里,对应的是:

attn = (q_win * self.scale) @ k_win.transpose(-1, -2)

这里实际代码比论文多了一项缩放:

$$ \text{attn} = \frac{Q_{win} K_{win}^T}{\sqrt{d}} $$

其中:

$$ \text{self.scale} = d^{-1/2} $$

这和标准 attention 的缩放写法一致,用来避免点积值过大导致 softmax 不稳定。

因此,论文里的:

$$ A_{win} = Q_{win} K_{win}^T $$

可以理解为省略了缩放因子的简写版本。

2. 为什么叫 directed semantic graph

论文把这个过程描述成构建一个 directed semantic graph,是因为:

  • 每个区域都作为一个节点
  • 从区域 $i$ 指向区域 $j$ 的边权是 $A_{win}^{(i,j)}$

并且这个图是有方向的,因为:

$$ A_{win}^{(i,j)} \neq A_{win}^{(j,i)} $$

一般并不保证对称。

这表示:

  • 区域 $i$ 关注区域 $j$
  • 不等于区域 $j$ 也同样程度地关注区域 $i$

这和普通无向图不同,更接近“谁去查询谁”的语义依赖关系。

3. Top-k 稀疏路由

论文接着写:

$$ I_{win} = \mathrm{TopKIndex}(A_{win}, k) $$

其中:

$$ I_{win} \in \mathbb{N}^{S^2 \times k} $$

这个矩阵的含义是:

  • 对每个 query 区域 $i$
  • 在所有 $S^2$ 个候选区域里
  • 只保留相似度最高的前 $k$ 个区域索引

也就是说,第 $i$ 行:

$$ I_{win}^{(i)} = [j_1, j_2, \dots, j_k] $$

表示“区域 $i$ 后面只去看这 $k$ 个最相关区域”。

在你的代码里,对应:

topk_val, topk_idx = torch.topk(attn, k=self.topk, dim=-1)
weight = F.softmax(topk_val, dim=-1)

其中:

  • topk_idx 对应论文里的 $I_{win}$
  • topk_val 是被选中的 top-$k$ 相似度分数
  • weight 是对这些 top-$k$ 分数再做 softmax 后得到的路由权重

如果严格对应数学表达,代码其实做的是两部分:

$$ I_{win} = \mathrm{TopKIndex}(A_{win}, k) $$

以及

$$ W_{win} = \mathrm{Softmax}(\mathrm{TopKValue}(A_{win}, k)) $$

也就是说,代码不只保留“连到谁”,还保留“连得有多强”。

4. 这一步为什么能实现 Focus

这一步是 AFB 中 “Focus” 的真正开始,因为它显式做了筛选。

原来全局注意力会让每个区域都可能和所有其他区域交互,总共有:

$$ S^2 $$

个候选区域。

但经过 top-$k$ 筛选后,每个区域只保留:

$$ k $$

个连接,其余

$$ S^2 - k $$

个区域直接被丢弃。

在你的默认设置里:

$$ S = 4,\quad S^2 = 16,\quad k = 4 $$

所以每个区域只保留 4 个最相关区域,相当于过滤掉:

$$ 16 - 4 = 12 $$

个不重要区域,也就是丢弃了 75% 的区域连接。

这就是为什么它可以起到“背景去噪”的作用:

  • 背景墙面
  • 衣物
  • 边缘噪声

如果和当前区域语义关联弱,就不会进入后续的细粒度注意力计算。

Step 3: Token Gathering and Sparse Attention

Step 2 只是在区域级别上决定“该看谁”,但真正的注意力仍然要落到 token 级别上。Step 3 就是在做这件事。

1. Gather 的含义

论文写:

$$ K_g = \mathrm{Gather}(K, I_{win}), \quad V_g = \mathrm{Gather}(V, I_{win}) $$

这表示:

  • 当前 query 区域自身的 query token 会保留
  • 但 key/value 不再来自所有区域
  • 而是只从 top-$k$ 路由到的那些区域中收集

也就是说,先在粗粒度上决定“看哪些区域”,再在这些区域内部取出所有 token 参与注意力。

2. Gather 后的维度为什么是这样

论文写:

$$ K_g, V_g \in \mathbb{R}^{S^2 \times \frac{kHW}{S^2} \times C} $$

这个维度很好理解。

每个 query 区域原本只对应一个区域内部的 token 数:

$$ \frac{HW}{S^2} $$

现在它连向了 $k$ 个区域,所以收集到的 key/value token 总数变成:

$$ k \cdot \frac{HW}{S^2} = \frac{kHW}{S^2} $$

所以对每个 query 区域来说,它后续做注意力时,不是面对全图所有 token,而只面对这部分被路由过来的 token。

在你的代码里,对应的是:

k_all = k_tok.view(B, 1, P, T, C).expand(-1, P, -1, -1, -1)
v_all = v_tok.view(B, 1, P, T, C).expand(-1, P, -1, -1, -1)
idx = r_idx.view(B, P, K, 1, 1).expand(-1, -1, -1, T, C)
 
k_sel = torch.gather(k_all, 2, idx)
v_sel = torch.gather(v_all, 2, idx)

这里:

  • $P = S^2$
  • $T = \frac{HW}{S^2}$
  • $K$ 在代码里就是 top-$k$

因此 k_selv_sel 的 shape 是:

$$ (B, P, K, T, C) $$

把中间两个维度合并以后,就对应论文里的:

$$ (B, S^2, \frac{kHW}{S^2}, C) $$

和论文表达是完全一致的。

3. 稀疏注意力公式

论文写最后的细粒度注意力为:

$$ \mathrm{Att}(Q, K_g, V_g)

\mathrm{Softmax}\left(\frac{QK_g^T}{\sqrt{d_k}}\right)V_g + \mathrm{LePE}(V) $$

这个公式可以拆成三部分理解。

第一部分,点积匹配:

$$ \frac{QK_g^T}{\sqrt{d_k}} $$

表示 query token 和 gather 后的 key token 做相似度计算。

第二部分,softmax:

$$ \mathrm{Softmax}\left(\frac{QK_g^T}{\sqrt{d_k}}\right) $$

表示把相似度变成归一化权重。

第三部分,加权聚合:

$$ \mathrm{Softmax}\left(\frac{QK_g^T}{\sqrt{d_k}}\right)V_g $$

表示用这些权重对 gather 到的 value token 做加权求和,得到当前区域更新后的表示。

4. 代码里是怎么实现的

对应代码:

q_tok = q_tok.view(B * P, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
k_sel = k_sel.view(B * P, K * T, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
v_sel = v_sel.view(B * P, K * T, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
 
attn = (q_tok * scale) @ k_sel.transpose(-1, -2)
attn = attn.softmax(dim=-1)
out = attn @ v_sel

这里表示:

  • 当前 query 区域内部有 $T$ 个 query token
  • 被路由来的 key/value 总共有 $K \cdot T$ 个 token
  • 多头注意力在这些 token 上做稀疏匹配

所以它不是全图 token 间的全连接注意力,而是:

  • 每个区域内部的 query token
  • 只和少量被路由进来的区域 token

做细粒度 attention。

5. 路由权重在代码里还被进一步利用了

你的代码里还有一步论文正文没有写得特别明确:

r_weight_expanded = r_weight.unsqueeze(-1).unsqueeze(-1)
v_sel = v_sel * r_weight_expanded

这表示在真正做 token-level attention 之前,先用区域级路由权重对选中的 value 做一次加权。

也就是说,区域层面的重要性不仅决定“选谁”,还会影响“选中的区域贡献多大”。

如果更严格地写成数学形式,可以补充理解为:

$$ \tilde{V}g = W{win} \odot V_g $$

然后后续注意力实际更接近:

$$ \mathrm{Att}(Q, K_g, \tilde{V}_g) $$

这一步让 coarse routing 和 fine attention 之间的耦合更强。

6. LePE 的作用

论文公式最后还有:

$$ \mathrm{LePE}(V) $$

它表示局部增强位置编码,用来补回局部空间连续性信息。

因为纯 attention 容易更偏内容相似性,而忽略局部邻域结构。LePE 通过深度可分离卷积,把局部空间偏置重新注入特征。

你的代码里对应:

self.lepe = nn.Conv2d(dim, dim, kernel_size=dwconv_kernel,
                      padding=dwconv_kernel // 2, groups=dim)

以及:

lepe = self.lepe(out_2d.permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()
out_2d = out_2d + lepe

所以最终输出并不是单纯的稀疏注意力结果,而是:

  • 稀疏 attention 的全局依赖
  • 加上 depth-wise convolution 带来的局部结构偏置

7. Step 2 和 Step 3 合起来在做什么

如果把这两步连起来看,其实就是一个很清楚的 coarse-to-fine 机制:

  1. 先用区域级摘要向量计算区域之间的相关性
  2. 对每个区域只保留 top-$k$ 个最相关区域
  3. 只从这些被选中的区域收集 token
  4. 再在收集到的 token 上做细粒度多头注意力

所以 AFB 的本质不是“直接做更复杂的注意力”,而是:

  • 先粗筛
  • 再精看

这就是你方法里 “Focus” 的数学本质。

3.3 Shortcut Forgetting Block (SFB)

SFB 的目标不是“再去找更显著的区域”,而是反过来抑制模型最容易依赖的显著特征,防止 adapter 学会不稳健的 shortcut。

如果说 AFB 解决的是“看哪里”的问题,那么 SFB 解决的是“看什么”的问题:

  • AFB 让模型尽量聚焦到可能含有伪造痕迹的区域
  • SFB 则进一步阻止模型只依赖最显眼、最容易分的那部分响应

这也是为什么论文把它叫做 Forget,而不是另一个 Focus

1. SFB 在代码里插在什么位置

结合你的实现,SFB 不是作用在原始图像上,也不是作用在 token 序列进入 ViT 之前,而是作用在第 8 层 adapter 输出出来的二维视觉特征图上。

在代码里,经过前 8 个 block 后,会先保存中间输出:

outs.append(
    {
        'query': x[:, :-v_L, ...],
        'x': x[:, -v_L:, ...].permute(0, 2, 1).reshape(n, d, h, w),
    }
)

这里:

  • feature['query'] 是 query token
  • feature['x'] 是视觉特征图,shape 为 $(N, D, H, W)$

随后 SFB 就作用在这个 feature['x'] 上:

masked_x = self.cfm(feature['x'])
xray_pred, attn_bias = self.mask_decoder(feature['query'], masked_x)

这说明你的 SFB 实际执行顺序是:

  1. 先让 AFB/BRA 完成视觉特征提炼
  2. 再对提炼后的视觉特征图做 shortcut forgetting
  3. 最后再把 masked feature 送进 mask_decoder

所以 SFB 的位置非常明确:它是 mask_decoder 之前的一个 feature regularizer。

2. SFB 的核心思想

论文 3.3 的核心假设是:

模型最容易依赖的 shortcut,通常会在特征图上激活得更强。

因此,不需要像 CFM 那样使用梯度来估计“哪些位置最关键”,而是可以直接看特征响应幅值。

论文中定义空间重要性图:

$$ M_{i,j} = |x_{i,j}|2 = \sqrt{\sum{c=1}^{C}(x_{i,j,c})^2} $$

其中:

  • $x_{i,j} \in \mathbb{R}^{C}$ 表示特征图在位置 $(i,j)$ 处的通道向量
  • $M_{i,j}$ 表示该空间位置的响应强度

在代码里对应:

importance_map = torch.norm(x, p=2, dim=1, keepdim=True)

这里输入 x 的 shape 是:

$$ x \in \mathbb{R}^{N \times C \times H \times W} $$

所以沿通道维做 $L_2$ 范数之后,就得到:

$$ M \in \mathbb{R}^{N \times 1 \times H \times W} $$

也就是说,对每个 batch 样本,SFB 都会生成一张空间重要性热图。

3. 为什么用 $L_2$ 范数

这个设计背后的直觉是:

  • 如果某个位置在很多通道上都有较强响应
  • 那么该位置的向量范数会更大
  • 这通常意味着模型在这个位置“看得更重”

因此 $L_2$ 范数是一个非常便宜但有效的显著性估计器。

和梯度型方法相比,它有两个优点:

  1. 不需要额外反向传播或二阶信息
  2. 可以直接在正常前向过程中计算,开销非常低

这也是论文里说它比 gradient-based CFM 更高效的原因。

4. Curriculum Masking Schedule:为什么要做课程式调度

如果一开始就把最重要的特征硬删掉,训练通常会不稳定,因为模型连最基础的判别模式都还没学会。

所以论文设计了两个随时间变化的量:

$$ p(t) = \min\left(p_{\max},\ p_{\mathrm{init}} + \frac{t}{T}(p_{\max}-p_{\mathrm{init}})\right) $$

$$ r(t) = \max\left(r_{\mathrm{final}},\ r_{\mathrm{init}} - \frac{t}{T}(r_{\mathrm{init}}-r_{\mathrm{final}})\right) $$

其中:

  • $p(t)$:当前 step 下触发 masking 的概率
  • $r(t)$:当前保留比例 keep ratio
  • $T$:调度饱和步数

论文语义很明确:

  • 随着训练进行,masking 更频繁,因为 $p(t)$ 逐渐增大
  • 同时 masking 更激进,因为 $r(t)$ 逐渐减小

在你的代码中,这个 schedule 由 _update_schedule() 实现:

self.current_prob = self.initial_mask_prob + (self.max_mask_prob - self.initial_mask_prob) * progress
self.current_alpha = self.initial_mask_alpha + (0.0 - self.initial_mask_alpha) * progress
self.current_ratio = self.initial_keep_ratio + (self.final_keep_ratio - self.initial_keep_ratio) * progress

这里除了论文中写到的 $p(t)$ 和 $r(t)$,代码还引入了第三个动态变量:

$$ \alpha(t) $$

它控制“被抑制区域到底衰减多少”,对应 soft-to-hard transition。

5. 你当前代码里的真实超参数

Adapter.__init__() 里,SFB 实际使用的是:

$$ p_{\mathrm{init}} = 0.05,\quad p_{\max} = 0.5 $$

$$ r_{\mathrm{init}} = 0.98,\quad r_{\mathrm{final}} = 0.7 $$

$$ \alpha_{\mathrm{init}} = 0.95,\quad \alpha_{\mathrm{final}} = 0 $$

$$ T = 30000 $$

也就是说:

  • 一开始只有 5% 的 batch 会触发 masking
  • 后期最多 50% 的 batch 会触发 masking
  • 一开始保留 98% 的空间位置,只抑制最显著的 2%
  • 最后保留 70% 的位置,也就是最多抑制 30% 的高响应区域
  • 一开始被抑制区域不是直接清零,而是只保留原值的 95%
  • 训练后期才逐渐走向完全抑制

这和论文实验设置部分写的是一致的。

6. keep ratio 在代码里到底是怎么变成 mask 的

代码先把重要性图展平:

importance_flat = importance_map.view(N, 1, -1)

然后计算当前 step 下要保留多少个位置:

num_pixels = H * W
k = int(num_pixels * self.current_ratio)

这里的 $k$ 不是 AFB 里的 routing top-$k$,而是:

$$ k = HW \cdot r(t) $$

表示当前要保留的空间位置数量。

随后代码执行:

topk_vals, _ = torch.topk(importance_flat, k, dim=-1, largest=False)
threshold = topk_vals[:, :, -1].unsqueeze(-1)

这一步很关键。这里用的是 largest=False,也就是:

  • 取最小的 $k$ 个响应值
  • 将这些较小响应对应的位置视为保留区域

因此,SFB 实际上是在:

  • 保留低响应位置
  • 抑制高响应位置

这和论文“mask out the most salient features”的描述是完全一致的。

换成数学表达,就是:

对每个样本,令阈值 $\tau(t)$ 为重要性图中第 $k$ 小的值,则二值保留掩码为:

$$ B_{i,j}(t) = \begin{cases} 1, & M_{i,j} \le \tau(t) \ 0, & M_{i,j} > \tau(t) \end{cases} $$

这里:

  • $B_{i,j}=1$ 表示 keep
  • $B_{i,j}=0$ 表示 drop

在代码里对应:

keep_mask = (importance_flat <= threshold).float().view(N, 1, H, W)
drop_mask = 1.0 - keep_mask

7. 这里有一个容易忽略但很重要的实现细节

论文里说的是“top-$(1-r(t))$ regions”,但你的代码实际并不是在 coarse region 级别做 masking,而是在最终特征图的空间位置上做 masking。

也就是说,它真正操作的是:

$$ H \times W $$

个 feature-map locations,而不是更粗的区域节点。

在你当前实现里,特征图大小是:

$$ H = W = 16 $$

所以 SFB 实际是在这 $16 \times 16 = 256$ 个空间位置上做显著性排序和抑制。

因此如果论文想更严格对齐代码,“regions” 这个词在这里最好换成:

  • spatial locations
  • feature positions
  • salient areas on the feature map

会更准确。

8. Soft-to-Hard Transition:为什么不是直接置零

论文给出的 masked feature 公式为:

$$ \tilde{X} = X \odot B + (1-B)\odot (X \cdot \alpha(t)) $$

这个式子非常重要,因为它表示被屏蔽的位置并不是一开始就完全删除,而是根据 $\alpha(t)$ 逐渐衰减。

可以分两部分理解:

第一部分:

$$ X \odot B $$

表示保留区域原样通过。

第二部分:

$$ (1-B)\odot (X \cdot \alpha(t)) $$

表示被认为“过于显著”的位置不会立刻清零,而是先乘一个衰减因子 $\alpha(t)$。

在代码里对应:

x_masked = x * keep_mask + (x * drop_mask * self.current_alpha)

这说明你的实现是一个典型的 soft-to-hard forgetting:

  • 训练前期:$\alpha(t)\approx 0.95$,只是轻微抑制
  • 训练后期:$\alpha(t)\to 0$,高响应位置接近被完全清零

9. 为什么这样更稳定

如果一开始就把最强响应区域完全删掉,模型会出现两个问题:

  1. 表征分布骤变,训练不稳定
  2. 早期模型尚未学会次级线索,直接强删会损伤优化

而 soft-to-hard 的过渡允许模型先建立基本判别能力,再逐步摆脱 shortcut 依赖。

这也是 SFB 的本质:

  • 不是粗暴破坏特征
  • 而是循序渐进地提高“忘记 shortcut”的难度

10. masking 不是每个 step 都一定触发

你的代码中还有一个经常被忽略的设计:

if torch.rand(1).item() > self.current_prob:
    return x

也就是说,哪怕当前已经计算出了重要性图和 mask,也不代表这一次前向一定会真正应用 masking。

只有以概率:

$$ p(t) $$

才会真正执行抑制。

这让 SFB 具备两层 curriculum:

  1. 触发频率逐渐升高
  2. 一旦触发,抑制强度也逐渐增强

因此它不是 deterministic hard masking,而是 stochastic curriculum masking。

11. 训练和推理阶段的行为不同

在代码里:

if self.training:
    ...
    return x_masked
else:
    return x

这表示:

  • 训练阶段:SFB 作为正则化器,会随机触发并抑制高响应区域
  • 推理阶段:默认不对特征做 masking,直接返回原始特征

所以 SFB 不是测试时也持续干预模型的模块,它本质上是训练阶段的 robustness regularizer。

这点非常重要,因为它说明你的方法在测试时并没有额外破坏判别信息,而是依靠训练期学到的更稳健表征来提升泛化。

12. last_mask 的作用

代码里还有:

self.last_mask = keep_mask.detach()

这不是训练逻辑本身的一部分,而是为了可视化保留最近一次生成的 mask。

即使在推理模式下,代码也会先算出当前样本的显著性 mask,只是最终不应用到特征上,而是保存在:

self.cfm.last_mask

这样你就可以在 return_map=True 时,把 SFB 的 forget map 一起返回做可视化。

13. 从优化角度理解 SFB

SFB 的优化目标不是让模型“不看重要区域”,而是让模型“不能只靠最显眼的那部分区域”。

更准确地说,它在训练时人为制造一种约束:

  • 如果模型只会依赖最强响应的 shortcut
  • 那么一旦这些 shortcut 被抑制,损失就会上升
  • 为了降低损失,模型被迫去开发次级但更稳健的判别线索

在人脸伪造检测里,这些次级线索通常更可能是:

  • blending boundaries
  • facial component inconsistencies
  • subtle texture mismatch
  • local structural artifacts

而不是:

  • 背景噪声
  • 衣物纹理
  • 头发高频
  • 训练集特有压缩痕迹

所以 SFB 的作用并不是“去噪”本身,而是“去 shortcut 依赖”。

14. AFB 和 SFB 的分工关系

把 3.2 和 3.3 连起来看,会更清楚你的方法逻辑:

  • AFB:先从空间上做筛选,减少背景和无关区域干扰
  • SFB:再从特征上做抑制,避免模型过度依赖最强激活 shortcut

也就是:

  • AFB 负责 where to look
  • SFB 负责 what not to rely on

这两个模块是互补的,而不是重复的。

15. 结合代码后,3.3 节最准确的一句话总结

如果完全按你现在的实现来总结,SFB 可以表述为:

SFB first computes an activation-magnitude importance map on the final adapter feature map, then applies a stochastic curriculum-based soft-to-hard suppression to the most salient spatial locations during training, forcing the decoder to rely on less obvious but more transferable forensic cues.

翻成中文就是:

SFB 先在最终 adapter 特征图上用通道 $L_2$ 范数估计空间显著性,再在训练阶段以概率触发、按课程式调度对最显著位置实施从软到硬的抑制,从而迫使解码器学习那些不那么显眼但更可迁移的伪造线索。

AFB 输入到底是什么

关于 Artifact Focusing Module,一个很容易混淆的问题是:进入这个模块的到底是不是“图像”,以及 Q/K/V 到底是从哪里来的。

结合你当前代码,结论很明确:

  • 进入 AFB/BRA 的只有 vision token
  • 不包含 query token
  • 也不是原始 RGB 图像

更准确地说,AFB 的输入是 vision token sequence,然后这些 token 会被重新排成二维空间特征图,以便执行分区、路由和稀疏注意力。

1. 输入到 AFB 的不是 query token

在你的实现里,进入 BRAttnForAdapter 之后,首先会把完整序列拆成两部分:

q_tok = x[:, :QL, :]
v_tok = x[:, QL:, :]

这里:

  • q_tok 是 query token
  • v_tok 是 vision token

真正进入 BRA 的是:

v_upd = self.bra_vis(v_tok, hw=(H, W))

也就是说,AFB/BRA 只处理 vision token

2. 为什么看起来像“变成了图像”

这一步更准确的说法不是“变回图像”,而是“重排成二维 feature map”。

因为 vision token 原本是序列:

$$ v_{tok} \in \mathbb{R}^{B \times N \times C} $$

其中:

  • $B$ 是 batch size
  • $N=H\times W$ 是 token 数
  • $C$ 是通道维度

而 BRA 要做的是:

  • 区域划分
  • 窗口路由
  • token gathering
  • LePE 卷积增强

这些操作都依赖明确的二维空间结构,所以代码中先执行:

x_2d = x.view(B, H, W, C)

得到:

$$ x_{2d} \in \mathbb{R}^{B \times H \times W \times C} $$

因此这不是在恢复原始图像像素,而是在恢复 token 对应的空间拓扑。

3. AFB 内部的 Q/K/V 从哪里来

AFB 中的 Q/K/V 不是来自 query token,也不是来自原图,而是来自重排后的 vision feature。

在代码里:

qkv = self.qkv(x_win)
q, k, v = qkv.split([C, C, C], dim=-1)

也就是说:

  • 先把二维 feature map 划成窗口
  • 再在这些视觉特征上做线性投影
  • 得到 BRA 内部使用的 Q/K/V

所以这里的 Q/K/V 本质上是:

$$ Q_v,\quad K_v,\quad V_v $$

也就是视觉分支内部的 query/key/value,而不是 adapter 里那组 learnable query token。

4. query token 在哪里参与

在你的实现里,query token 并不参与 AFB 内部的 BRA 路由。它们是在 vision token 被 BRA 更新之后,才作为读取端参与进来的。

对应代码:

q = self.q_proj(q_tok)
k = self.k_proj(v_upd)
v = self.v_proj(v_upd)
q_upd = self._mh_attn(q, k, v)

这表示:

  • 先由 BRA 更新 vision token,得到 $v_{upd}$
  • 再让 query token 对更新后的 vision token 做 attention

因此整个过程不是“query 和 vision 一起进入 BRA”,而是:

  1. vision -> vision:通过 BRA 做空间稀疏建模
  2. query -> vision:通过后续 attention 去读取更新后的视觉表示

5. 画图时应该怎么表述才准确

如果这张图画的是 AFB 的内部两阶段注意力框架,那么输入最好标成:

  • Vision Tokens
  • Visual Feature Tokens
  • Reshaped Spatial Features

而不建议直接写成 Image,因为那会让读者误解成这个模块是在原始像素空间里做运算。

更准确的描述应该是:

The AFB first reshapes the input vision tokens into a 2D spatial feature map, and then derives region-wise and token-wise Q/K/V projections from these visual features for bi-level routing attention.

翻成中文就是:

AFB 首先将输入的 vision tokens 重排为二维空间特征图,再从这些视觉特征中生成区域级和 token 级的 Q/K/V,用于后续的双层路由注意力。