Code : Three Paths on a Tree
#include <bits/stdc++.h>
using namespace std;
vector<int> multi_source_bfs(vector<vector<int>> &adj, vector<int> &sources,
vector<int> &parent) {
int n = adj.size();
queue<int> q;
vector<int> level(n, -1);
for (auto &src : sources) {
level[src] = 0;
parent[src] = -1;
q.push(src);
}
while (!q.empty()) {
int now = q.front();
q.pop();
for (auto child : adj[now]) {
if (level[child] == -1) {
level[child] = level[now] + 1;
parent[child] = now;
q.push(child);
}
}
}
return level;
}
void solve(vector<vector<int>> &adj) {
int n = adj.size();
vector<int> sources;
vector<int> parent(n, -1);
int mx = -1, farthest_node = -1;
// Run a BFS from node 0 to get one endpoint of the diameter.
sources.push_back(0);
vector<int> dist = multi_source_bfs(adj, sources, parent);
for (int i = 0; i < n; i++) {
if (mx < dist[i]) {
mx = dist[i];
farthest_node = i;
}
}
int diameter_end_a = farthest_node;
// Run a BFS from the diameter end to get the other end.
sources.clear();
sources.push_back(diameter_end_a);
dist = multi_source_bfs(adj, sources, parent);
mx = -1, farthest_node = -1;
for (int i = 0; i < n; i++) {
if (mx < dist[i]) {
mx = dist[i];
farthest_node = i;
}
}
// Diameter is the distance between these 2 nodes.
int diameter_end_b = farthest_node;
int diameter = mx;
// Extract all the vertices lying on this diameter.
sources.clear();
int node = diameter_end_b;
while (node != -1) {
sources.push_back(node);
node = parent[node];
}
// Start a multi source BFS from all nodes on this diameter.
// It would give you the maximum height in the forest.
dist = multi_source_bfs(adj, sources, parent);
mx = -1, farthest_node = -1;
for (int i = 0; i < n; i++) {
if (mx < dist[i]) {
mx = dist[i];
farthest_node = i;
}
}
cout << diameter + mx << "\n";
diameter_end_a++;
diameter_end_b++;
farthest_node++;
cout << diameter_end_a << " " << diameter_end_b << " " << farthest_node
<< "\n";
}
int main() {
int n;
cin >> n;
vector<vector<int>> adj(n);
for (int i = 0; i < (n - 1); i++) {
int u, v;
cin >> u >> v;
u--;
v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
solve(adj);
return 0;
}
#include <bits/stdc++.h>
using namespace std;
vector<int> multi_source_bfs(vector<vector<int>> &adj, vector<int> &sources,
vector<int> &parent) {
int n = adj.size();
queue<int> q;
vector<int> level(n, -1);
for (auto &src : sources) {
level[src] = 0;
parent[src] = -1;
q.push(src);
}
while (!q.empty()) {
int now = q.front();
q.pop();
for (auto child : adj[now]) {
if (level[child] == -1) {
level[child] = level[now] + 1;
parent[child] = now;
q.push(child);
}
}
}
return level;
}
void solve(vector<vector<int>> &adj) {
int n = adj.size();
vector<int> sources;
vector<int> parent(n, -1);
int mx = -1, farthest_node = -1;
// Run a BFS from node 0 to get one endpoint of the diameter.
sources.push_back(0);
vector<int> dist = multi_source_bfs(adj, sources, parent);
for (int i = 0; i < n; i++) {
if (mx < dist[i]) {
mx = dist[i];
farthest_node = i;
}
}
int diameter_end_a = farthest_node;
// Run a BFS from the diameter end to get the other end.
sources.clear();
sources.push_back(diameter_end_a);
dist = multi_source_bfs(adj, sources, parent);
mx = -1, farthest_node = -1;
for (int i = 0; i < n; i++) {
if (mx < dist[i]) {
mx = dist[i];
farthest_node = i;
}
}
// Diameter is the distance between these 2 nodes.
int diameter_end_b = farthest_node;
int diameter = mx;
// Extract all the vertices lying on this diameter.
sources.clear();
int node = diameter_end_b;
while (node != -1) {
sources.push_back(node);
node = parent[node];
}
// Start a multi source BFS from all nodes on this diameter.
// It would give you the maximum height in the forest.
dist = multi_source_bfs(adj, sources, parent);
mx = -1, farthest_node = -1;
for (int i = 0; i < n; i++) {
if (mx <= dist[i]) {
mx = dist[i];
// If the tree is a straight line, pick a random source, except the
// ends.
if (i != diameter_end_a && i != diameter_end_b) {
farthest_node = i;
}
}
}
cout << diameter + mx << "\n";
diameter_end_a++;
diameter_end_b++;
farthest_node++;
cout << diameter_end_a << " " << diameter_end_b << " " << farthest_node
<< "\n";
}
int main() {
int n;
cin >> n;
vector<vector<int>> adj(n);
for (int i = 0; i < (n - 1); i++) {
int u, v;
cin >> u >> v;
u--;
v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
solve(adj);
return 0;
}