Codeforces
CF Step
Youtube Linkedin Discord Toggle Dark/Light/Auto mode Toggle Dark/Light/Auto mode Toggle Dark/Light/Auto mode Back to homepage

Code : Prologue to Leaf Color

#include <atcoder/modint>
#include <bits/stdc++.h>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

class Tree {
  public:
    int n;
    vector<vector<int>> adj;
    vector<vector<mint>> dp;
    vector<int> a;
    // dp[i][d] is the count of induced connected subgraph such that the highest
    // vertex is node i and it has degree d, and all vertices with degree 1
    // except node i has the chosen color.
    // All degrees >= 2 are treated as identical.
    Tree(int n) {
        this->n = n;
        adj.resize(n);
        a.resize(n);
        dp.resize(n, vector<mint>(3, 0));
    }

    void dfs(int src, int par, int color) {
        // Start with no children.
        dp[src][0] = 1;
        for (auto child : adj[src]) {
            if (child == par) {
                continue;
            }
            dfs(child, src, color);
            vector<mint> ndp(3);
            for (int d = 0; d < 3; d++) {
                // Ignore this children.
                ndp[d] += dp[src][d];
                for (int now = 0; now < 3; now++) {
                    if (now == 0 && a[child] != color) {
                        continue;
                    }
                    int nxt = d + 1;
                    if (nxt > 1) {
                        nxt = 2;
                    }
                    // Append this children.
                    ndp[nxt] += dp[src][d] * dp[child][now];
                }
            }
            swap(dp[src], ndp);
        }
    }

    void clear_dp() {
        for (int i = 0; i < n; i++) {
            for (int d = 0; d < 3; d++) {
                dp[i][d] = 0;
            }
        }
    }
};

void solve() {
    int n;
    cin >> n;

    Tree t(n);
    auto &a = t.a;
    for (int i = 0; i < n; i++) {
        cin >> a[i];
        a[i]--;
    }

    int m = n - 1;
    for (int i = 0; i < m; i++) {
        int x, y;
        cin >> x >> y;
        x--;
        y--;
        t.adj[x].push_back(y);
        t.adj[y].push_back(x);
    }

    mint ans = 0;
    auto &dp = t.dp;
    for (int color = 0; color < n; color++) {
        // Remember to clear DP before running DFS for new color.
        t.clear_dp();
        t.dfs(0, -1, color);
        mint cur = 0;
        for (int i = 0; i < n; i++) {
            if (a[i] == color) {
                cur += dp[i][0];
                cur += dp[i][1];
            }
            cur += dp[i][2];
        }
        ans += cur;
        cout << cur.val() << " ";
    }
    cout << endl;
    cout << ans.val() << endl;
}

int main() {
    int t;
    cin >> t;
    for (int i = 0; i < t; i++) {
        solve();
    }
    return 0;
}