概述
这篇是接上一篇的基于spark的LSTM字符模型,数据源是莎士比亚的段子,利用LSTM每次学习一个字符,然后写段子,代码如下
public class SparkLSTMCharacterExample { private static final Logger log = LoggerFactory.getLogger(SparkLSTMCharacterExample.class); private static Map<Integer, Character> INT_TO_CHAR = getIntToChar();//调用函数,返回索引和对应字符的map,可以先看后面的函数 private static Map<Character, Integer> CHAR_TO_INT = getCharToInt();//调用函数,返回字符和对应索引的map,可以先看后面的函数 private static final int N_CHARS = INT_TO_CHAR.size();//计算索引数 private static int nOut = CHAR_TO_INT.size();//计算字符数 private static int exampleLength = 1000; //Length of each training example sequence to use//训练实例序列的长度为100 @Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1)//各种参数不再赘述了 private boolean useSparkLocal = true; @Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with") private int batchSizePerWorker = 8; //How many examples should be used per worker (executor) when fitting? @Parameter(names = "-numEpochs", description = "Number of epochs for training") private int numEpochs = 1; public static void main(String[] args) throws Exception { new SparkLSTMCharacterExample().entryPoint(args);//调用入口函数 } protected void entryPoint(String[] args) throws Exception { //Handle command line arguments JCommander jcmdr = new JCommander(this);//jCommander处理参数也不说了 try { jcmdr.parse(args); } catch (ParameterException e) { //User provides invalid input -> print the usage info jcmdr.usage(); try { Thread.sleep(500); } catch (Exception e2) { } throw e; } Random rng = new Random(12345);//随机生成器 int lstmLayerSize = 200; //Number of units in each GravesLSTM layer//LSTM层节点数量 int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters//截断式bptt中网络学习的长度 int nSamplesToGenerate = 4; //Number of samples to generate after each training epoch//每个训练步后生成的例子数量,这是要模仿写文章,所以有这个参数 int nCharactersToSample = 300; //Length of each sample to generate//生成例子的长度300,这是要模仿写文章,所以有这个参数 String generationInitialization = null; //Optional character initialization; a random character is used if null//初始化字符,这里是随机字符 // Above is Used to 'prime' the LSTM with a character sequence to continue/complete. // Initialization characters must all be in CharacterIterator.getMinimalCharacterSet() by default //Set up network configuration://设置网络 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .learningRate(0.1) .rmsDecay(0.95) .seed(12345) .regularization(true) .l2(0.001) .weightInit(WeightInit.XAVIER) .updater(Updater.RMSPROP) .list() .layer(0, new GravesLSTM.Builder().nIn(CHAR_TO_INT.size()).nOut(lstmLayerSize)//第一层是LSTM,输入大小独立字符数,输出大小是200,果然又是放大了好多,可见cnn是把节点越搞越小,rnn是把节点越搞越大 .activation("tanh").build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)//第二层还是LSTM层,输入输出节点都是200 .activation("tanh").build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax") //MCXENT + softmax for classification//输出层是RNN,由于是分类采用softmax作为激活函数,输入大小是200,输出和原始输入大小一致 .nIn(lstmLayerSize).nOut(nOut).build()) .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength)//使用截断式bptt,截断长度为50,即正反向参数更新参考的长度都是50 .pretrain(false).backprop(true) .build(); //------------------------------------------------------------- //Set up the Spark-specific configuration//配置spark /* How frequently should we average parameters (in number of minibatches)? Averaging too frequently can be slow (synchronization + serialization costs) whereas too infrequently can result learning difficulties (i.e., network may not converge) */ int averagingFrequency = 3;//参数平均化的频率,3批平均一次 //Set up Spark configuration and context SparkConf sparkConf = new SparkConf();//使用spark本地模式 if (useSparkLocal) { sparkConf.setMaster("local[*]"); } sparkConf.setAppName("LSTM Character Example"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD<DataSet> trainingData = getTrainingData(sc);//获取数据得到训练RDD,跳到这个函数 //Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark //Here, we are using standard parameter averaging //For details on these configuration options, see: https://deeplearning4j.org/spark#configuring//设置tm int examplesPerDataSetObject = 1;//每个DataSet对象有一个例子 ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)//构建tm .workerPrefetchNumBatches(2) //Asynchronously prefetch up to 2 batches//异步获取2批数据 .averagingFrequency(averagingFrequency)//参数平均化的频率是3 .batchSizePerWorker(batchSizePerWorker)//每个worker处理批的大小是8 .build(); SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc, conf, tm);//把参数传入spark的网络配置 sparkNetwork.setListeners(Collections.<IterationListener>singletonList(new ScoreIterationListener(1)));//设置监听器,singletonList返回一个包含具体对象的不可变list //Do training, and then generate and print samples from network for (int i = 0; i < numEpochs; i++) {//按步数训练,生成并打印新写的例子,每步最后返回一个训练网络的副本 //Perform one epoch of training. At the end of each epoch, we are returned a copy of the trained network MultiLayerNetwork net = sparkNetwork.fit(trainingData);//定型网络 //Sample some characters from the network (done locally)//本地随机化一些字符 log.info("Sampling characters from network given initialization "" + (generationInitialization == null ? "" : generationInitialization) + """); String[] samples = sampleCharactersFromNetwork(generationInitialization, net, rng, INT_TO_CHAR, nCharactersToSample, nSamplesToGenerate);//利用学习的生成新的例子,看下这个函数 for (int j = 0; j < samples.length; j++) {//打印随机化字符 log.info("----- Sample " + j + " -----"); log.info(samples[j]); } } //Delete the temp training files, now that we are done with them tm.deleteTempFiles(sc);//删除临时文件 log.info("nnExample complete"); } /** * Get the training data - a JavaRDD<DataSet>//注释说这个获取数据的方法是字符建模的特例,不是最佳实践 * Note that this approach for getting training data is a special case for this example (modelling characters), and * should not be taken as best practice for loading data (like CSV etc) in general. */ public static JavaRDD<DataSet> getTrainingData(JavaSparkContext sc) throws IOException { //Get data. For the sake of this example, we are doing the following operations: // File -> String -> List<String> (split into length "sequenceLength" characters) -> JavaRDD<String> -> JavaRDD<DataSet>//为了获取文件,我们从文件到字符串到序列长度的list再到RDD最终到DataSet的RDD List<String> list = getShakespeareAsList(exampleLength);//获取长度为1000的字符串list JavaRDD<String> rawStrings = sc.parallelize(list);//并行化数据 Broadcast<Map<Character, Integer>> bcCharToInt = sc.broadcast(CHAR_TO_INT);//广播字符和索引的map return rawStrings.map(new StringToDataSetFn(bcCharToInt));//又见scala的map,把并行化的数据 } private static class StringToDataSetFn implements Function<String, DataSet> {//java做map就得实现一个函数,这样功能就类似于scala的map了,Function的第一个参数是输入类型字符串,第二个参数是结果类型DataSet private final Broadcast<Map<Character, Integer>> ctiBroadcast;//定义一个广播变量 private StringToDataSetFn(Broadcast<Map<Character, Integer>> characterIntegerMap) {//构造函数,接收传入的广播变量 this.ctiBroadcast = characterIntegerMap; } @Override public DataSet call(String s) throws Exception {//回调函数,返回DataSet //Here: take a String, and map the characters to a one-hot representation//把字符串搞成one-hot描述 Map<Character, Integer> cti = ctiBroadcast.getValue();//广播变量的内容 int length = s.length();//字符串长度,由于最后一个长度可能不是1000,所以求下长度 INDArray features = Nd4j.zeros(1, N_CHARS, length - 1);//从spark的数据弄成nd4j的数据,第一个参数代表有1个元素,第二个参数代表这个矩阵元素的行即字符索引数,第三个参数代表这个矩阵元素的列即字符的长度 INDArray labels = Nd4j.zeros(1, N_CHARS, length - 1);//同理再搞一个放标签 char[] chars = s.toCharArray();//把字符串转成字符数组 int[] f = new int[3];//搞两个长度为3的整形数组 int[] l = new int[3]; for (int i = 0; i < chars.length - 2; i++) {//遍历字符数组 f[1] = cti.get(chars[i]);//在广播变量里搜索字符的索引,放到f的第二个位置,把的字符数组索引放入f的第三个位置 f[2] = i; l[1] = cti.get(chars[i + 1]); //Predict the next character given past and current characters l[2] = i;//在广播变量里搜索下一个字符的索引,放到l的第二个字符,把字符数组索引放入l的第三个位置,有点预测下一个字符的意思 features.putScalar(f, 1.0);//这里看出f第一个位置不放数字的原因是nd4j高维数组只有1个元素,f代表位置索引,1代表把f代表的位置置为1,one-hot一般都是这个套路 labels.putScalar(l, 1.0);//同理把标签放好 } return new DataSet(features, labels);//DataSet装入特征和标签,这也是单行map计算的返回结果 } } //This function downloads (if necessary), loads and splits the raw text data into "sequenceLength" strings private static List<String> getShakespeareAsList(int sequenceLength) throws IOException {//下载数据并切分长度为1000的字符串列表 //The Complete Works of William Shakespeare//数据概要,莎士比亚 //5.3MB file in UTF-8 Encoding, ~5.4 million characters //https://www.gutenberg.org/ebooks/100 String url = "https://s3.amazonaws.com/dl4j-distribution/pg100.txt";//从哪下 String tempDir = System.getProperty("java.io.tmpdir");//下载目录 String fileLocation = tempDir + "/Shakespeare.txt"; //Storage location from downloaded file//下载文件名 File f = new File(fileLocation);//声明文件类 if (!f.exists()) {//不存在就下载 FileUtils.copyURLToFile(new URL(url), f); System.out.println("File downloaded to " + f.getAbsolutePath()); } else { System.out.println("Using existing text file at " + f.getAbsolutePath()); } if (!f.exists()) throw new IOException("File does not exist: " + fileLocation); //Download problem?//下载有问题报异常 String allData = getDataAsString(fileLocation);//又嵌套了个函数,跳过去看下 List<String> list = new ArrayList<>();//搞一个list int length = allData.length();//计算大字符串长度 int currIdx = 0; while (currIdx + sequenceLength < length) {//如果当前索引加字符长度小于总长度 int end = currIdx + sequenceLength;//循环计算字符序列尾索引 String substr = allData.substring(currIdx, end);//截取串 currIdx = end;//把结尾索引赋值给新的当前索引 list.add(substr);//往list添加长度为1000的字符串 } return list;//返回list } /** * Load data from a file, and remove any invalid characters.//加载数据,过滤无效字符,返回大字符串 * Data is returned as a single large String */ private static String getDataAsString(String filePath) throws IOException { List<String> lines = Files.readAllLines(new File(filePath).toPath(), Charset.defaultCharset());//readAllLines这个方法读取文件的所有行,文件字节使用具体的字符集解码成字符,该方法不适合读大文件,第一个参数是文件路径,第二个是用于解码的字符集,返回文件行的列表 StringBuilder sb = new StringBuilder();//弄一个字符串缓冲 for (String line : lines) {//把每行弄成一个字符数组,遍历字符数组,如果字符和索引map包含遍历字符,塞进缓冲字符串,这样就起到了过滤的作用,最后返回一个带换行的大字符串 char[] chars = line.toCharArray(); for (int i = 0; i < chars.length; i++) { if (CHAR_TO_INT.containsKey(chars[i])) sb.append(chars[i]); } sb.append("n"); } return sb.toString(); } /** * Generate a sample from the network, given an (optional, possibly null) initialization. Initialization * can be used to 'prime' the RNN with a sequence you want to extend/continue.<br> * Note that the initalization is used for all samples * * @param initialization String, may be null. If null, select a random character as initialization for all samples * @param charactersToSample Number of characters to sample from network (excluding initialization) * @param net MultiLayerNetwork with one or more GravesLSTM/RNN layers and a softmax output layer *///根据给定的参数生成一个范例,初始化可以用于引导rnn按你提供的句子接着往下写 private static String[] sampleCharactersFromNetwork(String initialization, MultiLayerNetwork net, Random rng, Map<Integer, Character> intToChar, int charactersToSample, int numSamples) {//initialization初始化的字符串,可以为空,net是spark网络,rng随机数,intToChar是每个索引对应的字符,charactersToSample是范例的字符数,就是除了初始化的字符,继续往下写多少个字,numSamples每个训练步完成后写几个例子,这里是4,也就是每个训练步完成写4个例子,每个例子300个字符
//Set up initialization. If no initialization: use a random character if (initialization == null) {//生成第一个字符 int randomCharIdx = rng.nextInt(intToChar.size()); initialization = String.valueOf(intToChar.get(randomCharIdx)); } //Create input for initialization INDArray initializationInput = Nd4j.zeros(numSamples, intToChar.size(), initialization.length());//生成一个三维数组,第一个参数是写段子的数量,第二个参数是段子词汇索引长度,第三个参数是初始化字符串长度,其实就是1个字符,这和训练样本的shape有所不同 char[] init = initialization.toCharArray();//把初始化的字符串转成字符数组 for (int i = 0; i < init.length; i++) {//遍历这个字符数组,如果不给初始字符串,其实只有一个字符,遍历一次 int idx = CHAR_TO_INT.get(init[i]);//找出初始化字符对应的索引 for (int j = 0; j < numSamples; j++) {//依次写4个例子 initializationInput.putScalar(new int[]{j, idx, i}, 1.0f);//通过把不同位置的索引置为1来写 } } StringBuilder[] sb = new StringBuilder[numSamples];//可变字符数组 for (int i = 0; i < numSamples; i++) sb[i] = new StringBuilder(initialization);//把初始化字符串放入每个可变字符数组,也就是开头都一样,后面写的不一样 //Sample from network (and feed samples back into input) one character at a time (for all samples) //Sampling is done in parallel here//并行写范文,一边写一遍反馈到输入,一次写一个字符 net.rnnClearPreviousState();//清理rnn之前的状态参数 INDArray output = net.rnnTimeStep(initializationInput);//initializationInput是网络的输入,本例中是单时间步,initializationInput的第一个参数代表批大小,第二个参数是输入大小也就是索引大小,第三个参数是1也就是单时间步,output是输出激活函数,和输入的维度一致,这其实相当于设定的预测模式,输入什么样的格式,输出一个什么样的格式 output = output.tensorAlongDimension(output.size(2) - 1, 1, 0); //Gets the last time step output//获取最后一个时间步的输出,tensorAlongDimension这个方法改变向量的维度,第一个参数是要改变向量的索引,后两个参数是要改成的维度,这里output.size(2)的意思是取output第二个维度的大小是1,改成1行0列的形式,也就是把每个范例弄成一行,相当于转置 for (int i = 0; i < charactersToSample; i++) {//开始写300个字符 //Set up next input (single time step) by sampling from previous output//根据之前的输出设置下一个输入 INDArray nextInput = Nd4j.zeros(numSamples, intToChar.size());//搞一个多为数组,行数4,列数是字符索引数 //Output is a probability distribution. Sample from this for each example we want to generate, and add it to the new input//输出是概率分布,根据这个产生样例,并添加到新的输入,具体看看下面代码 for (int s = 0; s < numSamples; s++) {//每个样例 double[] outputProbDistribution = new double[intToChar.size()];//搞一个字符索引长度的概率分布数组 for (int j = 0; j < outputProbDistribution.length; j++)//对每个概率 outputProbDistribution[j] = output.getDouble(s, j);//获取该样例该位置的概率 int sampledCharacterIdx = sampleFromDistribution(outputProbDistribution, rng);//函数sampleFromDistribution的作用是从分布中选出索引 nextInput.putScalar(new int[]{s, sampledCharacterIdx}, 1.0f); //Prepare next time step input//写入下一个输入 sb[s].append(intToChar.get(sampledCharacterIdx)); //Add sampled character to StringBuilder (human readable output)//根据索引获得字符添加到对应缓冲数组 } output = net.rnnTimeStep(nextInput); //Do one time step of forward pass//向前做一个时间步 } String[] out = new String[numSamples];//搞4个样例字符串数组 for (int i = 0; i < numSamples; i++) out[i] = sb[i].toString();//把每个样例转成字符串写进去并返回 return out; } /** * Given a probability distribution over discrete classes, sample from the distribution * and return the generated class index.//获取一个概率分布,从中抽样并返回产生类的索引 * * @param distribution Probability distribution over classes. Must sum to 1.0 */ private static int sampleFromDistribution(double[] distribution, Random rng) {//传入分布数组和随机生成器 double d = rng.nextDouble(); double sum = 0.0; for (int i = 0; i < distribution.length; i++) {//遍历分布数组,累加分布值知道大于等于随机数,这时返回索引,这说明先遇到较大概率的索引容易被选中 sum += distribution[i]; if (d <= sum) return i; } //Should never happen if distribution is a valid probability distribution throw new IllegalArgumentException("Distribution is invalid? d=" + d + ", sum=" + sum); } /** * A minimal character set, with a-z, A-Z, 0-9 and common punctuation etc */ private static char[] getValidCharacters() { List<Character> validChars = new LinkedList<>();//搞一个字符list for (char c = 'a'; c <= 'z'; c++) validChars.add(c);//用a-z,A-Z,0-9和特殊字符填充list for (char c = 'A'; c <= 'Z'; c++) validChars.add(c); for (char c = '0'; c <= '9'; c++) validChars.add(c); char[] temp = {'!', '&', '(', ')', '?', '-', ''', '"', ',', '.', ':', ';', ' ', 'n', 't'}; for (char c : temp) validChars.add(c); char[] out = new char[validChars.size()];//搞一个新的字符数组 int i = 0; for (Character c : validChars) out[i++] = c;//把list的内容放到数组里 return out; } public static Map<Integer, Character> getIntToChar() { Map<Integer, Character> map = new HashMap<>();//搞一个map char[] chars = getValidCharacters();//获取有效字符,调用函数,可以先看函数 for (int i = 0; i < chars.length; i++) { map.put(i, chars[i]);//以索引为主键,字符为value填充map } return map; } public static Map<Character, Integer> getCharToInt() { Map<Character, Integer> map = new HashMap<>();//搞一个map char[] chars = getValidCharacters();//获取有效字符,调用函数,可以先看函数 for (int i = 0; i < chars.length; i++) {//遍历字符数组 map.put(chars[i], i);//以字符为主键,索引为value填充map } return map; } }
最后
以上就是悲凉冥王星为你收集整理的深度学习-基于spark的LSTM的全部内容,希望文章能够帮你解决深度学习-基于spark的LSTM所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复