题目链接
洛谷:
Solution
精神污染....这玩意比数树还难写...就是窝太菜了代码过于冗长然后调试还写了两三K
思路据说比较套路??反正窝是不会
我们可以很容易的把答案写出来:
\[ ans_k=\frac{1}{nm}\sum_{i=1}^{n}\sum_{j=1}^m (a_i+b_j)^k \] 我们忽略掉那个\(nm\)的系数,最后乘回来就好了,然后化简下:\[ \begin{align} ans_k=&\sum_{i=1}^{n}\sum_{j=1}^{m}\sum_{x=0}^k\binom{k}{x}a_i^xb_i^{k-x}\\ =&\sum_{x=0}^k\binom{k}{x}\sum_{i=1}^{n}a_i^x\sum_{j=1}^{m}b_i^{k-x} \end{align} \] 显然这是个卷积形式,然后算法瓶颈就在如何求出\(f(x)=\sum_{i=1}^{n}a_i^x\)。我们写出这个玩意的生成函数:
\[ \begin{align} F(x)=&\sum_{i=0}^{\infty} f(i)x^i\\ =&\sum_{i=0}^{\infty}\sum_{j=1}^{n}a_j^ix^i\\ =&\sum_{j=1}^{n}\sum_{i=0}^{\infty} a_j^ix^i\\ =&\sum_{j=1}^{n}\frac{1}{1-a_jx} \end{align} \] 注意到:\[ (\ln (1-ax))'=\frac{-a}{1-ax} \] 即:\[ x(\ln(1-ax))'=1-\frac{1}{1-ax} \] 也就是说:\[ F(x)=\sum_{i=1}^{n}1-x(\ln(1-a_ix))'=n-x\sum_{i=1}^{n}(\ln(1-a_ix))' \] 注意到导数满足加法律,后面的在化一下就是:\[ \begin{align} F(x)=&n-x\left(\sum_{i=1}^{n}\ln (1-a_ix)\right)'\\ =&n-x\left(\ln \prod_{i=1}^{n}(1-a_ix)\right)' \end{align} \] 注意到里面的连乘形式可以分治\(\rm FFT\)在\(O(n\log ^2 n)\)解决,然后再照着式子算一下就好了,需要写个多项式求\(\ln\),注意要把前面忽略的东西弄回去。总复杂度\(O(n\log^2 n)\)。
代码大概还能凑合着看吧...
#includeusing namespace std;void read(int &x) { x=0;int f=1;char ch=getchar(); for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f; for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;}void print(int x) { if(x<0) putchar('-'),x=-x; if(!x) return ;print(x/10),putchar(x%10+48);}void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}#define lf double#define ll long long #define pii pair #define vec vector #define pb push_back#define mp make_pair#define fr first#define sc second#define FOR(i,l,r) for(register int i=l,r_##i=r;i<=r_##i;++i) const int maxn = 1e6+10;const int inf = 1e9;const lf eps = 1e-8;const int mod = 998244353;int a[maxn],b[maxn],fac[maxn],ifac[maxn],inv[maxn];int w[maxn],n,m,T,mxn,bit,N,tmp[15][maxn],pos[maxn];int add(int x,int y) {return x+y>mod?x+y-mod:x+y;}int del(int x,int y) {return x-y<0?x-y+mod:x-y;}int mul(int x,int y) {return 1ll*x*y-1ll*x*y/mod*mod;}int qpow(int aa,int x) { int res=1; for(;x;x>>=1,aa=mul(aa,aa)) if(x&1) res=mul(res,aa); return res;}void clear(int *l,int *r) { if(l>=r) return ; while(l!=r) *l++=0;*l=0;}void ntt_init(int len) { for(mxn=1;mxn<=len;mxn<<=1); w[0]=1;w[1]=qpow(3,(mod-1)/mxn); for(int i=2;i<=mxn;i++) w[i]=mul(w[i-1],w[1]); inv[0]=inv[1]=fac[0]=ifac[0]=1; for(int i=2;i<=mxn;i++) inv[i]=mul(mod-mod/i,inv[mod%i]); for(int i=1;i<=mxn;i++) fac[i]=mul(fac[i-1],i); for(int i=1;i<=mxn;i++) ifac[i]=mul(ifac[i-1],inv[i]);}void get(int len) {for(bit=0,N=1;N<=len;N<<=1,bit++);}void get_pos() {for(int i=1;i >1]>>1|((i&1)<<(bit-1));}void ntt(int *r,int op) { for(int i=1;i i) swap(r[i],r[pos[i]]); for(int i=1,d=mxn>>1;i <<=1,d>>=1) for(int j=0;j <<1) for(int k=0;k >1);get(len);get_pos(); for(int i=0;i