概述
树形DP的转移是一个卷积的转移形式
可以先链剖,一个点的轻儿子先合并,然后一条重链用分治FFT合并
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef vector<int> poly;
const int N=800010,P=998244353;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x){
char c=nc(); x=0;
for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}
int n,m,cnt,a[N],G[N],son[N],size[N],p[N],fa[N],t;
struct edge{
int t,nx;
}E[N<<2];
poly f[N],g[N];
int num,w[2][N],rev[N];
inline void addedge(int x,int y){
E[++cnt].t=y; E[cnt].nx=G[x]; G[x]=cnt;
E[++cnt].t=x; E[cnt].nx=G[y]; G[y]=cnt;
}
void pfs(int x,int f){
size[x]=1; p[++t]=x; fa[x]=f;
for(int i=G[x];i;i=E[i].nx)
if(E[i].t!=f){
pfs(E[i].t,x);
size[x]+=size[E[i].t];
if(size[E[i].t]>size[son[x]]) son[x]=E[i].t;
}
}
inline int Pow(int x,int y){
int ret=1;
for(;y;y>>=1,x=1LL*x*x%P) if(y&1) ret=1LL*x*ret%P;
return ret;
}
inline void Pre(const int &n){
num=n; int g=Pow(3,(P-1)/n);
w[0][0]=w[1][0]=1;
for(int i=1;i<n;i++) w[1][i]=1LL*w[1][i-1]*g%P;
for(int i=1;i<n;i++) w[0][i]=w[1][n-i];
}
inline void NTT(int *a,int n,int r){
for(int i=1;i<n;i++) if(rev[i]>i) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++){
int x=a[j+k],y=1LL*a[j+k+i]*w[r][num/(i<<1)*k]%P;
a[j+k]=(x+y)%P; a[j+k+i]=(x-y+P)%P;
}
if(!r) for(int i=0,inv=Pow(n,P-2);i<n;i++) a[i]=1LL*a[i]*inv%P;
}
poly operator *(poly a,poly b){
if(!a.size() || !b.size()) return a.size()?b:a;
poly ret;
if(a.size()+b.size()<500){
ret.resize(a.size()+b.size()-1);
for(int i=0;i<a.size();i++)
for(int j=0;j<b.size();j++)
ret[i+j]=(ret[i+j]+1LL*a[i]*b[j])%P;
return ret;
}
int n,L=0;
for(n=1;n<=a.size()+b.size();n<<=1,L++); L--;
for(int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<L);
static int tmpa[N],tmpb[N];
for(int i=0;i<a.size();i++) tmpa[i]=a[i];
for(int i=0;i<b.size();i++) tmpb[i]=b[i];
NTT(tmpa,n,1); NTT(tmpb,n,1);
for(int i=0;i<n;i++) tmpa[i]=1LL*tmpa[i]*tmpb[i]%P;
NTT(tmpa,n,0); ret.resize(a.size()+b.size()-1);
for(int i=0;i<ret.size();i++) ret[i]=tmpa[i];
for(int i=0;i<n;i++) tmpa[i]=tmpb[i]=0;
return ret;
}
poly operator +(poly a,poly b){
poly ret; ret.resize(max(a.size(),b.size()));
for(int i=0;i<a.size();i++) ret[i]=a[i];
for(int i=0;i<b.size();i++) ret[i]=(ret[i]+b[i])%P;
return ret;
}
struct polyc{
poly a00,a01,a10,a11;
polyc(){}
polyc(poly a,poly b):a00(a),a11(b){}
int size(){ return max(max(a00.size(),a01.size()),max(a10.size(),a11.size())); }
friend polyc operator *(polyc a,polyc b){
polyc ret;
ret.a00=a.a00*b.a00+a.a01*b.a00+a.a00*b.a10;
ret.a01=a.a00*b.a01+a.a00*b.a11+a.a01*b.a01;
ret.a10=a.a10*b.a00+a.a10*b.a10+a.a11*b.a00;
ret.a11=a.a10*b.a01+a.a10*b.a11+a.a11*b.a01;
return ret;
}
friend bool operator <(polyc a,polyc b){
return a.size()>b.size();
}
};
struct polypair{
poly a,b;
polypair(){}
polypair(poly _a,poly _b):a(_a),b(_b){}
friend polypair operator *(polypair a,polypair b){
return polypair(a.a*b.a,a.b*b.b);
}
friend bool operator <(polypair a,polypair b){
return a.a.size()>b.a.size();
}
};
namespace HuffmanFFT{
priority_queue<polypair> a;
void Push(poly _a,poly _b){
a.push(polypair(_a,_b));
}
polypair work(){
while(a.size()>1){
polypair A=a.top(); a.pop();
polypair B=a.top(); a.pop();
a.push(A*B);
}
polypair ret=a.top(); a.pop();
return ret;
}
}
namespace DivAndConq{
vector<polyc> a;
void Push(poly _a,poly _b){ a.push_back(polyc(_b,_a)); }
void Clear(){ a.clear(); }
polyc solve(int l=0,int r=a.size()-1){
if(l==r) return a[l];
int mid=l+r>>1;
return solve(l,mid)*solve(mid+1,r);
}
}
inline void solve(int x){
DivAndConq::Clear();
for(int u=x;u;u=son[u]){
for(int i=G[u];i;i=E[i].nx)
if(E[i].t!=fa[u] && E[i].t!=son[u])
HuffmanFFT::Push(f[E[i].t]+g[E[i].t],g[E[i].t]);
polypair cur; if(HuffmanFFT::a.size()) cur=HuffmanFFT::work();
poly U; U.push_back(0); U.push_back(a[u]);
if(cur.b.size()) cur.b=cur.b*U; else cur.b=U;
if(!cur.a.size()) cur.a.push_back(1);
DivAndConq::Push(cur.b,cur.a);
}
polyc cur=DivAndConq::solve();
f[x]=cur.a10+cur.a11; g[x]=cur.a00+cur.a01;
}
int main(){
read(n); read(m);
int _m; for(_m=1;_m<=n;_m<<=1); Pre(_m);
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1,x,y;i<n;i++)
read(x),read(y),addedge(x,y);
pfs(1,0);
for(int i=t;i;i--)
if(son[fa[p[i]]]!=p[i]) solve(p[i]);
poly ans=f[1]+g[1];
if(ans.size()>m) printf("%dn",ans[m]);
else puts("0");
return 0;
}
最后
以上就是愉快樱桃为你收集整理的[链剖 FFT] LOJ#6289. 花朵的全部内容,希望文章能够帮你解决[链剖 FFT] LOJ#6289. 花朵所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复