Luogu P5298 PKUWC2018 Minimax 题解 [ 紫 ] [ 树形 dp ] [ 线段树合并 ] [ 概率 dp ]
Minimax:线段树合并优化 dp 好题。树形 dp
因为要求出每一个值的出现概率,首先我们可以想到一个很暴力的 dp 式子。
定义 \(dp_{i,j}\) 表示在节点 \(i\) 时,权值 \(j\) 的出现概率,设 \(l\) 表示左儿子,\(r\) 表示右儿子,则有如下转移:
[*]当 \(j\) 在左儿子中时,\(dp_{i,j}\gets dp_{l,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p_i)\times\sum_{k=j+1}^{V}dp_{r,k})\),理解的话就是对父亲节点选大的还是选小的进行分讨。
[*]当 \(j\) 在右儿子中时,\(dp_{i,j}\gets dp_{r,j}\times(p_i\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p_i)\times\sum_{k=j+1}^Vdp_{l,k})\)。
直接转移即可,时间复杂度 \(O(nV)\)。
线段树合并优化
显然原来的时间复杂度会炸掉,但是我们发现每个节点最开始时最多只有一个 dp 位置是有值的,所以我们考虑用这种均摊复杂度的线段树合并来优化这个 dp。
因为 dp 转移的时候需要用到前缀和后缀和,所以我们进行 merge 的时候记录节点 \(x,y\) 的前缀和 \(px,py\) 与后缀和 \(sx,sy\) 以及父亲节点的概率 \(p\)。
梳理一下 merge 的流程:
[*]进入节点 \(x,y\)。
[*]如果 \(x,y\) 其中之一是空树,则说明直接更新 dp 值即可。
[*]若 \(x\) 是空树,对应着上述 \(j\) 在右儿子中的转移方式,则我们对 \(y\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{l,k}+(1-p)\times\sum_{k=j+1}^Vdp_{l,k})=(p\times px+(1-p)\times sx)\) 即可。这个可以用懒标记实现区间乘。
[*]若 \(y\) 是空树,对应着上述 \(j\) 在左儿子中的转移方式,则我们对 \(x\) 的整颗子树内的 dp 值全部乘上 \((p\times\sum_{k=1}^{j-1}dp_{r,k}+(1-p)\times\sum_{k=j+1}^Vdp_{r,k})=(p\times py+(1-p)\times sy)\) 即可。这个可以用懒标记实现区间乘。
[*]否则就说明要递归合并,递归左右儿子的时候记得更新 \(sx,sy,px,py\) 的值。
[*]最后将左右儿子的 dp 值加起来就是这个区间的 dp 值。
时间复杂度 \(O(n\log n)\)。
代码
#include <bits/stdc++.h>#define fi first#define se second#define eb(x) emplace_back(x)#define pb(x) push_back(x)#define lc(x) (tr.ls)#define rc(x) (tr.rs)using namespace std;typedef long long ll;typedef unsigned long long ull;typedef long double ldb;using pi=pair<int,int>;const int N=300005;const ll mod=998244353;int n,fa,m=0,b,son,cd,p,ans;ll qpow(ll a,ll b){ ll res=1; while(b) { if(b&1)res=(res*a)%mod; b>>=1; a=(a*a)%mod; } return res;}int getrk(int x){ return (lower_bound(b+1,b+m+1,x)-b);}struct Node{ int ls,rs; ll dp,tag=1;};struct Segtree{ Node tr; int root,tot=0; void pushup(int p) { tr.dp=(tr.dp+tr.dp)%mod; } void pushdown(int p) { if(tr.tag!=1) { tr.tag=(tr.tag*tr.tag)%mod; tr.tag=(tr.tag*tr.tag)%mod; tr.dp=(tr.dp*tr.tag)%mod; tr.dp=(tr.dp*tr.tag)%mod; } tr.tag=1; } void modify(int p,int v) { tr.dp=(tr.dp*1ll*v)%mod; tr.tag=(tr.tag*1ll*v)%mod; } void update(int &u,int ln,int rn,int x,ll k) { if(u==0)u=++tot; if(ln==rn){tr.dp+=k;return;} int mid=(ln+rn)>>1; if(x<=mid)update(lc(u),ln,mid,x,k); else update(rc(u),mid+1,rn,x,k); pushup(u); } int merge(int x,int y,int px,int py,int sx,int sy,int p) { if(x==0&&y==0)return 0; if(x==0) { modify(y,(1ll*p*px%mod+1ll*((1-p)%mod+mod)%mod*sx)%mod); return y; } if(y==0) { modify(x,(1ll*p*py%mod+1ll*((1-p)%mod+mod)%mod*sy)%mod); return x; } pushdown(x);pushdown(y); int lx=tr.dp,rx=tr.dp,ly=tr.dp,ry=tr.dp; tr.ls=merge(lc(x),lc(y),px,py,(sx+rx)%mod,(sy+ry)%mod,p); tr.rs=merge(rc(x),rc(y),(px+lx)%mod,(py+ly)%mod,sx,sy,p); pushup(x); return x; } void query(int u,int ln,int rn) { if(ln==rn){ans=tr.dp;return;} int mid=(ln+rn)>>1; pushdown(u); query(lc(u),ln,mid); query(rc(u),mid+1,rn); }}tr1;void dfs1(int u){ if(son==0) { tr1.update(tr1.root,1,m,getrk(p),1); return; } if(son==0) { dfs1(son); tr1.root=tr1.root]; return; } dfs1(son); dfs1(son); tr1.root=tr1.merge(tr1.root],tr1.root],0,0,0,0,p);}int main(){ //freopen("sample.in","r",stdin); //freopen("sample.out","w",stdout); ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin>>n; for(int i=1;i<=n;i++)cin>>fa; for(int i=1;i<=n;i++) { son]]]=i; cd]++; } for(int i=1;i<=n;i++) { cin>>p; if(cd)p=p*1ll*qpow(10000,mod-2)%mod; else b[++m]=p; } sort(b+1,b+m+1); m=unique(b+1,b+m+1)-b-1; dfs1(1); tr1.query(tr1.root,1,m); ll res=0; for(int i=1;i<=m;i++)res=(res+1ll*i*b%mod*ans%mod*ans%mod)%mod; cout<<res; return 0;}
页:
[1]