block by micahstubbs f2aff83148a5f64f3222

Animated ROC Chart

Full Screen

A d3js ROC Chart that animates through the areas under each curve.
Might be useful for evaluating models.

the data shown are from models that predict various tennis stats. curious about the source data? browse through the R script that @ilanthedataman uses to generate the data.

inspired by the Interactive ROC Curve bl.ock from ilanman

index.html

<!DOCTYPE html>
<meta charset="utf-8">
<head>
  <link href='https://fonts.googleapis.com/css?family=Roboto:400,700,300,100,900|Open+Sans:400,300,700,600,800' rel='stylesheet' type='text/css'>
</head>
<style>

body { 
  font-size: 12px; 
  font-family: 'Open Sans';
}

path { 
    stroke-width: 3;
    fill: none;
    opacity: .7;
}

.axis path,
.axis line {
    fill: none;
    stroke: grey;
    stroke-width: 2;
    shape-rendering: crispEdges;
    opacity: 1;
}

.d3-tip {
    font-family: Verdana;
    background: rgba(0, 0, 0, 0.8);
    padding: 8px;
    color: #fff; 
    z-index: 5070;
  }

</style>
<body>

<div id="roc"> </div>

<script src="https://d3js.org/d3.v3.min.js" charset="utf-8"></script>
<script src="rocChart.js"></script>
<script>

var margin = {top: 30, right: 61, bottom: 70, left: 61}, 
    width = 470 - margin.left - margin.right,
    height = 450 - margin.top - margin.bottom;

// fpr for "false positive rate"
// tpr for "true positive rate"

var rocChartOptions = {
  "margin": margin,
  "width": width,
  "height": height,
  "interpolationMode": "basis",
  "fpr": "X",
  "tprVariables": [
    {
      "name": "BPC",
      "label": "Break Points"
    },
    {
      "name": "WNR",
      "label": "Winners"
    },
    {
      "name": "FSP",
      "label": "First Serve %",
    },
    {
      "name": "NPW",
      "label": "Net Points Won"
    }
  ], 
  "animate": true,
  "smooth": true
}

d3.json("data.json", function(error, data) {
  rocChart("#roc", data, rocChartOptions)
})

</script>
</body>

rocChart.js

function rocChart(id, data, options) {

  // set default configuration
  var cfg = {
    "margin": {top: 30, right: 20, bottom: 70, left: 61},
    "width": 470,
    "height": 450,
    "interpolationMode": "basis",
    "ticks": undefined,
    "tickValues": [0, .1, .25, .5, .75, .9, 1],
    "fpr": "fpr",
    "tprVariables": [{
      "name": "tpr0",
    }], 
    "animate": true
  }

  //Put all of the options into a variable called cfg
  if('undefined' !== typeof options){
    for(var i in options){
      if('undefined' !== typeof options[i]){ cfg[i] = options[i]; }
    }//for i
  }//if

  var tprVariables = cfg["tprVariables"];
  // if values for labels are not specified
  // set the default values for the labels to the corresponding
  // true positive rate variable name
  tprVariables.forEach(function(d, i) {
    if('undefined' === typeof d.label){
      d.label = d.name;
    }

  })

  console.log("tprVariables", tprVariables);


  var interpolationMode = cfg["interpolationMode"],
      fpr = cfg["fpr"],
      width = cfg["width"],
      height = cfg["height"],
      animate = cfg["animate"]

  var format = d3.format('.2');
  var aucFormat = d3.format('.4r')
  
  var x = d3.scale.linear().range([0, width]);
  var y = d3.scale.linear().range([height, 0]);
  var color = d3.scale.category10() // d3.scale.ordinal().range(["steelblue", "red", "green", "purple"]);

  var xAxis = d3.svg.axis()
    .scale(x)
    .orient("top")
    .outerTickSize(0);

  var yAxis = d3.svg.axis()
    .scale(y)
    .orient("right")
    .outerTickSize(0);

  // set the axis ticks based on input parameters,
  // if ticks or tickValues are specified
  if('undefined' !== typeof cfg["ticks"]) {
    xAxis.ticks(cfg["ticks"]);
    yAxis.ticks(cfg["ticks"]);
  } else if ('undefined' !== typeof cfg["tickValues"]) {
    xAxis.tickValues(cfg["tickValues"]);
    yAxis.tickValues(cfg["tickValues"]);
  } else {
    xAxis.ticks(5);
    yAxis.ticks(5);
  }

  // apply the format to the ticks we chose
  xAxis.tickFormat(format);
  yAxis.tickFormat(format);
  
  // a function that returns a line generator
  function curve(data, tpr) {

     var lineGenerator = d3.svg.line()
      .interpolate(interpolationMode)
      .x(function(d) { return x(d[fpr]); })
      .y(function(d) { return y(d[tpr]); });

    return lineGenerator(data);
  }

  // a function that returns an area generator
  function areaUnderCurve(data, tpr) {

    var areaGenerator = d3.svg.area()
      .x(function(d) { return x(d[fpr]); })
      .y0(height)
      .y1(function(d) { return y(d[tpr]); });

    return areaGenerator(data);
  }

  var svg = d3.select("#roc")
    .append("svg")
      .attr("width", width + margin.left + margin.right)
      .attr("height", height + margin.top + margin.bottom)
      .append("g")
        .attr("transform", "translate(" + margin.left + "," + margin.top + ")");

  x.domain([0, 1]);
  y.domain([0, 1]);

  svg.append("g")
      .attr("class", "x axis")
      .attr("transform", "translate(0," + height + ")")
      .call(xAxis)
      .append("text")            
        .attr("x", width / 2)
        .attr("y", 40 )
        .style("text-anchor", "middle")
        .text("False Positive Rate")
          
  var xAxisG = svg.select("g.x.axis");
            
  // draw the top boundary line
  xAxisG.append("line")
    .attr({
      "x1": -1,
      "x2": width + 1,
      "y1": -height,
      "y2": -height
    });

  // draw a bottom boundary line over the existing
  // x-axis domain path to make even corners
  xAxisG.append("line")
    .attr({
      "x1": -1,
      "x2": width + 1,
      "y1": 0,
      "y2": 0
    });


  // position the axis tick labels below the x-axis
  xAxisG.selectAll('.tick text')
    .attr('transform', 'translate(0,' + 25 + ')');

  // hide the y-axis ticks for 0 and 1
  xAxisG.selectAll("g.tick line")
    .style("opacity", function(d) {
      // if d is an integer
      return d % 1 === 0 ? 0 : 1;
    });

  svg.append("g")
    .attr("class", "y axis")
    .call(yAxis)
    .append("text")            
      .attr("transform", "rotate(-90)")
      .attr("y", -35)
      // manually configured so that the label is centered vertically 
      .attr("x", 0 - height/1.56)  
      .style("font-size","12px")              
      .style("text-anchor", "left")
      .text("True Positive Rate");

  yAxisG = svg.select("g.y.axis");

  // add the right boundary line
  yAxisG.append("line")
    .attr({
      "x1": width,
      "x2": width,
      "y1": 0,
      "y2": height
    })

  // position the axis tick labels to the right of 
  // the y-axis and
  // translate the first and the last tick labels
  // so that they are right aligned
  // or even with the 2nd digit of the decimal number
  // tick labels
  yAxisG.selectAll("g.tick text")
    .attr('transform', function(d) {
      if(d % 1 === 0) { // if d is an integer
        return 'translate(' + -22 + ',0)';
      } else if((d*10) % 1 === 0) { // if d is a 1 place decimal
        return 'translate(' + -32 + ',0)';
      } else {
        return 'translate(' + -42 + ',0)';
      }
  })

  // hide the y-axis ticks for 0 and 1
  yAxisG.selectAll("g.tick line")
    .style("opacity", function(d) {
      // if d is an integer
      return d % 1 === 0 ? 0 : 1;
    });

  // draw the random guess line
  svg.append("line") 
    .attr("class", "curve")         
    .style("stroke", "black")
    .attr({
      "x1": 0,
      "x2": width,
      "y1": height,
      "y2": 0
    })
    .style({
      "stroke-width": 2,
      "stroke-dasharray": "8",
      "opacity": 0.4
    })

  // draw the ROC curves
  function drawCurve(data, tpr, stroke){

    svg.append("path")
      .attr("class", "curve")
      .style("stroke", stroke)
      .attr("d", curve(data, tpr))
      .on('mouseover', function(d) {

        var areaID = "#" + tpr + "Area";
        svg.select(areaID)
          .style("opacity", .4)

        var aucText = "." + tpr + "text"; 
        svg.selectAll(aucText)
          .style("opacity", .9)
      })
      .on('mouseout', function(){
        var areaID = "#" + tpr + "Area";
        svg.select(areaID)
          .style("opacity", 0)

        var aucText = "." + tpr + "text"; 
        svg.selectAll(aucText)
          .style("opacity", 0)


      });
  }

  // draw the area under the ROC curves
  function drawArea(data, tpr, fill) {
    svg.append("path")
      .attr("class", "area")
      .attr("id", tpr + "Area")
      .style({
        "fill": fill,
        "opacity": 0
      })
      .attr("d", areaUnderCurve(data, tpr))
  }

  function drawAUCText(auc, tpr, label) {

    svg.append("g")
      .attr("class", tpr + "text")
      .style("opacity", 0)
      .attr("transform", "translate(" + .5*width + "," + .79*height + ")")
      .append("text")
        .text(label)
        .style({
          "fill": "white",
          "font-size": 18
        });

    svg.append("g")
      .attr("class", tpr + "text")
      .style("opacity", 0)
      .attr("transform", "translate(" + .5*width + "," + .84*height + ")")
      .append("text")
        .text("AUC = " + aucFormat(auc))
        .style({
          "fill": "white",
          "font-size": 18
        });

  }
  
  // calculate the area under each curve
  tprVariables.forEach(function(d){
    var tpr = d.name;
    var points = generatePoints(data, fpr, tpr);
    var auc = calculateArea(points);
    d["auc"] = auc;
  })

  console.log("tprVariables", tprVariables);

  // draw curves, areas, and text for each 
  // true-positive rate in the data
  tprVariables.forEach(function(d, i){
    console.log("drawing the curve for", d.label)
    console.log("color(", i, ")", color(i));
    var tpr = d.name;
    drawArea(data, tpr, color(i))
    drawCurve(data, tpr, color(i));
    drawAUCText(d.auc, tpr, d.label);
  })

  ///////////////////////////////////////////////////
  ////// animate through areas for each curve ///////
  ///////////////////////////////////////////////////

  if(animate && animate !== "false") {
    //sort tprVariables ascending by AUC
    var tprVariablesAscByAUC = tprVariables.sort(function(a, b) {
      return a.auc - b.auc;
    })

    console.log("tprVariablesAscByAUC", tprVariablesAscByAUC);
    
    for(var i = 0; i < tprVariablesAscByAUC.length; i++) {
      areaID = "#" + tprVariablesAscByAUC[i]["name"] + "Area";
      svg.select(areaID)
        .transition()
          .delay(2000 * (i+1))
          .duration(250)
          .style("opacity", .4)
        .transition()
          .delay(2000 * (i+2))
          .duration(250)
          .style("opacity", 0)

      textClass = "." + tprVariablesAscByAUC[i]["name"] + "text";
      svg.selectAll(textClass)
        .transition()
          .delay(2000 * (i+1))
          .duration(250)
          .style("opacity", .9)
        .transition()
          .delay(2000 * (i+2))
          .duration(250)
          .style("opacity", 0)
    }  
  }

  ///////////////////////////////////////////////////
  ///////////////////////////////////////////////////
  ///////////////////////////////////////////////////

  function generatePoints(data, x, y) {
    var points = [];
    data.forEach(function(d){
      points.push([ Number(d[x]), Number(d[y]) ])
    })
    return points;
  }

  // numerical integration
  function calculateArea(points) {
    var area = 0.0;
    var length = points.length;
    if (length <= 2) {
      return area;
    }
    points.forEach(function(d, i) {
      var x = 0,
          y = 1;

      if('undefined' !== typeof points[i-1]){
        area += (points[i][x] - points[i-1][x]) * (points[i-1][y] + points[i][y]) / 2;
      }
      
    });
    return area;
  }

} // rocChart