block by micahstubbs e110f1db857b1da87595515cacd103dc

latent value learning | scaled up

Full Screen

this iteration draws a 384px by 200px svg and scales it up to 960px wide by 500px tall with the svg viewBox attribute

an iteration on the block latent value learning | es2015 from @micahstubbs. originally inspired by the block Latent Value Learning from bricof


This is a cleaned and simplified version of a simulation / animation of latent variable learning used in the Recommendation Systems section of the Stitch Fix Algorithms Tour.

Each circle is assumed to have some latent value along the horizontal axis - some true value for an attribute that we cannot observe directly but that we can try to estimate based on feedback from attempted pair-matches involving one A element and one B element.

The algorithm used to find them is as follows:

The underlying simulation, then, is running this algorithm over a set of entities while also simulating the entities - each has its own latent value and the feedback it provides when paired with other entities is based on the actual differences between their latent values, with some noise added for good measure.

The svg update is straightforward - at each timestep, pairs are shown by lines between the circles, and the circles are transitioned to their new location based on their current estimated latent value.

index.html

<!DOCTYPE html>
<meta charset="utf-8">
<style>
.A-color {
  fill: #4B90A6;
}

.B-color {
  fill: #F3A54A;
}
</style>
<body>
  <script src="https://d3js.org/d3.v4.min.js"></script>
  <script src="animated-learning.js"></script>
  <script>
    animatedLearning()
  </script>
</body>

animated-learning.js

/* global d3 */

function animatedLearning() {
  const parentSVG = d3.select('body').append('svg')
    .attr('width', 960)
    .attr('height', 500)
    .attr('viewBox', '0 0 384 200')
    // .attr('preserveAspectRatio', 'xMinYMin meet');

  const width = 384;
  const height = 200;  // change the height and the rest of the layout responds ✨

  const margin = {
    top: height * 0.05,
    bottom: height * 0.05,
    left: 0,
    right: 0
  };

  // calculate some useful values for plot layout
  const innerHeight = height - margin.top - margin.bottom;
  const pointYDistance  = innerHeight * 0.3;
  const pointYOffset = (innerHeight - pointYDistance) / 2;

  const x = d3.scaleLinear()
    .domain([0, 1])
    .range([0, width]);

  //  viewBox="120 50 580 450"  
  const svg = parentSVG.append('svg')
    .attr('width', width)
    .attr('height', height);

  const g =  svg.append('g')
    .attr('transform', `translate(${margin.left},${margin.top})`);

  // draw the center line
  g.append('line')
    .attr('x1', x.range()[0])
    .attr('x2', x.range()[1])
    .attr('y1', innerHeight / 2)
    .attr('y2', innerHeight / 2)
    .style('stroke', 'black')
    .style('stroke-width', '0.75')
    .style('fill', 'none')

  // draw A elements label
  g.append('text')
    .attr('x', width / 2)
    .attr('y', pointYOffset * 0.4)
    .attr('dy', '0.35em')
    .classed('A-color', true)
    .attr('text-anchor', 'middle')
    .attr('font-size', '12px')
    .text('A elements');

  // draw B elements label
  g.append('text')
    .attr('x', width / 2)
    .attr('y', pointYOffset + pointYDistance + (pointYOffset * 0.6))
    .attr('dy', '0.35em')
    .classed('B-color', true)
    .attr('text-anchor', 'middle')
    .attr('font-size', '12px')
    .text('B elements');

  const learningRate = 0.2;

  const nAs = 10;
  const nBs = 10;
  const nPairs = 8;

  //
  // create elements with latent values and initial positions
  //
  const As = [];
  for (let i = 0; i < nAs; i += 1) {
    As.push({ 
      id: i,
      latentValue: Math.random(),
      currentPosition: 0.5, 
      nextPosition: 0.5
    });
  }

  const Bs = [];
  for (let i = 0; i < nBs; i += 1) {
    Bs.push({
      id: i,
      latentValue: Math.random(),
      currentPosition: 0.5,
      nextPosition: 0.5
    });
  }

  //
  // construct circles
  //
  g.selectAll('.A').data(As, d => d.id)
    .enter().append('circle')
      .attr('class', 'A A-color')
      .attr('cx', width / 2)
      .attr('cy', pointYOffset)
      .attr('r', 3);

  g.selectAll('.B').data(Bs, d => d.id)
    .enter().append('circle')
      .attr('class', 'B B-color')
      .attr('cx', width / 2)
      .attr('cy', pointYOffset + pointYDistance)
      .attr('r', 3);

  //
  // simulation / animation loop
  //
  d3.interval(() => {
    // *** SIMULATION ***

    const pairs = [];

    for (let i = 0; i < nPairs; i += 1) {
      const aId = Math.floor(Math.random() * nAs);

      // pair selection a stochastic function of distance from respective current positions
      const weights = Bs.map(d => d.currentPosition - As[aId].currentPosition);

      let cumWeights = [];
      weights.reduce((a, b, i) => { 
        cumWeights[i] = { 
          v: a + b, 
          id: i 
        }; 
        return a + b; 
      }, 0);
      cumWeights = cumWeights.sort((a, b) => a.v > b.v);

      let bId = Math.floor(Math.random() * nBs);
      if (cumWeights[cumWeights.length - 1].v !== 0) {
        const selRandom = Math.random() * cumWeights[cumWeights.length - 1].v;
        const sel = cumWeights.find(d => d.v >= selRandom);
        if (!(sel == null)) {
          bId = sel.id;
        }
      }

      pairs.push({ aId, bId });

      // big = 1, small = -1
      const feedback = -1 + (2 * (As[aId].latentValue > Bs[bId].latentValue));

      // use feedback if it contradicts current
      if ((feedback === -1) && (As[aId].currentPosition <= Bs[bId].currentPosition)) {
        As[aId].nextPosition = As[aId].currentPosition + (Math.random() * learningRate);
        Bs[bId].nextPosition = Bs[bId].currentPosition - (Math.random() * learningRate);
      }

      if ((feedback === 1) && (As[aId].currentPosition >= Bs[bId].currentPosition)) {
        As[aId].nextPosition = As[aId].currentPosition - (Math.random() * learningRate);
        Bs[bId].nextPosition = Bs[bId].currentPosition + (Math.random() * learningRate);
      }

      As[aId].nextPosition = Math.min(1, Math.max(0, As[aId].nextPosition));
      Bs[bId].nextPosition = Math.min(1, Math.max(0, Bs[bId].nextPosition));
    }

    // *** SVG ANIMATION ***

    const delay = 400;
    const move = 500;

    //
    // draw and animate pair lines
    //
    g.selectAll('.pair').remove();

    g.selectAll('.pair')
      .data(pairs)
      .enter().append('line')
        .attr('class', 'pair')
        .attr('y1', pointYOffset)
        .attr('y2', pointYOffset + pointYDistance)
        .style('stroke', '#000')
        .style('stroke-width', 0.25)
        .style('fill', 'none')
        .attr('x1', d => x(As[d.aId].currentPosition))
        .attr('x2', d => x(Bs[d.bId].currentPosition))
        .transition()
          .delay(delay)
          .duration(move)
          .attr('x1', d => x(As[d.aId].nextPosition))
          .attr('x2', d => x(Bs[d.bId].nextPosition));

    //
    // animate circles
    //
    g.selectAll('.A')
      .transition()
        .delay(delay)
        .duration(move)
        .attr('cx', d => x(d.nextPosition));

    g.selectAll('.B')
      .transition()
        .delay(delay)
        .duration(move)
        .attr('cx', d => x(d.nextPosition));

    // *** end of svg animation code ***

    //
    // prep next timestep
    //
    As.forEach((d) => {
      d.currentPosition = d.nextPosition;
    });

    Bs.forEach((d) => {
      d.currentPosition = d.nextPosition;
    });
  }, 1200);
}