12/2
22:02
OI

CSP-S 2019 题解

格雷码

签到题。模拟题意,递归做就行了。

括号树

考虑序列上的DP 令f_i,g_i分别表示[1,i]的答案和以i的括号结尾的合法对有多少。用栈维护括号序列。如果是左括号,继承之前的f,否则记栈顶为j

g_i = g_{j-1} + 1

f_i=f_{i-1}+g_i

把这个序列上的简单DP丢到树上即可。因为每次只对栈操作一下,用一个变量记录操作即可。

树上的数

神仙题,先咕了。

Emiya家今天的饭

我们记S_i= \sum_{j=1}^n A_{i,j}

考虑一个O(n^3m)的DP。先求出总方案数,减去不合法方案。发现只会有一列超出限制。枚举列x,令f_{i,j,k}表示前i行,总共,这一列分别选了j,k个的方案数。

f_{i,j,k}=f_{i-1,j,k}+f_{i-1,j-1,k-1} A_{i,x} + f_{i-1,j-1,k} (S_i-A_{i,x})

实际上并不需要知道两种具体有多少,只需要知道差值。考虑O(n^2m),定义f_{i,j}为前i行,x列减其他列为j

f_{i,j}=f_{i-1,j}+ f_{i-1,j-1}A_{i,x}+f_{i-1,j+1}(S_i-A_{i,x})

划分

可以猜到一个结论:最优方案中,最后一段的和最小。结论看上去是比较显然的,然而具体证明比较麻烦。可以考虑使用反证+数学归纳,或者用毛爷爷的官方题解的方法

有了这个结论,就可以考虑DP了。定义f_i为以i结尾的最优方案,最后一段的开始位置-1的位置。

记前缀和为s,如果j可以转移到i,有

s_i – s_j \geq s_j – s_{f_j} \rightarrow s_i \geq 2s_j – s_{f_j}

用单调队列维护就行了。

因为卡空间,最后用f构造答案。用两个long long压在一起当高精度。考完写的时候我直接用了__int128。

树的重心

一种思路是发现重心在重链上,考虑倍增,枚举断边,分成子树和非子树计算。

另一种思路,是枚举每个点,计算它是多少方案的重心。记录枚举的点的最大,次大儿子为x,y。记子树大小为S,删掉了k个点。如果删去的是非x内的边,有:

S_x \leq \lfloor \frac{n-k}{2} \rfloor \rightarrow k \leq n – 2 S_x

如果删的是x内的边,最大的儿子可能变成y。有:

S_x – k \leq \lfloor \frac{n-k}{2} \rfloor, S_y \leq \lfloor \frac{n-k}{2} \rfloor

\rightarrow 2S_x – n \leq k \leq n – 2 S_y

于是可以考虑维护每个子树内可以删的点大小是什么。这个可以通过线段树合并简单维护。

考虑计算原树子树内的一个点,它的儿子会加上父亲所在的那个树。所有子树可用的大小,可以通过树状数组动态维护。父亲的即是删掉其它儿子剩下的那一部分。


下面是Day2三道题的代码

#include <bits/stdc++.h>

inline int rd() {
  int a = 1, b = 0; char c = getchar();
  while (!isdigit(c)) a = c == '-' ? 0 : 1, c = getchar();
  while (isdigit(c)) b = b * 10 + c - '0', c = getchar();
  return a ? b : -b;
}

const int N = 105, M = 2005, P = 998244353;

typedef long long ll;

int n, m, A[N][M], S[N], f[N][N * 2], ans = 1;

inline int calc(int x) {
  memset(f, 0, sizeof(f));
  f[0][n + 10] = 1;
  for (int i = 1; i <= n; ++i)
    for (int j = n + 10 - i; j <= n + 10 + i; ++j)
      f[i][j] = ((f[i - 1][j] + (ll)f[i - 1][j - 1] * A[i][x] % P
        + (ll)f[i - 1][j + 1] * (S[i] - A[i][x]) % P) % P + P) % P;
  int ret = 0;
  for (int i = 1; i <= n; ++i)
    ret = (ret + f[n][i + n + 10]) % P;
  return ret;
}

int main() {
  n = rd(), m = rd();
  for (int i = 1; i <= n; ++i)
    for (int j = 1; j <= m; ++j)
      S[i] = (S[i] + (A[i][j] = rd())) % P;
  for (int i = 1; i <= n; ++i)
    ans = (ll)ans * (S[i] + 1) % P;
  for (int i = 1; i <= m; ++i)
    ans = (ans - calc(i) + P) % P;
  ans = (ans - 1 + P) % P;
  printf("%d\n", ans);
  return 0;
}
#include <bits/stdc++.h>

inline int rd() {
  int a = 1, b = 0; char c = getchar();
  while (!isdigit(c)) a = c == '-' ? 0 : 1, c = getchar();
  while (isdigit(c)) b = b * 10 + c - '0', c = getchar();
  return a ? b : -b;
}

void write(__int128 x) {
  if (!x) return;
  write(x / 10);
  putchar(x % 10 + '0');
}

const int N = 4e7 + 233, M = 1e5 + 233, MOD = 1 << 30;
typedef long long ll;

int n, type, P[M], L[M], R[M];
ll sum[N], B[N]; int que[N], head, tail, f[N];
__int128 ans;

int main() {
  n = rd(), type = rd();
  if (type == 0) {
    for (int i = 1; i <= n; ++i)
      sum[i] = sum[i - 1] + rd();
  } else {
    int x = rd(), y = rd(), z = rd(), m, now = 1;
    B[1] = rd(), B[2] = rd(), m = rd();
    for (int i = 1; i <= m; ++i)
      P[i] = rd(), L[i] = rd(), R[i] = rd();
    for (int i = 3; i <= n; ++i)
      B[i] = (x * B[i - 1] + y * B[i - 2] + z) % MOD;
    for (int i = 1; i <= n; ++i) {
      if (i > P[now]) ++now;
      sum[i] = sum[i - 1] + (B[i] % (R[now] - L[now] + 1)) + L[now];
    }
  }
  for (int i = 1; i <= n; ++i) {
    while (head != tail && sum[i] >= 2 * sum[que[head + 1]] - sum[f[que[head + 1]]])
      ++head;
    f[i] = que[head];
    while (head != tail && 2 * sum[i] - sum[f[i]] <= 2 * sum[que[tail]] - sum[f[que[tail]]])
      --tail;
    que[++tail] = i;
  }
  for ( ; n; n = f[n])
    ans += (__int128)(sum[n] - sum[f[n]]) * (sum[n] - sum[f[n]]);
  write(ans), putchar('\n');
  return 0;
}
#include <bits/stdc++.h>

inline int rd() {
  int a = 1, b = 0; char c = getchar();
  while (!isdigit(c)) a = c == '-' ? 0 : 1, c = getchar();
  while (isdigit(c)) b = b * 10 + c - '0', c = getchar();
  return a ? b : -b;
}

const int N = 3e5 + 2333;

typedef long long ll;

int n;
long long ans;

struct Graph {
  int to, nxt;
} G[N * 2];
int head[N], tot;

inline void addedge(int x, int y) {
  G[++tot].to = y, G[tot].nxt = head[x],
  head[x] = tot;
}

int ls[N * 50], rs[N * 50], sum[N * 50], num, root[N];

int insert(int x, int L = 1, int R = n) {
  int p = ++num; sum[p] = 1;
  if (L != R) {
    int mid = (L + R) >> 1;
    if (x <= mid) ls[p] = insert(x, L, mid);
    else rs[p] = insert(x, mid + 1, R);
  }
  return p;
}

int merge(int x, int y) {
  if (!x || !y) return x + y;
  int p = ++num;
  sum[p] = sum[x] + sum[y];
  ls[p] = merge(ls[x], ls[y]);
  rs[p] = merge(rs[x], rs[y]);
  return p;
}

int query(int p, int l, int r, int L, int R) {
  if (!p) return 0;
  if (l <= L && r >= R)
    return sum[p];
  int mid = (L + R) >> 1, ret = 0;
  if (l <= mid)
    ret += query(ls[p], l, r, L, mid);
  if (r > mid)
    ret += query(rs[p], l, r, mid + 1, R);
  return ret;
}

int query(int p, int l, int r) {
  if (l > r) return 0;
  return query(p, l, r, 1, n);
}

int bit[N];

inline void add(int x, int y) {
  for ( ; x <= n; x += x & -x)
    bit[x] += y;
}

inline int ask(int x) {
  int ret = 0;
  for ( ; x; x -= x & -x)
    ret += bit[x];
  return ret;
}

inline int ask(int l, int r) {
  if (l > r)
    return 0;
  return ask(r) - ask(l - 1);
}

int left_x[N], right_x[N], left_nx[N], right_nx[N];
int max[N][2], size[N];

void dfs1(int x, int fa) {
  size[x] = 1;

  for (int i = head[x]; i; i = G[i].nxt) {
    int y = G[i].to;
    if (y != fa) {
      dfs1(y, x);
      size[x] += size[y];
      if (size[y] > size[max[x][0]]) {
        if (size[max[x][0]] > size[max[x][1]])
          max[x][1] = max[x][0];
        max[x][0] = y;
      } else {
        if (size[y] > size[max[x][1]])
          max[x][1] = y;
      }
    }
  }

  int mx = size[max[x][0]], mx2 = size[max[x][1]];

  if (n - size[x] > size[max[x][0]]) {
    if (size[max[x][0]] > size[max[x][1]]) {
      max[x][1] = max[x][0];
      mx2 = size[max[x][1]];
    }
    max[x][0] = fa;
    mx = n - size[x];
  } else {
    if (n - size[x] > size[max[x][1]]) {
      max[x][1] = fa;
      mx2 = n - size[x];
    }
  }

  if (fa) add(size[x], 1);

  left_nx[x] = 1;
  right_nx[x] = n - 2 * mx;
  left_x[x] = std::max(2 * mx - n, 1);
  right_x[x] = n - 2 * mx2;

  root[x] = insert(size[x]);
  for (int i = head[x]; i; i = G[i].nxt) {
    int y = G[i].to;
    if (y != fa) {
      root[x] = merge(root[x], root[y]);
    }
  }

  if (max[x][0] != fa) {
    for (int i = head[x]; i; i = G[i].nxt) {
      int y = G[i].to;
      if (y == max[x][0]) {
        ans -= (ll)x * query(root[y], left_nx[x], right_nx[x]);
        ans += (ll)x * query(root[y], left_x[x], right_x[x]);
      }
    }
  } else {
    for (int i = head[x]; i; i = G[i].nxt) {
      int y = G[i].to;
      if (y != fa) {
        ans -= (ll)x * query(root[y], left_x[x], right_x[x]);
        ans += (ll)x * query(root[y], left_nx[x], right_nx[x]);
      }
    }
  }
}

void dfs2(int x, int fa) {
  if (max[x][0] == fa)
    ans += (ll)x * ask(left_x[x], right_x[x]);
  else
    ans += (ll)x * ask(left_nx[x], right_nx[x]);

  for (int i = head[x]; i; i = G[i].nxt) {
    int y = G[i].to;
    if (y != fa) {
      add(n - size[y], 1);
      add(size[y], -1);
      dfs2(y, x);
      add(n - size[y], -1);
      add(size[y], 1);
    }
  }
}

inline void solve() {
  for (int i = 1; i <= n; ++i) {
    head[i] = root[i] = size[i] = left_x[i] = left_nx[i] = right_x[i] =
      right_nx[i] = max[i][0] = max[i][1] = bit[i] = 0;
  }
  for (int i = 1; i <= num; ++i)
    ls[i] = rs[i] = sum[i] = 0;
  ans = tot = num = 0;

  n = rd();
  for (int i = 1; i < n; ++i) {
    int x = rd(), y = rd();
    addedge(x, y);
    addedge(y, x);
  }
  dfs1(1, 0);
  dfs2(1, 0);
  std::cout << ans << std::endl;
}

int main() {
  for (int T = rd(); T; --T)
    solve();
  return 0;
}