Codeforces
CF Step
Youtube Linkedin Discord Toggle Dark/Light/Auto mode Toggle Dark/Light/Auto mode Toggle Dark/Light/Auto mode Back to homepage

Code : Sasha and a Walk in the City

#include <atcoder/modint>
#include <bits/stdc++.h>

using namespace std;
using namespace atcoder;

using mint = modint998244353;

class Tree {
  public:
    int n;
    vector<vector<int>> adj;
    vector<vector<vector<mint>>> dp;
    vector<vector<mint>> sum;
    // dp[i][d] is the number of colorings of the i-th subtree when the
    // maximum black nodes on the upward path to i is d.
    Tree(int n) {
        this->n = n;
        adj.resize(n);
        dp.resize(n, vector<vector<mint>>(3, vector<mint>(2, 0)));
        sum.resize(n, vector<mint>(3, 0));
    }

    void dfs(int src, int par) {
        // Start with no children.
        dp[src][0][0] = 1;
        dp[src][1][1] = 1;
        for (auto child : adj[src]) {
            if (child == par) {
                continue;
            }
            dfs(child, src);
            vector<vector<mint>> ndp(3, vector<mint>(2, 0));
            for (int rc = 0; rc <= 1; rc++) {
                for (int d = 0; d < 3; d++) {
                    for (int now = 0; now < 3; now++) {
                        if (d + now >= 3) {
                            continue;
                        }
                        int nxt = max(now + rc, d);
                        if (nxt < 3) {
                            // Append this children.
                            ndp[nxt][rc] += dp[src][d][rc] * sum[child][now];
                        }
                    }
                }
            }
            swap(dp[src], ndp);
        }

        for (int d = 0; d < 3; d++) {
            for (int rc = 0; rc <= 1; rc++) {
                sum[src][d] += dp[src][d][rc];
            }
        }
    }
};

void solve() {
    int n;
    cin >> n;

    Tree t(n);
    int m = n - 1;
    for (int i = 0; i < m; i++) {
        int x, y;
        cin >> x >> y;
        x--;
        y--;
        t.adj[x].push_back(y);
        t.adj[y].push_back(x);
    }

    t.dfs(0, -1);
    mint ans = 0;
    auto &sum = t.sum;
    for (int d = 0; d < 3; d++) {
        ans += sum[0][d];
    }
    cout << ans.val() << "\n";
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int t;
    cin >> t;
    for (int i = 0; i < t; i++) {
        solve();
    }
    return 0;
}