我是靠谱客的博主 潇洒黑米,最近开发中收集的这篇文章主要介绍Spark MLlib源码分析—TFIDF源码详解,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

以下代码是我依据SparkMLlib(版本1.6)
1、HashingTF 是使用哈希表来存储分词,并计算分词频数(TF),生成HashMap表。在Map中,K为分词对应索引号,V为分词的频数。在声明HashingTF 时,需要设置numFeatures,该属性实为设置哈希表的大小;如果设置numFeatures过小,则在存储分词时会出现重叠现象,所以不要设置太小,一般情况下设置为30w~50w之间。
2、IDF是计算每个分词出现在文章中的次数,并计算log值。在声明IDF时,可以设置minDocFreq,即过滤掉出现文章数小于minDocFreq的分词。
3、IDFModel 主要是计算TF*IDF,另外IDFModel也可以将IDF数据保存下来(即模型的保存),在测试语料时,只需要计算测试语料中每个分词的在该篇文章中的词频TF,就可以计算TFIDF。

package org.apache.spark.mllib.feature
class HashingTF(val numFeatures: Int) extends Serializable {
def this() = this(1 << 20)
def nonNegativeMod(x: Int, mod: Int): Int = { //根据 numFeatures 设置的哈希表容量,来设定索引号
val rawMod = x % mod
rawMod + (if (rawMod < 0) mod else 0)
}
def indexOf(term: Any): Int = nonNegativeMod(term.##, numFeatures) //根据分词来生成索引号
def transform(document: Iterable[_]): Vector = {
//每篇文章一个hash表,记录每篇文章中的词频
val termFrequencies = mutable.HashMap.empty[Int, Double]
document.foreach { term =>
val i = indexOf(term)
//map中的getOrElse(i, 0.0)函数表示如果找到i位置的值就返回,否则就默认为0.0
termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)//注意这里有加1计数操作
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = {
dataset.map(this.transform)
}
}
class IDF(val minDocFreq: Int){
def this() = this(0) //默认minDocFreq为0,用来过滤文章出现次数过少的分词
def fit(dataset: RDD[Vector]): IDFModel = {
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(minDocFreq = minDocFreq))(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
new IDFModel(idf)
}
}
private object IDF {
/** Document frequency aggregator. */
class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable {
/** number of documents */
private var m = 0L
/** document frequency vector */
private var df: BDV[Long] = _
def this() = this(0)
private def isEmpty: Boolean = m == 0L
def add(doc: Vector): this.type = { //add -> 计算分词在每个分区中的文章频率
if (isEmpty) {
df = BDV.zeros(doc.size)
}
doc match {
case SparseVector(size, indices, values) =>
val nnz = indices.size
var k = 0
while (k < nnz) {
if (values(k) > 0) {
//表示分词values(k)在该篇文章中出现过
df(indices(k)) += 1L
//计数分词indices(k)出现在多少篇文章中
}
k += 1
}
case DenseVector(values) =>
val n = values.size
var j = 0
while (j < n) {
if (values(j) > 0.0) {
//作用和上面一样,只是在spark中有DenseVector 和 SparseVector两种向量的区别。
df(j) += 1L
}
j += 1
}
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
m += 1L
this
}
/** Merges another. */
def merge(other: DocumentFrequencyAggregator): this.type = { //将各个分区聚合到一起
if (!other.isEmpty) {
m += other.m
if (df == null) {
df = other.df.copy
} else {
df += other.df
}
}
this
}
/** 返回当前IDF的向量 */
def idf(): Vector = {
if (isEmpty) {
throw new IllegalStateException("Haven't seen any document yet.")
}
val n = df.length
val inv = new Array[Double](n)
var j = 0
while (j < n) {
if (df(j) >= minDocFreq) {
inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) //计算IDF —— log(D/d(j))
}
j += 1
}
Vectors.dense(inv)
}
}
}
class IDFModel(val idf: Vector) extends Serializable {
// idf 里面存储的是IDF向量
def transform(dataset: RDD[Vector]): RDD[Vector] = {
//dataset里面存储的是TF向量
val bcIdf = dataset.context.broadcast(idf)
dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v)))
}
def transform(v: Vector): Vector = IDFModel.transform(idf, v)
}
private object IDFModel {
def transform(idf: Vector, v: Vector): Vector = { // 这里就是
idf * v (v是TF向量)
val n = v.size
v match {
case SparseVector(size, indices, values) =>
val nnz = indices.size
val newValues = new Array[Double](nnz)
var k = 0
while (k < nnz) {
newValues(k) = values(k) * idf(indices(k))
//SparseVector 向量下 TF * IDF
k += 1
}
Vectors.sparse(n, indices, newValues)
case DenseVector(values) =>
val newValues = new Array[Double](n)
var j = 0
while (j < n) {
newValues(j) = values(j) * idf(j)
//DenseVector 向量下 TF * IDF
j += 1
}
Vectors.dense(newValues)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
}

最后

以上就是潇洒黑米为你收集整理的Spark MLlib源码分析—TFIDF源码详解的全部内容,希望文章能够帮你解决Spark MLlib源码分析—TFIDF源码详解所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(48)

评论列表共有 0 条评论

立即
投稿
返回
顶部