考虑一个子问题:求从某个点 \(u\) 能到达的点数。
如果要精确地计算出来,最优解法只能是 \(O(\frac{n^2}{w})\) 的 bitset。但是我们还没有利用到题目的性质,我们只需要判断一个点是否至多有一个点互不可达。
考虑拓扑排序的过程,队列里面的点两两互不可达。维护一个 \(f_u\) 表示从 \(u\) 能到达的点数。
若某个时刻队列点数 \(\ge 3\),那么这些点全寄了,不用管(也可以打标记);
若某个时刻队列点数 \(= 1\),那么队头的点可以到达剩下的所有点,直接累加。
若某个时刻队列点数 \(= 2\),设它们为 \(x, y\)。如果 \(y\) 的后继中存在一个点入度为 \(1\),那么说明它只能由 \(y\) 到达,此时给 \(x\) 打标记。如果不满足这个条件,就直接把剩下的点数累加到 \(f_x\)。
建反图再做一遍。时间复杂度 \(O(n + m)\)。
code
// Problem: F. Upgrading Cities
// Contest: Codeforces - Codeforces Round 520 (Div. 2)
// URL: https://codeforces.com/problemset/problem/1062/F
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 300100;
int n, m, f[maxn], ind[maxn];
bool vis[maxn];
vector<int> G[maxn], T[maxn];
void solve() {
scanf("%d%d", &n, &m);
while (m--) {
int u, v;
scanf("%d%d", &u, &v);
G[u].pb(v);
T[v].pb(u);
}
int cnt = 0;
queue<int> q;
for (int i = 1; i <= n; ++i) {
ind[i] = (int)T[i].size();
if (!ind[i]) {
q.push(i);
++cnt;
}
}
while (q.size()) {
int u = q.front();
q.pop();
if (q.empty()) {
f[u] += n - cnt;
} else if ((int)q.size() == 1) {
bool flag = 1;
for (int v : G[q.front()]) {
if (ind[v] == 1) {
flag = 0;
break;
}
}
if (flag) {
f[u] += n - cnt;
} else {
vis[u] = 1;
}
} else {
vis[u] = 1;
}
for (int v : G[u]) {
if (!(--ind[v])) {
q.push(v);
++cnt;
}
}
}
cnt = 0;
for (int i = 1; i <= n; ++i) {
ind[i] = (int)G[i].size();
if (!ind[i]) {
q.push(i);
++cnt;
}
}
while (q.size()) {
int u = q.front();
q.pop();
if (q.empty()) {
f[u] += n - cnt;
} else if ((int)q.size() == 1) {
bool flag = 1;
for (int v : T[q.front()]) {
if (ind[v] == 1) {
flag = 0;
break;
}
}
if (flag) {
f[u] += n - cnt;
} else {
vis[u] = 1;
}
} else {
vis[u] = 1;
}
for (int v : T[u]) {
if (!(--ind[v])) {
q.push(v);
++cnt;
}
}
}
int ans = 0;
for (int i = 1; i <= n; ++i) {
ans += (!vis[i] && f[i] >= n - 2);
}
printf("%d\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}