Program Listing for File build_tree.h¶
↰ Return to documentation for file (include/build_tree.h
)
#ifndef build_tree_h
#define build_tree_h
#include <cassert>
#include <queue>
#include <unordered_map>
#include "exafmm_t.h"
#include "hilbert.h"
#include "fmm_base.h"
namespace exafmm_t {
template <typename T>
void get_bounds(const Bodies<T>& sources, const Bodies<T>& targets, vec3& x0, real_t& r0) {
vec3 Xmin = sources[0].X;
vec3 Xmax = sources[0].X;
for (size_t b=0; b<sources.size(); ++b) {
Xmin = min(sources[b].X, Xmin);
Xmax = max(sources[b].X, Xmax);
}
for (size_t b=0; b<targets.size(); ++b) {
Xmin = min(targets[b].X, Xmin);
Xmax = max(targets[b].X, Xmax);
}
x0 = (Xmax + Xmin) / 2;
r0 = fmax(max(x0-Xmin), max(Xmax-x0));
r0 *= 1.00001;
}
template <typename T>
void sort_bodies(Node<T>* const node, Body<T>* const bodies, Body<T>* const buffer,
int begin, int end, std::vector<int>& size, std::vector<int>& offsets) {
// Count number of bodies in each octant
size.resize(8, 0);
vec3 X = node->x; // the center of the node
for (int i=begin; i<end; i++) {
vec3& x = bodies[i].X;
int octant = (x[0] > X[0]) + ((x[1] > X[1]) << 1) + ((x[2] > X[2]) << 2);
size[octant]++;
}
// Exclusive scan to get offsets
offsets.resize(8);
int offset = begin;
for (int i=0; i<8; i++) {
offsets[i] = offset;
offset += size[i];
}
// Sort bodies by octant
std::vector<int> counter(offsets);
for (int i=begin; i<end; i++) {
vec3& x = bodies[i].X;
int octant = (x[0] > X[0]) + ((x[1] > X[1]) << 1) + ((x[2] > X[2]) << 2);
buffer[counter[octant]].X = bodies[i].X;
buffer[counter[octant]].q = bodies[i].q;
buffer[counter[octant]].ibody = bodies[i].ibody;
counter[octant]++;
}
}
template <typename T>
void build_tree(Body<T>* sources, Body<T>* sources_buffer, int source_begin, int source_end,
Body<T>* targets, Body<T>* targets_buffer, int target_begin, int target_end,
Node<T>* node, Nodes<T>& nodes, NodePtrs<T>& leafs, NodePtrs<T>& nonleafs,
const Keys& leafkeys, FmmBase<T>& fmm, bool direction=false) {
node->idx = int(node-&nodes[0]); // current node's index in nodes
node->nsrcs = source_end - source_begin;
node->ntrgs = target_end - target_begin;
node->up_equiv.resize(fmm.nsurf, (T)(0.));
node->dn_equiv.resize(fmm.nsurf, (T)(0.));
ivec3 iX = get3DIndex(node->x, node->level, fmm.x0, fmm.r0);
node->key = getKey(iX, node->level);
bool is_leaf_key = 1;
if (!leafkeys.empty()) { // when leafkeys is given (when balancing tree)
std::set<uint64_t>::iterator it = leafkeys[node->level].find(node->key);
if (it == leafkeys[node->level].end()) { // if current key is not a leaf key
is_leaf_key = 0;
}
}
if (node->nsrcs<=fmm.ncrit && node->ntrgs<=fmm.ncrit && is_leaf_key) {
node->is_leaf = true;
node->trg_value.resize(node->ntrgs*4, (T)(0.)); // initialize target result vector
if (node->nsrcs || node->ntrgs)
leafs.push_back(node);
if (direction) {
for (int i=source_begin; i<source_end; i++) {
sources_buffer[i].X = sources[i].X;
sources_buffer[i].q = sources[i].q;
sources_buffer[i].ibody = sources[i].ibody;
}
for (int i=target_begin; i<target_end; i++) {
targets_buffer[i].X = targets[i].X;
targets_buffer[i].ibody = targets[i].ibody;
}
}
// Copy sources and targets' coords and values to leaf (only during 2:1 tree balancing)
if (!leafkeys.empty()) {
Body<T>* first_source = (direction ? sources_buffer : sources) + source_begin;
Body<T>* first_target = (direction ? targets_buffer : targets) + target_begin;
for (Body<T>* B=first_source; B<first_source+node->nsrcs; ++B) {
for (int d=0; d<3; ++d) {
node->src_coord.push_back(B->X[d]);
}
node->isrcs.push_back(B->ibody);
node->src_value.push_back(B->q);
}
for (Body<T>* B=first_target; B<first_target+node->ntrgs; ++B) {
for (int d=0; d<3; ++d) {
node->trg_coord.push_back(B->X[d]);
}
node->itrgs.push_back(B->ibody);
}
}
return;
}
// Sort bodies and save in buffer
std::vector<int> source_size, source_offsets;
std::vector<int> target_size, target_offsets;
sort_bodies(node, sources, sources_buffer, source_begin, source_end, source_size, source_offsets); // sources_buffer is sorted
sort_bodies(node, targets, targets_buffer, target_begin, target_end, target_size, target_offsets); // targets_buffer is sorted
node->is_leaf = false;
nonleafs.push_back(node);
assert(nodes.capacity() >= nodes.size()+NCHILD);
nodes.resize(nodes.size()+NCHILD);
Node<T> * child = &nodes.back() - NCHILD + 1;
node->children.resize(8, nullptr);
for (int c=0; c<8; c++) {
node->children[c] = &child[c];
child[c].x = node->x;
child[c].r = node->r / 2;
for (int d=0; d<3; d++) {
child[c].x[d] += child[c].r * (((c & 1 << d) >> d) * 2 - 1);
}
child[c].parent = node;
child[c].octant = c;
child[c].level = node->level + 1;
build_tree(sources_buffer, sources, source_offsets[c], source_offsets[c] + source_size[c],
targets_buffer, targets, target_offsets[c], target_offsets[c] + target_size[c],
&child[c], nodes, leafs, nonleafs,
leafkeys, fmm, !direction);
}
}
template <typename T>
Nodes<T> build_tree(Bodies<T>& sources, Bodies<T>& targets,
NodePtrs<T>& leafs, NodePtrs<T>& nonleafs,
FmmBase<T>& fmm, const Keys& leafkeys=Keys()) {
Bodies<T> sources_buffer = sources;
Bodies<T> targets_buffer = targets;
Nodes<T> nodes(1);
nodes[0].parent = nullptr;
nodes[0].octant = 0;
nodes[0].x = fmm.x0;
nodes[0].r = fmm.r0;
nodes[0].level = 0;
nodes.reserve((sources.size()+targets.size()) * (32/fmm.ncrit+1));
build_tree(&sources[0], &sources_buffer[0], 0, sources.size(),
&targets[0], &targets_buffer[0], 0, targets.size(),
&nodes[0], nodes, leafs, nonleafs,
leafkeys, fmm);
return nodes;
}
template <typename T>
Keys breadth_first_traversal(Node<T>* const root, std::unordered_map<uint64_t, size_t>& key2id) {
assert(root);
Keys keys;
std::queue<Node<T>*> buffer;
buffer.push(root);
int level = 0;
std::set<uint64_t> keys_;
while (!buffer.empty()) {
Node<T>* curr = buffer.front();
if (curr->level != level) {
keys.push_back(keys_);
keys_.clear();
level = curr->level;
}
keys_.insert(curr->key);
key2id[curr->key] = curr->idx;
buffer.pop();
if (!curr->is_leaf) {
for (int i=0; i<NCHILD; i++) {
buffer.push(curr->children[i]);
}
}
}
if (keys_.size())
keys.push_back(keys_);
return keys;
}
template <typename K>
Keys balance_tree(const Keys& keys, const std::unordered_map<uint64_t, size_t>& key2id, const Nodes<K>& nodes) {
int nlevels = keys.size();
int maxlevel = nlevels - 1;
Keys bkeys(keys.size()); // balanced Morton keys
std::set<uint64_t> S, N;
std::set<uint64_t>::iterator it;
for (int l=maxlevel; l>0; --l) {
// N <- S + nonleafs
N.clear();
for (it=keys[l].begin(); it!=keys[l].end(); ++it)
if (!nodes[key2id.at(*it)].is_leaf) // choose nonleafs
N.insert(*it);
N.insert(S.begin(), S.end()); // N = S + nonleafs
S.clear();
// S <- Parent(Colleagues(N))
for (it=N.begin(); it!=N.end(); ++it) {
ivec3 iX = get3DIndex(*it); // find N's colleagues
ivec3 ciX;
for (int m=-1; m<=1; ++m) {
for (int n=-1; n<=1; ++n) {
for (int p=-1; p<=1; ++p) {
if (m||n||p) {
ciX[0] = iX[0] + m;
ciX[1] = iX[1] + n;
ciX[2] = iX[2] + p;
if (ciX[0]>=0 && ciX[0]<pow(2,l) && // boundary check
ciX[1]>=0 && ciX[1]<pow(2,l) &&
ciX[2]>=0 && ciX[2]<pow(2,l)) {
uint64_t colleague = getKey(ciX, l);
uint64_t parent = getParent(colleague);
S.insert(parent); // S: parent of N's colleague
}
}
}
}
}
}
// T <- T + Children(N)
if (l!=maxlevel) {
std::set<uint64_t>& T = bkeys[l+1];
for (it=N.begin(); it!=N.end(); ++it) {
uint64_t child = getChild(*it);
for (int i=0; i<8; ++i) {
T.insert(child+i);
}
}
}
}
// manually add keys for lvl 0 and 1
bkeys[0].insert(0);
for(int i=1; i<9; ++i) bkeys[1].insert(i);
return bkeys;
}
Keys find_leaf_keys(const Keys& keys) {
std::set<uint64_t>::iterator it;
Keys leafkeys(keys.size());
for (int l=keys.size()-1; l>=1; --l) {
std::set<uint64_t> parentkeys = keys[l-1];
// remove nonleaf keys
for (it=keys[l].begin(); it!=keys[l].end(); ++it) {
uint64_t parentkey = getParent(*it);
std::set<uint64_t>::iterator it2 = parentkeys.find(parentkey);
if (it2 != parentkeys.end()) parentkeys.erase(it2);
}
leafkeys[l-1] = parentkeys;
}
leafkeys[keys.size()-1] = keys.back();
return leafkeys;
}
template <typename T>
void balance_tree(Nodes<T>& nodes, Bodies<T>& sources, Bodies<T>& targets,
NodePtrs<T>& leafs, NodePtrs<T>& nonleafs, FmmBase<T>& fmm) {
std::unordered_map<uint64_t, size_t> key2id;
Keys keys = breadth_first_traversal(&nodes[0], key2id);
if (nodes.size() == 1) {
nodes.clear();
leafs.clear();
nonleafs.clear();
nodes = build_tree(sources, targets,
leafs, nonleafs,
fmm, keys);
} else {
Keys balanced_keys = balance_tree(keys, key2id, nodes);
Keys leaf_keys = find_leaf_keys(balanced_keys);
nodes.clear();
leafs.clear();
nonleafs.clear();
nodes = build_tree(sources, targets,
leafs, nonleafs,
fmm, leaf_keys);
}
fmm.depth = keys.size() - 1;
}
}
#endif