概述
tensorflow.js实现了几种RNN的接口,包括SimpleRNN、GRU和LSTM。这篇笔记介绍如何在浏览器环境下利用tensorflow.js训练RNN学习加法运算,即给出一个加法算式的字符串,算出数字结果,类似于自然语言处理。
1、生成训练、测试数据
// digits-每个字符位数,trainingSize-训练集大小
function generateData(digits, trainingSize) {
// 所有可选字符集
const digitArray = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'];
const arraySize = digitArray.length;
// 输出
const output = [];
const maxLen = digits + 1 + digits;
// 从digitArray挑选digits个数据拼为一个数字
const f = () => {
let str = '';
while (str.length < digits) {
const index = Math.floor(Math.random() * arraySize);
str += digitArray[index];
}
return Number.parseInt(str);
};
// 生成trainingSize组数据
while (output.length < trainingSize) {
const a = f();
const b = f();
const q = `${a}+${b}`;
// 补空格
const query = q + ' '.repeat(maxLen - q.length);
let ans = (a + b).toString();
// 补空格
ans += ' '.repeat(digits + 1 - ans.length);
output.push([query, ans]);
}
return output;
}
digits代表输入数字的位数,比如567的位数是3。函数f从digitArray中随机挑选digits个数拼为一个输入。输入a、加号、输入b整体拼为一个query,a+b的真实结果拼为ans。为防止第一个数字为0改变数字位数,query和ans均后补空格,函数返回query、ans字符对。
2、数据分组并转为tensor
// 90%训练集,10%测试集
const split = Math.floor(trainingSize * 0.9);
this.trainData = data.slice(0, split);
this.testData = data.slice(split);
// 转为tensors,并分为训练组、测试组
[this.trainXs, this.trainYs] = convertDataToTensors(this.trainData, this.charTable, digits);
[this.testXs, this.testYs] = convertDataToTensors(this.testData, this.charTable, digits);
将generateData生成的数据分为训练组和验证组,并将字符串转为tensor。转换函数converDataToTensors如下。
function convertDataToTensors(data, charTable, digits) {
const maxLen = digits + 1 + digits;
// data中每一项datum = [query, ans]
const questions = data.map(datum => datum[0]);
const answers = data.map(datum => datum[1]);
return [
charTable.encodeBatch(questions, maxLen),
charTable.encodeBatch(answers, digits + 1),
];
}
对query、ans编码,需要字符集类CharacterTable。
class CharacterTable {
constructor(chars) {
this.chars = chars;
// 字符-位置index
this.charIndices = {};
// 位置index-字符
this.indicesChar = {};
this.size = this.chars.length;
for (let i = 0; i < this.size; ++i) {
const char = this.chars[i];
this.charIndices[this.chars[i]] = i;
this.indicesChar[i] = this.chars[i];
}
}
// 输入questions、answers数组,输出转化的tensor
encodeBatch(strings, maxLen) {
const numExamples = strings.length;
const b
最后
以上就是无心蜻蜓为你收集整理的【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”的全部内容,希望文章能够帮你解决【tensorflow.js学习笔记(5)】使用RNN学习“加法运算”所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复