看了原理和比人的代码后,终于自己写了一个EM的实现。
我从网上找了一些身高性别的数据,用EM算法通过身高信息来识别性别。
实现的效果还行,正确率有84% (初始数据 男生170 女生160 方差都是10)
79% (初始数据 男生165 女生150 方差都是10)
正确率与初始值有关。
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174/* 试图用EM算法来根据输入的身高来区分性别 */ #include<iostream> #include<fstream> #include<algorithm> #include<vector> using namespace std; #define PI 3.14159 #define max(x,y) (x > y ? x : y) typedef struct FLOAT2 { float f1; float f2; }FLOAT2; typedef struct Gaussian { float mean; float var; }Gaussian; typedef struct EMData { char sex; float fHeight; }EMData; //获取身高性别数据 int getdata(vector<EMData> &Data) { ifstream fin; fin.open("data.txt"); if(!fin) { cout<<"error: can't open the file."<<endl; return -1; } while(!fin.eof()) { char c[10]; float height; fin >> c >> height; EMData data; data.sex = c[0]; data.fHeight = height; Data.push_back(data); } return 0; } //根据身高数据区分性别, 返回正确率 float predict(vector<EMData> Data) { //设符合正态分布 Gaussian sex[2]; float a[2]; //男女生所占百分比 float t = 1; float tlimit = 0.000001; //收敛条件 //赋初值 下标0表示男生 1表示女生 sex[0].mean = 180.0; sex[0].var = 10.0; sex[1].mean = 150.0; sex[1].var = 10.0; a[0] = 0.5; a[1] = 0.5; while(t > tlimit) { Gaussian sex_old[2]; float a_old[2]; sex_old[0] = sex[0]; sex_old[1] = sex[1]; a_old[0] = a[0]; a_old[1] = a[1]; //计算每个样本分别被两个模型抽中的概率 vector<FLOAT2> px; vector<EMData>::iterator it; for(it = Data.begin(); it < Data.end(); it++) { FLOAT2 p; p.f1 = 1/(sqrt(2 * PI * sex[0].var)) * exp(-(it->fHeight - sex[0].mean) * (it->fHeight - sex[0].mean) / (2 * sex[0].var)); p.f2 = 1/(sqrt(2 * PI * sex[1].var)) * exp(-(it->fHeight - sex[1].mean) * (it->fHeight - sex[1].mean) / (2 * sex[1].var)); px.push_back(p); } //E步 //计算每个样本属于男生或女生的概率 vector<FLOAT2>::iterator it2; for(it2 = px.begin(); it2 < px.end(); it2++) { float sum = 0.0; (*it2).f1 *= a[0]; sum += (*it2).f1; (*it2).f2 *= a[1]; sum += (*it2).f2; (*it2).f1 = (*it2).f1/sum; (*it2).f2 = (*it2).f2/sum; } //M步 float sum_male = 0, sum_female = 0; float sum_mean_male = 0, sum_mean_female = 0; for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++) { sum_male += (*it2).f1; sum_female += (*it2).f2; sum_mean_male += (*it2).f1 * (it->fHeight); sum_mean_female += (*it2).f2 * (it->fHeight); } //更新a a[0] = sum_male/(sum_male + sum_female); a[1] = sum_female/(sum_male + sum_female); //更新均值 sex[0].mean = sum_mean_male/ sum_male; sex[1].mean = sum_mean_female/ sum_female; //更新方差 float sum_var_male = 0, sum_var_female = 0; for(it2 = px.begin(), it = Data.begin(); it2 < px.end(); it2++, it++) { sum_var_male += (*it2).f1 * ((it->fHeight) - sex[0].mean) * ((it->fHeight) - sex[0].mean); sum_var_female += (*it2).f2 * ((it->fHeight) - sex[1].mean) * ((it->fHeight) - sex[1].mean); } sex[0].var = sum_var_male / sum_male; sex[1].var = sum_var_female / sum_female; //计算变化率 t = max((a[0] - a_old[0])/a_old[0], (a[1] - a_old[1])/a_old[1]); t = max(t, (sex[0].mean - sex_old[0].mean)/sex_old[0].mean); t = max(t, (sex[1].mean - sex_old[1].mean)/sex_old[1].mean); t = max(t, (sex[0].var - sex_old[0].var)/sex_old[0].var); t = max(t, (sex[1].var - sex_old[1].var)/sex_old[1].var); } //计算正确率 int correct_num = 0; float correct_rate = 0; vector<EMData>::iterator it; for(it = Data.begin(); it < Data.end(); it++) { float p[2]; char csex; for(int i = 0; i < 2; i++) { p[i] = 1/(sqrt(2 * PI * sex[i].var)) * exp(-(it->fHeight - sex[i].mean) * (it->fHeight - sex[i].mean) / (2 * sex[i].var)); } csex = (p[0] > p[1]) ? 'm' : 'f'; if(csex == it->sex) correct_num++; } correct_rate = (float)correct_num / Data.size(); return correct_rate; } int main() { vector<EMData> Data; getdata(Data); float correct_rate = predict(Data); cout << "correct rate = "<< correct_rate << endl; return 0; }
数据:data.txt内容
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114male 164 female 156 male 168 female 160 female 162 male 187 female 162 male 167 female 160.5 female 160 female 158 female 164 female 165 male 174 female 166 female 158 male 162 male 175 male 170 female 161 female 169 female 161 female 160 female 167 male 176 male 169 male 178 male 165 female 155 male 183 male 171 male 179 female 154 male 172 female 172 male 173 male 172 male 175 male 160 male 160 male 160 male 175 male 163 male 181 male 172 male 175 male 175 male 167 male 172 male 169 male 172 male 175 male 172 male 170 male 158 male 167 male 164 male 176 male 182 male 173 male 176 male 163 male 166 male 162 male 169 male 163 male 163 male 176 male 169 male 173 male 163 male 167 male 176 male 168 male 167 male 170 female 155 female 157 female 165 female 156 female 155 female 156 female 160 female 158 female 162 female 162 female 155 female 163 female 160 female 162 female 165 female 159 female 147 female 163 female 157 female 160 female 162 female 158 female 155 female 165 female 161 female 159 female 163 female 158 female 155 female 162 female 157 female 159 female 152 female 156 female 165 female 154 female 156 female 162
最后
以上就是可耐皮带最近收集整理的关于【EM】C++代码实现的全部内容,更多相关【EM】C++代码实现内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复