A d3js ROC Chart. Might be useful for evaluating models.
the data shown are from models that predict various tennis stats. curious about the source data? browse through the R script that @ilanthedataman uses to generate the data.
check out the animated version
inspired by the Interactive ROC Curve bl.ock from ilanman
<!DOCTYPE html>
<meta charset="utf-8">
<head>
<link href='https://fonts.googleapis.com/css?family=Roboto:400,700,300,100,900|Open+Sans:400,300,700,600,800' rel='stylesheet' type='text/css'>
</head>
<style>
body {
font-size: 12px;
font-family: 'Open Sans';
}
path {
stroke-width: 3;
fill: none;
opacity: .7;
}
.axis path,
.axis line {
fill: none;
stroke: grey;
stroke-width: 2;
shape-rendering: crispEdges;
opacity: 1;
}
.d3-tip {
font-family: Verdana;
background: rgba(0, 0, 0, 0.8);
padding: 8px;
color: #fff;
z-index: 5070;
}
</style>
<body>
<div id="roc"> </div>
<script src="https://d3js.org/d3.v3.min.js" charset="utf-8"></script>
<script src="rocChart.js"></script>
<script>
var margin = {top: 30, right: 61, bottom: 70, left: 61},
width = 470 - margin.left - margin.right,
height = 450 - margin.top - margin.bottom;
// fpr for "false positive rate"
// tpr for "true positive rate"
var rocChartOptions = {
"margin": margin,
"width": width,
"height": height,
"interpolationMode": "basis",
"fpr": "X",
"tprVariables": [
{
"name": "BPC",
"label": "Break Points"
},
{
"name": "WNR",
"label": "Winners"
},
{
"name": "FSP",
"label": "First Serve %",
},
{
"name": "NPW",
"label": "Net Points Won"
}
],
"animate": false,
"smooth": true
}
d3.json("data.json", function(error, data) {
rocChart("#roc", data, rocChartOptions)
})
</script>
</body>
function rocChart(id, data, options) {
// set default configuration
var cfg = {
"margin": {top: 30, right: 20, bottom: 70, left: 61},
"width": 470,
"height": 450,
"interpolationMode": "basis",
"ticks": undefined,
"tickValues": [0, .1, .25, .5, .75, .9, 1],
"fpr": "fpr",
"tprVariables": [{
"name": "tpr0",
}],
"animate": true
}
//Put all of the options into a variable called cfg
if('undefined' !== typeof options){
for(var i in options){
if('undefined' !== typeof options[i]){ cfg[i] = options[i]; }
}//for i
}//if
var tprVariables = cfg["tprVariables"];
// if values for labels are not specified
// set the default values for the labels to the corresponding
// true positive rate variable name
tprVariables.forEach(function(d, i) {
if('undefined' === typeof d.label){
d.label = d.name;
}
})
console.log("tprVariables", tprVariables);
var interpolationMode = cfg["interpolationMode"],
fpr = cfg["fpr"],
width = cfg["width"],
height = cfg["height"],
animate = cfg["animate"]
var format = d3.format('.2');
var aucFormat = d3.format('.4r')
var x = d3.scale.linear().range([0, width]);
var y = d3.scale.linear().range([height, 0]);
var color = d3.scale.category10() // d3.scale.ordinal().range(["steelblue", "red", "green", "purple"]);
var xAxis = d3.svg.axis()
.scale(x)
.orient("top")
.outerTickSize(0);
var yAxis = d3.svg.axis()
.scale(y)
.orient("right")
.outerTickSize(0);
// set the axis ticks based on input parameters,
// if ticks or tickValues are specified
if('undefined' !== typeof cfg["ticks"]) {
xAxis.ticks(cfg["ticks"]);
yAxis.ticks(cfg["ticks"]);
} else if ('undefined' !== typeof cfg["tickValues"]) {
xAxis.tickValues(cfg["tickValues"]);
yAxis.tickValues(cfg["tickValues"]);
} else {
xAxis.ticks(5);
yAxis.ticks(5);
}
// apply the format to the ticks we chose
xAxis.tickFormat(format);
yAxis.tickFormat(format);
// a function that returns a line generator
function curve(data, tpr) {
var lineGenerator = d3.svg.line()
.interpolate(interpolationMode)
.x(function(d) { return x(d[fpr]); })
.y(function(d) { return y(d[tpr]); });
return lineGenerator(data);
}
// a function that returns an area generator
function areaUnderCurve(data, tpr) {
var areaGenerator = d3.svg.area()
.x(function(d) { return x(d[fpr]); })
.y0(height)
.y1(function(d) { return y(d[tpr]); });
return areaGenerator(data);
}
var svg = d3.select("#roc")
.append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
x.domain([0, 1]);
y.domain([0, 1]);
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis)
.append("text")
.attr("x", width / 2)
.attr("y", 40 )
.style("text-anchor", "middle")
.text("False Positive Rate")
var xAxisG = svg.select("g.x.axis");
// draw the top boundary line
xAxisG.append("line")
.attr({
"x1": -1,
"x2": width + 1,
"y1": -height,
"y2": -height
});
// draw a bottom boundary line over the existing
// x-axis domain path to make even corners
xAxisG.append("line")
.attr({
"x1": -1,
"x2": width + 1,
"y1": 0,
"y2": 0
});
// position the axis tick labels below the x-axis
xAxisG.selectAll('.tick text')
.attr('transform', 'translate(0,' + 25 + ')');
// hide the y-axis ticks for 0 and 1
xAxisG.selectAll("g.tick line")
.style("opacity", function(d) {
// if d is an integer
return d % 1 === 0 ? 0 : 1;
});
svg.append("g")
.attr("class", "y axis")
.call(yAxis)
.append("text")
.attr("transform", "rotate(-90)")
.attr("y", -35)
// manually configured so that the label is centered vertically
.attr("x", 0 - height/1.56)
.style("font-size","12px")
.style("text-anchor", "left")
.text("True Positive Rate");
yAxisG = svg.select("g.y.axis");
// add the right boundary line
yAxisG.append("line")
.attr({
"x1": width,
"x2": width,
"y1": 0,
"y2": height
})
// position the axis tick labels to the right of
// the y-axis and
// translate the first and the last tick labels
// so that they are right aligned
// or even with the 2nd digit of the decimal number
// tick labels
yAxisG.selectAll("g.tick text")
.attr('transform', function(d) {
if(d % 1 === 0) { // if d is an integer
return 'translate(' + -22 + ',0)';
} else if((d*10) % 1 === 0) { // if d is a 1 place decimal
return 'translate(' + -32 + ',0)';
} else {
return 'translate(' + -42 + ',0)';
}
})
// hide the y-axis ticks for 0 and 1
yAxisG.selectAll("g.tick line")
.style("opacity", function(d) {
// if d is an integer
return d % 1 === 0 ? 0 : 1;
});
// draw the random guess line
svg.append("line")
.attr("class", "curve")
.style("stroke", "black")
.attr({
"x1": 0,
"x2": width,
"y1": height,
"y2": 0
})
.style({
"stroke-width": 2,
"stroke-dasharray": "8",
"opacity": 0.4
})
// draw the ROC curves
function drawCurve(data, tpr, stroke){
svg.append("path")
.attr("class", "curve")
.style("stroke", stroke)
.attr("d", curve(data, tpr))
.on('mouseover', function(d) {
var areaID = "#" + tpr + "Area";
svg.select(areaID)
.style("opacity", .4)
var aucText = "." + tpr + "text";
svg.selectAll(aucText)
.style("opacity", .9)
})
.on('mouseout', function(){
var areaID = "#" + tpr + "Area";
svg.select(areaID)
.style("opacity", 0)
var aucText = "." + tpr + "text";
svg.selectAll(aucText)
.style("opacity", 0)
});
}
// draw the area under the ROC curves
function drawArea(data, tpr, fill) {
svg.append("path")
.attr("class", "area")
.attr("id", tpr + "Area")
.style({
"fill": fill,
"opacity": 0
})
.attr("d", areaUnderCurve(data, tpr))
}
function drawAUCText(auc, tpr, label) {
svg.append("g")
.attr("class", tpr + "text")
.style("opacity", 0)
.attr("transform", "translate(" + .5*width + "," + .79*height + ")")
.append("text")
.text(label)
.style({
"fill": "white",
"font-size": 18
});
svg.append("g")
.attr("class", tpr + "text")
.style("opacity", 0)
.attr("transform", "translate(" + .5*width + "," + .84*height + ")")
.append("text")
.text("AUC = " + aucFormat(auc))
.style({
"fill": "white",
"font-size": 18
});
}
// calculate the area under each curve
tprVariables.forEach(function(d){
var tpr = d.name;
var points = generatePoints(data, fpr, tpr);
var auc = calculateArea(points);
d["auc"] = auc;
})
console.log("tprVariables", tprVariables);
// draw curves, areas, and text for each
// true-positive rate in the data
tprVariables.forEach(function(d, i){
console.log("drawing the curve for", d.label)
console.log("color(", i, ")", color(i));
var tpr = d.name;
drawArea(data, tpr, color(i))
drawCurve(data, tpr, color(i));
drawAUCText(d.auc, tpr, d.label);
})
///////////////////////////////////////////////////
////// animate through areas for each curve ///////
///////////////////////////////////////////////////
if(animate && animate !== "false") {
//sort tprVariables ascending by AUC
var tprVariablesAscByAUC = tprVariables.sort(function(a, b) {
return a.auc - b.auc;
})
console.log("tprVariablesAscByAUC", tprVariablesAscByAUC);
for(var i = 0; i < tprVariablesAscByAUC.length; i++) {
areaID = "#" + tprVariablesAscByAUC[i]["name"] + "Area";
svg.select(areaID)
.transition()
.delay(2000 * (i+1))
.duration(250)
.style("opacity", .4)
.transition()
.delay(2000 * (i+2))
.duration(250)
.style("opacity", 0)
textClass = "." + tprVariablesAscByAUC[i]["name"] + "text";
svg.selectAll(textClass)
.transition()
.delay(2000 * (i+1))
.duration(250)
.style("opacity", .9)
.transition()
.delay(2000 * (i+2))
.duration(250)
.style("opacity", 0)
}
}
///////////////////////////////////////////////////
///////////////////////////////////////////////////
///////////////////////////////////////////////////
function generatePoints(data, x, y) {
var points = [];
data.forEach(function(d){
points.push([ Number(d[x]), Number(d[y]) ])
})
return points;
}
// numerical integration
function calculateArea(points) {
var area = 0.0;
var length = points.length;
if (length <= 2) {
return area;
}
points.forEach(function(d, i) {
var x = 0,
y = 1;
if('undefined' !== typeof points[i-1]){
area += (points[i][x] - points[i-1][x]) * (points[i-1][y] + points[i][y]) / 2;
}
});
return area;
}
} // rocChart