算法分析

首先,由于要求最大化下面的式子:
$$ \sum\limits_{i=1}^{n-1}(b_{i}^{b_{i+1}}\bmod998244353) $$
容易想到使用 DP。

其次,由于双端队列需要控制两端的位置,所以显然要使用区间 DP

状态设计

首先记录一个区间的左、右端点,所以第一步令 $f_{l,r}$ 表示 $a$ 中 $[l,r]$ 这一区间最大的贡献。

但是,$[l,r]$ 可能由 $[l+1,r]$ 或 $[l,r-1]$ 得来,无法固定,所以需要记录 $[l,r]$ 是由哪个区间得来的。

能不能记录这个数的下标呢?这是不行的,因为这样空间复杂度就会变为 $O(n^3)$,铁定爆炸。

根据之前的描述,$[l,r]$ 只有两种得到的可能性,所以可以将第三维表示为 $0/1$:

  • $f_{l,r,0}$ 表示当前区间为 $[l,r]$,且是先弹出 $a_l$;
  • $f_{l,r,1}$ 表示当前区间为 $[l,r]$,且是先弹出 $a_r$。

这样,我们就可以进行状态转移了。

状态转移

声明:下面的转移方程中暂时先忽略取模。

考虑 $f_{l,r,0}$ 如何转移。

由于 $f_{l,r,0}$ 由 $[l+1,r]$ 得来,所以此时先弹出的是 $a_l$,即 $b_i=a_l$。

由于 $[l+1,r]$ 不固定,所以我们需要分类讨论:

  • $[l+1,r]$ 弹出的是 $l+1$,那么 $(b_i,b_{i+1})=(a_l,a_{l+1})$,所以结果是 $f_{l+1,r,0}+a_{l}^{a_{l+1}}$;
  • $[l+1,r]$ 弹出的是 $r$,那么 $(b_i,b_{i+1})=(a_l,a_r)$,所以结果是 $f_{l+1,r,1}+a_{l}^{a_{r}}$。

所以
$$ f_{l,r,0}=\max\{f_{l+1,r,0}+a_{l}^{a_{l+1}},f_{l+1,r,1}+a_{l}^{a_{r}}\} $$

或许给张图能更好理解(?

f[l][r][0] 的转移 图解

接下来考虑 $f_{l,r,1}$ 如何转移,与上面的转移同理。

$f_{l,r,1}$ 由 $[l,r-1]$ 得来,所以 $b_i=a_r$。

  • $[l,r-1]$ 弹出的是 $l$,那么 $b_{i+1}=a_l$;
  • $[l,r-1]$ 弹出的是 $r-1$,那么 $b_{i+1}=a_{r-1}$。

所以
$$ f_{l,r,1}=\max\{f_{l,r-1,0}+a_{r}^{a_{l}},f_{l,r-1,1}+a_{r}^{a_{r-1}}\} $$

转移细节

初始化

全部赋值为 $0$,因为一个空的区间显然结果为 $0$。

最终答案

最终结果为 $\max\{f_{1,n,0},f_{1,n,1}\}$,即整个数组第一次弹出左/右端点的答案中的最大值。

代码实现

实现细节

  • 本题中特别规定 $0^0=0$,所以快速幂中需要特判;

  • 区间长度至少为 $2$,否则弹出后是空的,没有意义。

  • $f$ 数组要开为 long long 类型:

    由于仅在快速幂的时候取模,求和是不取模,那么每个值最大为 $998,244,352$,需要求和 $n-1$ 次,$998,244,352\times(n_{\max}-1)=997,246,107,648>10^{9}$,所以要开 long long

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <cstring>
#include <iostream>
#define max(a, b) ((a)>(b)?(a):(b))
const int N = 1e3 + 10;
const int MOD = 998244353;
typedef long long lint;

int T, n, a[N];
lint f[N][N][2];

lint qpow(lint a, lint b) { // 快速幂
if (a == 0 && b == 0) return 0; // 特判
lint res = 1;
for (; b; b >>= 1) {
if (b & 1) (res *= a) %= MOD;
(a *= a) %= MOD;
}
return res;
}

int main() {
std::cin.tie(0)->sync_with_stdio(0);
for (std::cin >> T; T; --T) {
std::cin >> n, memset(f, 0, sizeof(f)); // 多测清空
for (int i = 1; i <= n; ++i) std::cin >> a[i];
for (int len = 2; len <= n; ++len) { // 从小到大枚举区间长度
for (int l = 1, r; ; ++l) {
if ((r = l + len - 1) > n) break; // 右端点超出范围就退出
// 根据状态转移方程进行转移
f[l][r][0] = max(f[l + 1][r][0] + qpow(a[l], a[l + 1]), f[l + 1][r][1] + qpow(a[l], a[r]));
f[l][r][1] = max(f[l][r - 1][0] + qpow(a[r], a[l]), f[l][r - 1][1] + qpow(a[r], a[r - 1]));
}
}
std::cout << max(f[1][n][0], f[1][n][1]) << '\n';
}
std::cout.flush();
return 0;
}