block by bricedev 1bd45a5f6d727499ee46

Gradient descent

Full Screen

Linear regression using gradient descent method

index.html

<!DOCTYPE html>
<meta charset="utf-8">
<style>

body {
  font: 10px sans-serif;
}

.axis path,
.axis line {
  fill: none;
  stroke: #000;
  shape-rendering: crispEdges;
}

.line {
  fill: none;
  stroke: black;
  stroke-width: 1px;
}

</style>
<body>
<script src="https://d3js.org/d3.v3.min.js"></script>
<script>

var margin = {top: 20, right: 20, bottom: 30, left: 40},
    width = 960 - margin.left - margin.right,
    height = 500 - margin.top - margin.bottom;

var format = d3.format(".3f");

var x = d3.scale.linear()
    .range([0, width]);

var y = d3.scale.linear()
    .range([height, 0]);

var xAxis = d3.svg.axis()
    .scale(x)
    .orient("bottom");

var yAxis = d3.svg.axis()
    .scale(y)
    .orient("left");

var svg = d3.select("body").append("svg")
    .attr("width", width + margin.left + margin.right)
    .attr("height", height + margin.top + margin.bottom)
  .append("g")
    .attr("transform", "translate(" + margin.left + "," + margin.top + ")");

d3.csv("data.csv", function(error, data) {
  data.forEach(function(d) {
    d.population = +d.population;
    d.profit = +d.profit;
  });

  x.domain(d3.extent(data, function(d) { return d.population; })).nice();
  y.domain(d3.extent(data, function(d) { return d.profit; })).nice();

  var xMin = x.domain()[0],
      xMax = x.domain()[1],
      yMin = y.domain()[0],
      yMax = y.domain()[1];

  svg.append("g")
      .attr("class", "x axis")
      .attr("transform", "translate(0," + height + ")")
      .call(xAxis)
    .append("text")
      .attr("class", "label")
      .attr("x", width)
      .attr("y", -6)
      .style("text-anchor", "end")
      .style("font-weight","bold")
      .text("Population of City in 10,000s");

  svg.append("g")
      .attr("class", "y axis")
      .call(yAxis)
    .append("text")
      .attr("class", "label")
      .attr("transform", "rotate(-90)")
      .attr("y", 6)
      .attr("dy", ".71em")
      .style("font-weight","bold")
      .style("text-anchor", "end")
      .text("Profit in $10,000s")

  svg.selectAll(".dot")
      .data(data)
    .enter().append("circle")
      .attr("class", "dot")
      .attr("r", 3.5)
      .attr("cx", function(d) { return x(d.population); })
      .attr("cy", function(d) { return y(d.profit); })
      .style("fill","#d73027");

  // Some gradient descent settings
  var iteration = 0,
      iterationNumber = 1500,
      m = data.length,
      alpha = 0.01;
      theta0 = 0,
      theta1 = 0;

  var line = svg.append("line")
    .attr("class", "line")
    .attr("x1",x( xMin ))
    .attr("y1",y( theta1 * xMin + theta0 ))
    .attr("x2",x( xMax ))
    .attr("y2",y( theta1 * xMax + theta0 ));

  var hyp = svg.append("text")
    .attr("x", width/2)
    .attr("y", 40)
    .style("text-anchor","middle")
    .style("font-size","35px")
    .text("hθ(x) = 0 + 0x");

  function computeCost (data, theta0, theta1) {
    var cost = 0;
    data.forEach(function(d) {
      cost += Math.pow((theta1 * d.population + theta0 - d.profit),2);
    });
    return cost/(2 * m);
  };

  d3.timer(function() {

    var temp0 = theta0 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit); }));
    var temp1 = theta1 - alpha * (1/m) * d3.sum(data.map(function(d) { return ((theta1 * d.population + theta0) - d.profit) * d.population ; }));

    theta0 = temp0;
    theta1 = temp1;

    line.attr("x1",x( xMin ))
    .attr("y1",y( theta1 * xMin + theta0 ))
    .attr("x2",x( xMax ))
    .attr("y2",y( theta1 * xMax + theta0 ));
  
    hyp.text("hθ(x) = " + format(theta0) + " + " + format(theta1) + "x");

  return ++iteration > iterationNumber;
  },200);

});

</script>

data.csv

population,profit
6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705