题解 P5326 【[ZJOI2019]开关】
lyx_cjz
2019-05-08 08:42:36
首先容易想到用高斯消元法解决这个问题,这样的复杂度是 $O(8^n)$。
可以发现,这是对稀疏矩阵高斯消元,可以做到 $O(4^n)$。
当然,这都是过不了的。
我们来考虑优化这个消元,设 $f_{T}$ 表示对 $s_i=[i\in T]$ 时的答案。并假定 $\sum_{i=1}^n {p_i}=1$。
对 $T\neq \varnothing$,我们有:
$$f_T=1+ \sum_{i=1}^n p_i f_{T\oplus \{i\}}$$
我们观察到,这是一个异或卷积!
定义 $G=\sum_{i=1}^n p_ix^{\{i\}}$ 为集合幂级数,设 $F=\sum_T f_T x^T$ 也为集合幂级数。
于是 $F=\sum_T x^T + F \times G + c\cdot x^{\varnothing}$。
其中,$c$ 是由于 $T=\varnothing$ 时不能用上式处理所做的待定系数。
现在,我们需要求 $[x^S]F$,根据 FWT 的知识,我们对该式先做 FWT。
记 $\widetilde{F}$ 为 $F$ 的 FWT 结果。
对集合 $U$,该式在 $U$ 处的 FWT 为:
$$[x^U]\widetilde{F} \times \left (1- [x^U]\widetilde{G}\right ) = \sum_T (-1)^{|T \cap U|}+ c$$
当 $U=\varnothing$ 时,$[x^U]\widetilde{G}=1$,左式为 $0$。解得 $c=-2^n$。
对其他 $U$,$[x^U]\widetilde{G}< 1$,所以可以解出 $[x^U]\widetilde{F}=\dfrac{c}{1-[x^U]\widetilde{G}}$。
做 IFWT,得到:
$$f_S=\dfrac{1}{2^n} \sum_T (-1)^{|S\cap T|}[x^T]\widetilde{F}=\dfrac{[x^\varnothing]\widetilde{F}}{2^n}-\sum_{T\neq \varnothing} (-1)^{|S\cap T|} \dfrac{1}{1-[x^T]\widetilde{G}}$$
对于 $[x^{\varnothing}]\widetilde{F}$,利用 $f_{\varnothing}=0$ 可以得到:
$$\dfrac{[x^\varnothing]\widetilde{F}}{2^n}=\sum_{T\neq \varnothing} \dfrac{1}{1-[x^T]\widetilde{G}}$$
所以最终的答案就是:
$$\sum_{T\neq \varnothing} \left(1-(-1)^{|S\cap T|}\right) \dfrac{1}{1-[x^T]\widetilde{G}}=\sum_{T\neq \varnothing} [|S\cap T|\equiv1 \hspace{-0.444em} \pmod{2}]\dfrac{1}{\sum_{i\in T}p_i}$$
于是我们就得到了一个 $O(2^n)$ 的优秀做法!
然而这样还是有点慢啊!
注意到原题中 $p_i$ 的值域是比较小的,我们可以用 $dp[i][j][k]$ 表示前 $i$ 个数,和 $S$ 交集奇偶性为 $j$,当前总概率为 $k$ 的方案数。
最后对每种 $k$ 统计答案即可。
时间复杂度为 $O(n \sum p_i)$。在 $n$ 更大时可以进一步用 FFT 进行优化,但这里不需要。
代码如下:
```cpp
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
typedef long long ll;
inline void upd(int&a,const int&b){a+=b-mod;a+=a>>31&mod;}
const int M=50010;
const int N=200;
int ans,dp[2][M],n,inv[M];
pair<int,int>a[N];
int main(){
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++)cin>>a[i].second;
for(int i=1;i<=n;i++)cin>>a[i].first;
sort(a+1,a+n+1);
int sz=0;
dp[0][0]=1;
for(int i=1;i<=n;i++){
int op=a[i].second,p=a[i].first;
for(int j=sz;j>=0;--j){
upd(dp[op][j+p],dp[0][j]);
upd(dp[!op][j+p],dp[1][j]);
}
sz+=p;
}
inv[1]=1;
for(int i=2;i<=sz;++i)
inv[i]=mod-mod/i*(ll)inv[mod%i]%mod;
for(int i=1;i<=sz;++i)
ans=(ans+(ll)inv[i]*dp[1][i])%mod;
cout<<(ll)ans*sz%mod<<endl;
return 0;
}
```