我是靠谱客的博主 无心蜻蜓,最近开发中收集的这篇文章主要介绍【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。

1、生成训练、测试数据

// digits-每个字符位数,trainingSize-训练集大小
function generateData(digits, trainingSize) {
  // 所有可选字符集
  const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
  const arraySize = digitArray.length;

  // 输出
  const output = [];
  const maxLen = digits + 1 + digits;

  // 从digitArray挑选digits个数据拼为一个数字
  const f = () => {
    let str = '';
    while (str.length < digits) {
      const index = Math.floor(Math.random() * arraySize);
      str += digitArray[index];
    }
    return Number.parseInt(str);
  };

  // 生成trainingSize组数据
  while (output.length < trainingSize) {
    const a = f();
    const b = f();

    const q = `${a}+${b}`;
    // 补空格
    const query = q + ' '.repeat(maxLen - q.length);
    let ans = (a + b).toString();
    // 补空格
    ans += ' '.repeat(digits + 1 - ans.length);
    output.push([query, ans]);
  }
  return output;
}

digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。

2、数据分组并转为tensor

// 90%训练集,10%测试集
const split = Math.floor(trainingSize * 0.9);
this.trainData = data.slice(0, split);
this.testData = data.slice(split);

// 转为tensors,并分为训练组、测试组
[this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits);
[this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);

将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。

function convertDataToTensors(data, charTable, digits) {
  const maxLen = digits + 1 + digits;
  // data中每一项datum = [query, ans]
  const questions = data.map(datum => datum[0]);
  const answers = data.map(datum => datum[1]);
  return [
    charTable.encodeBatch(questions, maxLen),
    charTable.encodeBatch(answers, digits + 1),
  ];
}

对query、ans编码,需要字符集类CharacterTable。

class CharacterTable {
  constructor(chars) {
    this.chars = chars;
    // 字符-位置index
    this.charIndices = {};
    // 位置index-字符
    this.indicesChar = {};
    this.size = this.chars.length;
    for (let i = 0; i < this.size; ++i) {
      const char = this.chars[i];
      this.charIndices[this.chars[i]] = i;
      this.indicesChar[i] = this.chars[i];
    }
  }

  // 输入questions、answers数组,输出转化的tensor
  encodeBatch(strings, maxLen) {
    const numExamples = strings.length;
    const b

最后

以上就是无心蜻蜓为你收集整理的【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”的全部内容,希望文章能够帮你解决【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(51)

评论列表共有 0 条评论

立即
投稿
返回
顶部