cf1622D 组合数/容斥
D - Shuffle
长度为n的01串,可以选择一段恰好含m个1的子串shuffle(重排),求一共能得到几个不一样的串
首先这种问题要么dp要么组合数,dp无解于是先n^2傻瓜预处理一个组合数。
考虑以 i 为左端点的最大合法区间,i点只有1和0两种情况
我们采取容斥的思想,计算到 以 i 为左端点时,我们只计算 i 取反的情况。因为 i 相同的情况显然会在 > i 的区间中计算到。
那么有个疑问点,我们只考虑 区间中恰好含 1 数量为 m 的时候会出现一个问题:
-
例如: 1100 (m==2) 。我们先计算了 i = 1, j = 4。让 i取反为 0, 则后面三位可以填2个1,C(3,2) = 3。
但是 i 为 1 的情况怎么计算呢?我们对于 k(当前1的数量) < m 且 j 已经到达边界的区间同样这样考虑即可。
i = 2, j = 4 时令 第2位取反为0,后面两位填1个1, C(2, 1) = 2。
答案加上原始区间就是 6。
问题就这样成功容斥解决了!
简单来说就是:保证左指针变化然后右边随便组合
/*******************************
| Author: koifish
| Problem: D. Shuffle
| Contest: Educational Codeforces Round 120 (Rated for Div. 2)
| URL: https://codeforces.com/contest/1622/problem/E
| When: 2021-12-27 22:35:29
|
| Memory: 256 MB
| Time: 2000 ms
*******************************/
#include <bits/stdc++.h>
#define ll long long
#define int ll
#define mp make_pair
#define fi first
#define se second
#define pb push_back
#define vi vector<int>
#define pi pair<int, int>
#define vpii vector<pi>
#define il inline
#define ri register
#define all(a) a.begin(), a.end()
#define fr(a) freopen(a, "r", stdin)
#define fo(a) freopen(a, "w", stdout);
#define mod 998244353
#define debug puts("------------------------")
#define lowbit(x) (x&-x)
template<typename T> bool chkmin(T &a, T b){return (b < a) ? a = b, 1 : 0;}
template<typename T> bool chkmax(T &a, T b){return (b > a) ? a = b, 1 : 0;}
ll ksm(ll a, ll b) {if (b == 0) return 1; ll ns = ksm(a, b >> 1); ns = ns * ns % mod; if (b & 1) ns = ns * a % mod; return ns;}
void Read(int &a) {a=0;int c=getchar(),b=1; while(c>'9'||c<'0') {if(c=='-')b=-1;c=getchar();} while(c>='0'&&c<='9') a=(a<<3)+(a<<1)+c-48,c=getchar();a*=b; }
int read() {int a=0,c=getchar(),b=1; while(c>'9'||c<'0') {if(c=='-')b=-1;c=getchar();} while(c>='0'&&c<='9') a=(a<<3)+(a<<1)+c-48,c=getchar();return a*=b; }
void write(int x) {if(x>9)write(x/10);putchar('0'+x%10);}
void W(int x) {if(x<0){putchar('-'),x=-x;}write(x);}
#define LOCAL
using namespace std;
const int maxn = 5005;
/**/
int n, m, sum[maxn], c[maxn][maxn];
char s[maxn];
/**/
il int calc(int i, int j) {
return (i < 0 || j < 0 || i < j) ? 0 : c[i][j];
}
signed main()
{
n = read(); m = read();
for(int i = 0; i <= n; i++) c[i][0] = 1;
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= i; j++) {
c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % mod;
}
}
scanf("%s", s + 1);
int ans = 1;
for(int i = 1; i <= n; i++) sum[i] = sum[i - 1] + (s[i] == '1');
if(sum[n] < m) {
puts("1");
return 0;
}
for(int i = 1, j = 0; i <= n; i++) {
while(j <= n && sum[j] - sum[i - 1] <= m) j++;
int len = (j-1) - i + 1;
int k = sum[j - 1] - sum[i - 1] - (s[i] == '0');
ans = (ans + calc(len - 1, k)) % mod;
}
cout << ans << '\n';
return 0;
}