Treeone’s Blog

競技プログラミングの記録など

ARC067 E Grouping

初投稿。それなりに頑張って解いたので自分用のメモとして書いた。

問題概要

1からNまでの番号がついた人がN人いる。以下の条件を満たすようなグループ分けは何通りあるかを求め、1000000007で割った余りを出力する。

  • 各グループの人数はA人以上B人以下
  • 任意のiについて、ちょうどi人が含まれるグループの数は0またはC以上D以下


制約
1 \leqq n \leqq 1000

解法

動的計画法(DP)を用いて解く。
i人未満からなるグループのみでj人をグループ分けした時の場合の数をdp[i][j]通りとする。
残りのn-j人の中から、i人グループをk個作る時( k=0 or C \leqq k \leqq D)、どのようにdp配列を更新すればよいだろうか。
まず、k=0の場合は簡単で、何もしないので dp[i+1][j]+=dp[i][j]である。
次に、C \leqq k \leqq D かつ i \times k \leqq n-jを満たしている場合、以下の式のようにdp配列が更新される。
dp[i+1][j+i\times k]+=dp[i][j]\times{}_{n-j}C_i \times _{n-j-i}C_i \times ... \times _{n-j-i\times (k-1)}C_i\times \frac{1}{k!}
ただし、全て1000000007で割った余りを取ることに注意する。
このDPの計算量は一見O(N^3)に思われるが、k0 \leqq k \leqq \frac{n-j}{i}の範囲しか取らないので、全体の計算量はO(N^2 log N)となる。*1

ソースコード

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const ll mod = 1e9 + 7;
ll n, a, b, c, d;
ll dp[1010][1010];
ll fact[1010], fact_inv[1010], inv[1010];

long long mod_pow(long long x, long long n){
    if(n == 0) return 1;
    long long res = mod_pow(x * x % mod, n / 2);
    if(n & 1) res = res * x % mod;
    return res;
}

long long combination(long long n, long long k){
    return fact[n] * fact_inv[k] % mod * fact_inv[n - k] % mod;
}

int main(){
    cin >> n >> a >> b >> c >> d;
    
    memset(dp, 0, sizeof(dp));
    for(int i = 0; i < 1010; i++) dp[i][0] = 1;

    fact[0] = fact_inv[0] = 1;
    for(int i = 1; i < 1010; i++){
        inv[i] = mod_pow(i, mod - 2);
        fact[i] = fact[i - 1] * i % mod;
        fact_inv[i] = fact_inv[i - 1] * inv[i] % mod;
    }

    for(int i = a; i <= b; i++){
        for(int j = 0; j <= n; j++){
            if(dp[i][j] == 0) continue;
            if(j != 0) (dp[i + 1][j] += dp[i][j]) %= mod;
            ll p = 1;            
            for(int k = 1; k <= d; k++){
                if(j + i * k > n) break;
                p = p * combination(n - j - i * (k - 1), i) % mod * inv[k] % mod;
                if(c <= k && k <= d){
                    (dp[i + 1][j + i * k] += dp[i][j] * p % mod) %= mod;
                }
            }
        }
    }
    cout << dp[b + 1][n] << endl;
}

典型DPが解けるようになりたい

*1:\sum_{k=1}^{N} \frac{1}{k} \approx logNという近似を用いた