C Forgiving Matching
解题思路
为了求出答案我们需要统计$S$的每个子串对于$L$的失配字符的数量。
我们可以反过来想为什么我们不先求出$S$每一个字串和$L$匹配上的字符的数量。
我们令$f(i)$为子串$S(i\cdots i+m-1)$对$L$匹配上的字符的数量。
我们可以考虑每个字符的匹配情况,然后把它加起来就得到了$f(i)$
假定我们当前考虑的字符为$c$,我们令$A_i=(S_i=c)$,$B_i=(L_i=c)$,那么
$$
f(i)=\sum_{j=1}^{m}A(i+j-1)B(j)
$$
我们将$B$翻转后变成$B'$
$$
f(i)=\sum_{j=1}^{m}A(i+j-1)B'(m-j+1)
$$
这很显然可以通过FFT计算得到。
接下来考虑通配符的情况
一个子串通过通配符匹配到的字符数量为(S对应子串中*的数量+L中*的数量-S和L*匹配上的数量)。就是一个简单的容斥,匹配数量可以用上面的方法也可以直接算出来。
参考代码
#include <bits/stdc++.h>
#define DEBUG puts("Here is a BUG")
#define MEM(a,b) memset((a),(b),sizeof(a))
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define PI 3.1415926535897932626
#define all(a) a.begin(),a.end()
#define see(a) cerr<<(#a)<<" = "<<(a)<<endl
#define yesno(a) if(a)cout<<"YES"<<endl;else cout<<"NO"<<endl;
typedef long long ll;
using namespace std;
const double eps=1e-8;
const int MAXN=(int)1e6+5;
const int MOD=998244353;
const int INF=0x3f3f3f3f;
const int dx[]={-1,1,0,0};
const int dy[]={0,0,-1,1};
void NTT(int *a,int n,int op,int G=3) {
auto powmod=[](int a,int b){int res=1;a%=MOD; assert(b>=0); for(;b;b>>=1){if(b&1)res=1ll*res*a%MOD;a=1ll*a*a%MOD;}return res;};
static int r[MAXN];
static int rn=0;
if(rn!=n)
{
int len=__builtin_ffs(n)-1;
for(int i=1;i<n;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(len-1));
rn=n;
}
for(int i=0;i<n;++i) if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<n;i<<=1) {
int gn=powmod(G,(MOD-1)/(i<<1));
for(int j=0;j<n;j+=(i<<1)) {
int t1,t2,g=1;
for(int k=0;k<i;++k,g=1LL*g*gn%MOD) {
t1=a[j+k],t2=1LL*g*a[j+k+i]%MOD;
a[j+k]=(t1+t2)%MOD,a[j+k+i]=(t1-t2+MOD)%MOD;
}
}
}
if(op==1) return;
int ny=powmod(n,MOD-2); reverse(a+1,a+n);
for(int i=0;i<n;++i) a[i]=1LL*a[i]*ny%MOD;
}
int a[MAXN],b[MAXN];
int f[MAXN];
char S[MAXN],L[MAXN];
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
#ifndef ONLINE_JUDGE
//freopen("0.in","r",stdin);
#endif
int n,m;
int T;
cin>>T;
while(T--)
{
cin>>n>>m;
cin>>S>>L;
memset(f,0,sizeof(int)*(n+m));
int lim=1,len=0;
while(lim<(n+m))lim<<=1,len++;
reverse(L,L+m);
for (int i = 0; i < 10; ++i)
{
memset(a,0,sizeof(int)*(lim));
memset(b,0,sizeof(int)*(lim));
rep(j,0,n)a[j]=(S[j]==(i+'0'));
rep(j,0,m)b[j]=(L[j]==(i+'0'));
NTT(a,lim,1);
NTT(b,lim,1);
rep(j,0,lim)a[j]=(1ll*a[j]*b[j])%MOD;
NTT(a,lim,-1);
rep(j,m-1,n)f[j]+=a[j];
}
//处理通配符
int cnt1=0,cnt2=0;
rep(i,0,m)cnt1+=(L[i]=='*');
rep(i,0,m-1)cnt2+=(S[i]=='*');
rep(i,m-1,n)
{
cnt2+=(S[i]=='*');
f[i]+=cnt2+cnt1;
cnt2-=(S[i-m+1]=='*');
}
memset(a,0,sizeof(int)*(lim));
memset(b,0,sizeof(int)*(lim));
rep(j,0,n)a[j]=(S[j]=='*');
rep(j,0,m)b[j]=(L[j]=='*');
NTT(a,lim,1);
NTT(b,lim,1);
rep(j,0,lim)a[j]=(1ll*a[j]*b[j])%MOD;
NTT(a,lim,-1);
rep(j,m-1,n)f[j]-=a[j];
memset(b,0,sizeof(int)*(m+1));
rep(i,m-1,n)b[m-f[i]]++;
int cnt=0;
rep(i,0,m+1)
{
cnt+=b[i];
cout<<cnt<<endl;
}
}
return 0;
}
I.Rise in Price
解题思路
三维dp乱搞,每次取前100个最大的转移。
就是有点卡常,用vector
炸了😅。
参考代码
#include <bits/stdc++.h>
#define DEBUG puts("Here is a BUG")
#define MEM(a,b) memset((a),(b),sizeof(a))
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define PI 3.1415926535897932626
#define all(a) a.begin(),a.end()
#define see(a) cerr<<(#a)<<" = "<<(a)<<endl
#define yesno(a) if(a)cout<<"YES"<<endl;else cout<<"NO"<<endl;
typedef long long ll;
using namespace std;
const double eps=1e-8;
const int MAXN=(int)1e2+5;
const int MOD=(int)1e9+7;
const int INF=0x3f3f3f3f;
const int dx[]={-1,0};
const int dy[]={0,-1};
pair<int,int> dp[MAXN][MAXN][MAXN*50],tmp[MAXN];
int cnt[MAXN][MAXN],ct;
int a[MAXN][MAXN],b[MAXN][MAXN];
bool cmp(pair<int,int> x,pair<int,int> y)
{
return 1ll*x.first*x.second>1ll*y.first*y.second;
}
int main()
{
int T,n;
scanf("%d",&T);
while(T--)
{
scanf("%d",&n);
rep(i,0,n)rep(j,0,n)cnt[i][j]=0;
rep(i,0,n)rep(j,0,n)scanf("%d",&a[i][j]);
rep(i,0,n)rep(j,0,n)scanf("%d",&b[i][j]);
dp[0][0][cnt[0][0]++]=make_pair(b[0][0],a[0][0]);
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
if(i==0&&j==0)continue;
for (int k = 0; k < 2; ++k)
{
int x=i+dx[k];
int y=j+dy[k];
if (x>=0&&y>=0)
{
for (int z=0;z<cnt[x][y];z++)
dp[i][j][cnt[i][j]++]=make_pair(dp[x][y][z].first+b[i][j],dp[x][y][z].second+a[i][j]);
}
}
sort(dp[i][j],dp[i][j]+cnt[i][j],cmp);
cnt[i][j]=min(cnt[i][j],100);
}
}
printf("%lld\n",1ll*dp[n-1][n-1][0].first*dp[n-1][n-1][0].second);
}
return 0;
}