block by armollica 64ffc3bd8fc76c5657719a842e39c4e3

K-D Tree Nearest Neighbor

Full Screen

k-d tree nearest neighbor search (1-NN). The red dot is the nearest neighbor. Orange dots are scanned and not selected.

Compare to nearest neighbor search using quadtrees from this block. The k-d tree technique seems to scan more points, although the process of limiting the search set is different so this isn’t really a direct measure of which is more efficient.

Here’s a more up-to-date version of this block that works for k nearest neighbors.

index.html

<html>
  <head>
    <style>
      .line {
        fill: none;
        stroke: #ccc;
      }
      
      .point {
        fill: #999;
        stroke: #fff;
      }

      .point.scanned {
        fill: orange;
        stroke: #999;
      }
      
      .point.selected {
        fill: red;
        stroke: #999;
      }
      
      .halo {
        fill: none;
        stroke: red;
      }
    </style>
  </head>
  <body>
     <script src="kd-tree.js"></script>
     <script src="https://d3js.org/d3.v3.min.js" charset="utf-8"></script>
     <script>
       var width = 960,
           height = 500;
       
       var svg = d3.select("body").append("svg")
        .attr("width", width)
        .attr("height", height);
       
       var data = d3.range(2000)
        .map(function() {
          return {
            x: width * Math.random(),
            y: height * Math.random(),
            value: d3.random.normal()() // just for testing purposes
          };
        });
        
        var tree = KDTree()
          .x(function(d) { return d.x; })
          .y(function(d) { return d.y; })
          (data);
        
        svg.append("g").attr("class", "lines")
            .selectAll(".line").data(tree.lines([[0,0], [width, height]]))
          .enter().append("path")
            .attr("class", "line")
            .attr("d", d3.svg.line());
            
        var points = svg.append("g").attr("class", "points")
            .selectAll(".point").data(tree.flatten())
          .enter().append("circle")
            .attr("class", "point")
            .attr("cx", function(d) { return d.location[0]; })
            .attr("cy", function(d) { return d.location[1]; })
            .attr("r", 4);
        
        var halo = svg.append("circle").attr("class", "halo");
        
        update([width/3, height/2]);
        
        svg.append("rect").attr("class", "event-canvas")
          .attr("width", width)
          .attr("height", height)
          .attr("fill-opacity", 0)
          .on("mousemove", function() { update(d3.mouse(this)); });
          
        function update(point) {
          var nearest = tree.find(point);
              
          points
            .classed("scanned", function(d) { return nearest.scannedNodes.indexOf(d) !== -1; })
            .classed("selected", function(d) { return d === nearest.node; });
            
          halo
            .attr("cx", point[0])
            .attr("cy", point[1])
            .attr("r", nearest.distance);
        };
     </script>
  </body>
</html>

kd-tree.js

function Node(location, axis, subnodes, datum) {
  this.location = location;
  this.axis = axis;
  this.subnodes = subnodes;  // = children nodes = [left child, right child]
  this.datum = datum;
};

Node.prototype.toArray = function() {
  var array = [
    this.location, 
    this.subnodes[0] ? this.subnodes[0].toArray() : null, 
    this.subnodes[0] ? this.subnodes[1].toArray() : null
  ];
  array.axis = this.axis;
  return array;
};

Node.prototype.flatten = function() {
  var left = this.subnodes[0] ? this.subnodes[0].flatten() : null,
      right = this.subnodes[1] ? this.subnodes[1].flatten() : null;
  return left && right ? [this].concat(left, right) :
         left ? [this].concat(left) :
         right ? [this].concat(right) :
         [this];
};

// Nearest neighbor search (1-NN)
Node.prototype.find = function(target) {
  var guess = this,
      bestDist = Infinity,  
      scannedNodes = [];  // keep track of these just for testing purpose
      
  search(this);
  
  return { 
    node: guess, 
    distance: bestDist,
    scannedNodes: scannedNodes
  };
  
  // 1-NN algorithm outlined here:
  // http://web.stanford.edu/class/cs106l/handouts/assignment-3-kdtree.pdf
  function search(node) {
    if (node === null) return;
    
    scannedNodes.push(node);
    
    // If the current location is better than the best known location,
    // update the best known location
    var nodeDist = distance(node.location, target);
    if (nodeDist < bestDist) {
      bestDist = nodeDist;
      guess = node;
    }
    
    // Recursively search the half of the tree that contains the target
    var side = target[node.axis] < node.location[node.axis] ? "left" : "right";
    if (side == "left") {
      search(node.subnodes[0]);
      var otherNode = node.subnodes[1];
    }
    else {
      search(node.subnodes[1]);
      var otherNode = node.subnodes[0];
    }
    
    // If the candidate hypersphere crosses this splitting plane, look on the
    // other side of the plane by examining the other subtree
    if (otherNode !== null) {
      var i = node.axis;
      var delta = Math.abs(node.location[i] - target[i]);
      if (delta < bestDist) {
        search(otherNode);
      } 
    }
  }
};

// Only works for 2D
Node.prototype.lines = function(extent) {
  var x0 = extent[0][0], 
      y0 = extent[0][1],
      x1 = extent[1][0], 
      y1 = extent[1][1],
      x = this.location[0],
      y = this.location[1];
    
  if (this.axis == 0) {
    var line = [[x, y0], [x, y1]];
    var left = this.subnodes[0] ?
      this.subnodes[0].lines([[x0, y0], [x, y1]]) : null;
    var right = this.subnodes[1] ?
      this.subnodes[1].lines([[x, y0], [x1, y1]]) : null;
  } 
  else if (this.axis == 1) {
    var line = [[x0, y], [x1, y]];
    var left = this.subnodes[0] ?
      this.subnodes[0].lines([[x0, y0], [x1, y]]) : null;
    var right = this.subnodes[1] ?
      this.subnodes[1].lines([[x0, y], [x1, y1]]) : null;
  }
  
  return left && right ? [line].concat(left, right) :
         left ? [line].concat(left) :
         right ? [line].concat(right) :
         [line];
}

function KDTree() {
  var x = function(d) { return d[0]; },
      y = function(d) { return d[1]; };
  
  function tree(data) {
    var points = data.map(function(d) { 
      var point = [x(d), y(d)];
      point.datum = d;
      return point; 
    });
    
    return treeify(points, 0);
  }
  
  tree.x = function(_) {
    if (!arguments.length) return x;
    x = _;
    return tree;
  };
  
  tree.y = function(_) {
    if (!arguments.length) return y;
    y = _;
    return tree;
  };
  
  return tree;
  
  // Adapted from https://en.wikipedia.org/wiki/K-d_tree
  function treeify(points, depth) {
      try { var k = points[0].length; }
      catch (e) { return null; }
      
      // Select axis based on depth so that axis cycles through all valid values
      var axis = depth % k;
      
      // TODO: To speed up, consider splitting points based on approximation of
      //       median; take median of random sample of points (perhaps of 1/10th 
      //       of the points)
      
      // Sort point list and choose median as pivot element
      points.sort(function(a, b) { return a[axis] - b[axis]; });
      i_median = Math.floor(points.length / 2);
      
      // Create node and construct subtrees
      var point = points[i_median],
          left_points = points.slice(0, i_median),
          right_points = points.slice(i_median + 1);
          
      return new Node(
        point,
        axis,
        [treeify(left_points, depth + 1), treeify(right_points, depth + 1)],
        point.datum
      );
    }
}

function min(array, accessor) {
  return array
    .map(function(d) { return accessor(d); })
    .reduce(function(a, b) { return a < b ? a : b; });
}

function max(array, accessor) {
  return array
    .map(function(d) { return accessor(d); })
    .reduce(function(a, b) { return a > b ? a : b; });
}

function get(key) { return function(d) { return d[key]; }; }

// TODO: Make distance function work for k-dimensions

// Euclidean distance between two 2D points
function distance(p0, p1) {
  return Math.sqrt(Math.pow(p1[0] - p0[0], 2) + Math.pow(p1[1] - p0[1], 2));
}