题解 P5326 【[ZJOI2019]开关】

lyx_cjz

2019-05-08 08:42:36

Solution

首先容易想到用高斯消元法解决这个问题,这样的复杂度是 $O(8^n)$。 可以发现,这是对稀疏矩阵高斯消元,可以做到 $O(4^n)$。 当然,这都是过不了的。 我们来考虑优化这个消元,设 $f_{T}$ 表示对 $s_i=[i\in T]$ 时的答案。并假定 $\sum_{i=1}^n {p_i}=1$。 对 $T\neq \varnothing$,我们有: $$f_T=1+ \sum_{i=1}^n p_i f_{T\oplus \{i\}}$$ 我们观察到,这是一个异或卷积! 定义 $G=\sum_{i=1}^n p_ix^{\{i\}}$ 为集合幂级数,设 $F=\sum_T f_T x^T$ 也为集合幂级数。 于是 $F=\sum_T x^T + F \times G + c\cdot x^{\varnothing}$。 其中,$c$ 是由于 $T=\varnothing$ 时不能用上式处理所做的待定系数。 现在,我们需要求 $[x^S]F$,根据 FWT 的知识,我们对该式先做 FWT。 记 $\widetilde{F}$ 为 $F$ 的 FWT 结果。 对集合 $U$,该式在 $U$ 处的 FWT 为: $$[x^U]\widetilde{F} \times \left (1- [x^U]\widetilde{G}\right ) = \sum_T (-1)^{|T \cap U|}+ c$$ 当 $U=\varnothing$ 时,$[x^U]\widetilde{G}=1$,左式为 $0$。解得 $c=-2^n$。 对其他 $U$,$[x^U]\widetilde{G}< 1$,所以可以解出 $[x^U]\widetilde{F}=\dfrac{c}{1-[x^U]\widetilde{G}}$。 做 IFWT,得到: $$f_S=\dfrac{1}{2^n} \sum_T (-1)^{|S\cap T|}[x^T]\widetilde{F}=\dfrac{[x^\varnothing]\widetilde{F}}{2^n}-\sum_{T\neq \varnothing} (-1)^{|S\cap T|} \dfrac{1}{1-[x^T]\widetilde{G}}$$ 对于 $[x^{\varnothing}]\widetilde{F}$,利用 $f_{\varnothing}=0$ 可以得到: $$\dfrac{[x^\varnothing]\widetilde{F}}{2^n}=\sum_{T\neq \varnothing} \dfrac{1}{1-[x^T]\widetilde{G}}$$ 所以最终的答案就是: $$\sum_{T\neq \varnothing} \left(1-(-1)^{|S\cap T|}\right) \dfrac{1}{1-[x^T]\widetilde{G}}=\sum_{T\neq \varnothing} [|S\cap T|\equiv1 \hspace{-0.444em} \pmod{2}]\dfrac{1}{\sum_{i\in T}p_i}$$ 于是我们就得到了一个 $O(2^n)$ 的优秀做法! 然而这样还是有点慢啊! 注意到原题中 $p_i$ 的值域是比较小的,我们可以用 $dp[i][j][k]$ 表示前 $i$ 个数,和 $S$ 交集奇偶性为 $j$,当前总概率为 $k$ 的方案数。 最后对每种 $k$ 统计答案即可。 时间复杂度为 $O(n \sum p_i)$。在 $n$ 更大时可以进一步用 FFT 进行优化,但这里不需要。 代码如下: ```cpp #include<bits/stdc++.h> using namespace std; const int mod=998244353; typedef long long ll; inline void upd(int&a,const int&b){a+=b-mod;a+=a>>31&mod;} const int M=50010; const int N=200; int ans,dp[2][M],n,inv[M]; pair<int,int>a[N]; int main(){ ios::sync_with_stdio(0);cin.tie(0);cout.tie(0); cin>>n; for(int i=1;i<=n;i++)cin>>a[i].second; for(int i=1;i<=n;i++)cin>>a[i].first; sort(a+1,a+n+1); int sz=0; dp[0][0]=1; for(int i=1;i<=n;i++){ int op=a[i].second,p=a[i].first; for(int j=sz;j>=0;--j){ upd(dp[op][j+p],dp[0][j]); upd(dp[!op][j+p],dp[1][j]); } sz+=p; } inv[1]=1; for(int i=2;i<=sz;++i) inv[i]=mod-mod/i*(ll)inv[mod%i]%mod; for(int i=1;i<=sz;++i) ans=(ans+(ll)inv[i]*dp[1][i])%mod; cout<<(ll)ans*sz%mod<<endl; return 0; } ```