我是靠谱客的博主 犹豫睫毛膏,最近开发中收集的这篇文章主要介绍【Codechef DEVLOCK Devu and Locks】【倍增二维FFT】题意分析代码,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

题意

求有多少个 n n n位十进制数(可以有前导零),满足模 p p p等于 0 0 0且各位数字之和不超过 m m m
n ≤ 1 0 9 , p ≤ 16 , m ≤ 15000 nle 10^9,ple 16, mle 15000 n109,p16,m15000

分析

注意到第 i i i位贡献的系数为 1 0 i   m o d   p 10^ibmod p 10imodp,两位的贡献不同当且仅当对应系数不同。因此可以把数位按照系数分类。设 n u m i num_i numi表示有多少位满足贡献系数为 i i i n u m i num_i numi可以通过求 1 0 k 10^k 10k p p p的周期来求出。

可以用倍增FFT求出 f i , j f_{i,j} fi,j表示选了 n u m i num_i numi 0 0 0 9 9 9之间的数,和为 j j j的方案数。显然 f i , j f_{i,j} fi,j的每种方案中,选取数字乘上对应系数之和模 p p p的值为 i j   m o d   p ijbmod p ijmodp。从而得到 g i , k , j g_{i,k,j} gi,k,j表示在系数为 i i i的位置中,选出来的数乘上系数之和模 p p p的值为 k k k,且选出来的数之和为 j j j的方案数。把 g 0 , ⋯   , g p − 1 g_0,cdots,g_{p-1} g0,,gp1通过二维FFT合并,就得到答案了。

二维FFT的实现方法是先对 A A A的第一维做DFT得到数组 B B B,再对 B B B的第二维做DFT得到数组 C C C。则 C C C就是二维DFT得到的数组。但这题里面 p p p比较小,因此第二维可以直接暴力卷积。

倍增FFT的时间复杂度为 O ( p m log ⁡ n log ⁡ m ) O(pmlog nlog m) O(pmlognlogm),二维FFT的时间复杂度为 O ( p 2 m log ⁡ m + p 3 m ) O(p^2mlog m+p^3m) O(p2mlogm+p3m),因此总的时间复杂度为 O ( p m log ⁡ n log ⁡ m + p 2 m log ⁡ m + p 3 m ) O(pmlog nlog m+p^2mlog m+p^3m) O(pmlognlogm+p2mlogm+p3m)

代码

#include<bits/stdc++.h>
#define pb push_back
using namespace std;

typedef long long LL;

const int N = 33005;
const int P = 55;
const int MOD = 998244353;

int n, p, m, num[P], bz[40][N], f[P][N], g[P][N], tmp[P][N], ans[P][N], L, rev[N];
vector<int> wn1[25], wn2[25];

int gcd(int x, int y)
{
	return !y ? x : gcd(y, x % y);
}

int ksm(int x, int y, int mo)
{
	int ans = 1;
	while (y)
	{
		if (y & 1) ans = (LL)ans * x % mo;
		x = (LL)x * x % mo; y >>= 1;
	}
	return ans;
}

void pre()
{
	int now[p], w = 1 % p, ls;
	memset(now, 0, sizeof(now));
	for (int i = 1; i <= n && !now[w]; i++, w = w * 10 % p) num[w]++, now[w] = i, ls = i;
	int T = ls - now[w] + 1, tmp = n - now[w] + 1 - T;
	for (int i = 0; i < T; i++, w = w * 10 % p) num[w] += tmp / T;
	for (int i = 0; i < tmp % T; i++, w = w * 10 % p) num[w]++;
	int lg = 0;
	for (L = 1; L <= m * 2; L <<= 1, lg++);
	for (int i = 0; i < L; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
	for (int i = 0; i < 20; i++)
	{
		int w1 = ksm(3, (MOD - 1) / (1 << i) / 2, MOD), w2 = ksm(3, MOD - 1 - (MOD - 1) / (1 << i) / 2, MOD);
		wn1[i].pb(1); wn2[i].pb(1);
		for (int j = 1; j < (1 << i); j++) wn1[i].pb((LL)wn1[i][j - 1] * w1 % MOD), wn2[i].pb((LL)wn2[i][j - 1] * w2 % MOD);
	}
}

void NTT(int * a, int f)
{
	for (int i = 0; i < L; i++) if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int i = 1, lg = 0; i < L; i <<= 1, lg++)
		for (int j = 0; j < L; j += (i << 1))
			for (int k = 0; k < i; k++)
			{
				int u = a[j + k], v = (LL)a[j + k + i] * (f == 1 ? wn1[lg][k] : wn2[lg][k]) % MOD;
				a[j + k] = (u + v) % MOD; a[j + k + i] = (u + MOD - v) % MOD;
			}
	if (f == -1) for (int i = 0, inv = ksm(L, MOD - 2, MOD); i < L; i++) a[i] = (LL)a[i] * inv % MOD;
}

void solve1()
{
	int mx = 0;
	for (int i = 0; i < p; i++) mx = max(mx, num[i]), f[i][0] = 1;
	for (int i = 0; i <= min(9, m); i++) bz[0][i] = 1;
	NTT(bz[0], 1);
	for (int i = 0; (1 << i) <= mx; i++)
	{
		for (int j = 0; j < p; j++)
			if (num[j] & (1 << i))
			{
				NTT(f[j], 1);
				for (int k = 0; k < L; k++) f[j][k] = (LL)f[j][k] * bz[i][k] % MOD;
				NTT(f[j], -1);
				for (int k = m + 1; k < L; k++) f[j][k] = 0;
			}
		if ((1 << (i + 1)) > mx) break;
		for (int j = 0; j < L; j++) bz[i + 1][j] = (LL)bz[i][j] * bz[i][j] % MOD;
		NTT(bz[i + 1], -1);
		for (int j = m + 1; j < L; j++) bz[i + 1][j] = 0;
		NTT(bz[i + 1], 1);
	}
}

void solve2()
{
	ans[0][0] = 1;
	for (int i = 0; i < p; i++)
	{
		for (int j = 0; j < p; j++) memset(g[j], 0, sizeof(g[j])), memset(tmp[j], 0, sizeof(tmp[j]));
		for (int j = 0; j <= m; j++)
			(g[j * i % p][j] += f[i][j]) %= MOD;
		for (int j = 0; j < p; j++) NTT(ans[j], 1), NTT(g[j], 1);
		for (int j = 0; j < p; j++)
			for (int k = 0, tar = j; k < p; k++, tar = (tar + 1) % p)
				for (int l = 0; l < L; l++)
					(tmp[tar][l] += (LL)ans[j][l] * g[k][l] % MOD) %= MOD;
		for (int j = 0; j < p; j++)
		{
			NTT(tmp[j], -1);
			for (int k = 0; k <= m; k++) ans[j][k] = tmp[j][k];
			for (int k = m + 1; k < L; k++) ans[j][k] = 0;
		}
	}
}

int main()
{
	scanf("%d%d%d", &n, &p, &m);
	pre();
	solve1();
	solve2();
	for (int i = 1; i <= m; i++) (ans[0][i] += ans[0][i - 1]) %= MOD;
	for (int i = 0; i <= m; i++) printf("%d ", ans[0][i]);
	return 0;
}

最后

以上就是犹豫睫毛膏为你收集整理的【Codechef DEVLOCK Devu and Locks】【倍增二维FFT】题意分析代码的全部内容,希望文章能够帮你解决【Codechef DEVLOCK Devu and Locks】【倍增二维FFT】题意分析代码所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(63)

评论列表共有 0 条评论

立即
投稿
返回
顶部