01/16
21:03
OI

国王奇遇记 [线性插值]

国王奇遇记加强版之再加强版

本题有O(m^2)做法,网上很多,这里只有O(m)

题意:求

\sum_{i=1}^n i^m m^i

F_m(n)等于这个式子。经过大胆打表看题解,可以知道

F_m(n) = m^n P_m(n) – P_m(0)

其中P_mm次多项式。

下面开始不写_m了,懒w

这个结论虽然不知道怎么想出来的,但是可以证明,大约想法就是归纳+差分。

考虑怎么求P

F(n+1) – F(n) = (n+1)^m m^{n+1} = m^{n+1} P(n+1) – m^n P(n)

P(n+1) = \frac{P(n)}{m} + (m+1)^m

P(n)表示为AP(0)+B

因为Pm次多项式,做m+1次差分:

\sum_{i=0}^{m+1} \binom{m+1}{i} (-1)^{m – i} P(i) = 0

就可以解出A,B啦。

这篇文章中有一种给出1,m+1处点值插值的方法,推到用到了很多组合技巧,很有意思。

不过其实直接拉格朗日插值就行。。。结合一些预处理可以做到O(m)

#include <bits/stdc++.h>

const int N = 500000 + 233, P = 1e9 + 7;

typedef long long ll;

inline ll fpow(ll x, int y) {
  ll ret = 1;
  for ( ; y; y >>= 1, x = x * x % P)
    if (y & 1) ret = ret * x % P;
  return ret;
}

int n, m;
ll fac[N], inv[N], inv_num[N], pre[N], pow_m[N],
  A[N], B[N], p[N], pre_all[N], suf_all[N];

int prime[N], vis[N], tot;

inline void init() {
  pow_m[0] = pow_m[1] = 1;
  for (int i = 1; i <= m + 1; ++i) {
    if (!vis[i]) {
      prime[++tot] = i;
      pow_m[i] = fpow(i, m);
    }
    for (int j = 1; j <= tot && i * prime[j] <= m + 1; ++j) {
      vis[i * prime[j]] = 1;
      pow_m[i * prime[j]] = pow_m[i] * pow_m[prime[j]] % P;
      if (!(i % prime[j]))
        break;
    }
  }

  for (int i = fac[0] = 1; i <= m + 1; ++i)
    fac[i] = fac[i - 1] * i % P;
  inv[m + 1] = fpow(fac[m + 1], P - 2);
  for (int i = m + 1; i; --i)
    inv[i - 1] = inv[i] * i % P;
  inv_num[1] = 1;
  for (int i = 2; i <= m + 1; ++i)
    inv_num[i] = P - 1ll * P / i * inv_num[P % i] % P;

  for (int i = pre_all[0] = suf_all[m + 2] = 1; i <= m + 1; ++i)
    pre_all[i] = pre_all[i - 1] * (n - i) % P;
  for (int i = m + 1; i; --i)
    suf_all[i] = suf_all[i + 1] * (n - i) % P;
}

inline ll solve() {
  A[0] = 1, A[1] = inv_num[m], B[1] = pow_m[1];
  for (int i = 2; i <= m + 1; ++i) {
    A[i] = A[i - 1] * inv_num[m] % P;
    B[i] = (B[i - 1] * inv_num[m] % P + pow_m[i]) % P;
  }

  long long sum_A = 0, sum_B = 0;
  for (int i = 0; i <= m + 1; ++i) {
    long long tmp = fac[m + 1] * inv[i] % P * inv[m + 1 - i] % P;
    if ((m - i) & 1)
      tmp = P - tmp;
    sum_A = (sum_A + tmp * A[i]) % P;
    sum_B = (sum_B + tmp * B[i]) % P;
  }

  p[0] = P - sum_B * fpow(sum_A, P - 2) % P;
  for (int i = 1; i <= m + 1; ++i)
    p[i] = (A[i] * p[0] + B[i]) % P;

  long long ret = 0;

  for (int i = 1; i <= m + 1; ++i) {
    long long tmp = p[i];
    tmp = tmp * pre_all[i - 1] % P * suf_all[i + 1] % P;
    tmp = tmp * inv[i - 1] % P * inv[m + 1 - i] % P;
    if ((m + 1 - i) & 1)
      tmp = P - tmp;
    ret = (ret + tmp) % P;
  }

  ret = fpow(m, n) * ret % P - p[0];
  ret = (ret % P + P) % P;
  return ret;
}

int main() {
  std::cin >> n >> m;
  if (m == 1)
    return std::cout << 1ll * n * (n + 1) / 2 << std::endl, 0;
  init();
  if (m >= n) {
    long long ans = 0, tmp = 1;
    for (int i = 1; i <= n; ++i) {
      tmp = tmp * m % P;
      ans = (ans + pow_m[i] * tmp) % P;
    }
    std::cout << ans << std::endl;
  } else
    std::cout << solve() << std::endl;
  return 0;
}