Cost function J(θ) convergence for different learning rate α values.
<!DOCTYPE html>
<meta charset="utf-8">
<style>
body {
font: 10px sans-serif;
}
.axis path,
.axis line {
fill: none;
stroke: #fff;
shape-rendering: crispEdges;
}
.axis text {
font-weight: bold;
}
.line {
fill: none;
stroke: #000;
stroke-width: 1.2px;
stroke-linejoin: round;
}
</style>
<body>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script>
var margin = {top: 20, right: 20, bottom: 20, left: 30},
width = 960 - margin.left - margin.right,
height = 500 - margin.top - margin.bottom;
var x = d3.scale.linear()
.range([0, width]);
var y = d3.scale.linear()
.range([height, 0]);
var xAxis = d3.svg.axis()
.scale(x)
.tickSize(-height)
.tickPadding(8)
.orient("bottom");
var yAxis = d3.svg.axis()
.scale(y)
.tickSize(-width)
.tickPadding(8)
.orient("left");
var svg = d3.select("body").append("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
.append("g")
.attr("transform", "translate(" + margin.left + "," + margin.top + ")");
d3.csv("data.csv", function(error, data) {
// Some gradient descent settings
var iterationNumber = 150,
m = data.length,
alpha = [0.01,0.001,0.0005,0.0001],
theta0 = 0,
theta1 = 0;
data.forEach(function(d) {
d.population = +d.population;
d.profit = +d.profit;
});
var values = alpha.map(function(alphaValue) {
var costHistory = [];
theta0 = 0;
theta1 = 0;
for(i=0;i<iterationNumber;i++) {
costHistory.push({iteration: i, cost: computeCost(data, theta0, theta1) });
var temp0 = theta0 - alphaValue * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
var temp1 = theta1 - alphaValue * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));
theta0 = temp0;
theta1 = temp1;
};
return {alpha: alphaValue, data: costHistory};
})
x.domain([0,iterationNumber]);
var yMin = d3.min(values, function(alphaValue) { return d3.min(alphaValue.data, function(d) { return d.cost; }); });
var yMax = d3.max(values, function(alphaValue) { return d3.max(alphaValue.data, function(d) { return d.cost; }); });
y.domain([yMin,yMax]).nice();
var line = d3.svg.line()
.interpolate("basis")
.x(function(d) { return x(d.iteration); })
.y(function(d) { return y(d.cost); });
svg.append("rect")
.attr("class", "background")
.attr("width", width)
.attr("height", height)
.style("fill","#e7e7e7");
svg.append("g")
.attr("class", "x axis")
.attr("transform", "translate(0," + height + ")")
.call(xAxis)
.append("text")
.attr("class", "label")
.attr("x", width-5)
.attr("y", -6)
.style("text-anchor", "end")
.style("font-weight","bold")
.text("Number of iterations");
svg.append("g")
.attr("class", "y axis")
.call(yAxis)
.append("text")
.attr("class", "label")
.attr("transform", "rotate(-90)")
.attr("dx","-.71em")
.attr("y", 6)
.attr("dy", ".71em")
.style("font-weight","bold")
.style("text-anchor", "end")
.text("J(θ)")
var path = svg.selectAll(".line")
.data(values).enter()
path.append("path")
.attr("id", function(d) { console.log(d); return d.alpha; })
.attr("class","line")
.attr("d", function(d) { return line(d.data); });
path.append("text")
.attr("dy", "-3px")
.append("textPath")
.attr("xlink:href", function(d) { return "#" + d.alpha; })
.attr("startOffset", "24%")
.style("font-weight","bold")
.text(function(d) { return "α = " + d.alpha; });
function computeCost (data, theta0, theta1) {
var cost = 0;
data.forEach(function(d) {
cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
});
return cost/(2 * m);
};
});
</script>
population,profit
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705