1+ import numpy as np
2+ from scipy import io as spio
3+ from matplotlib import pyplot as plt
4+ from sklearn import svm
5+
6+ def SVM ():
7+ '''data1——线性分类'''
8+ data1 = spio .loadmat ('data1.mat' )
9+ X = data1 ['X' ]
10+ y = data1 ['y' ]
11+ y = np .ravel (y )
12+ plot_data (X ,y )
13+
14+ model = svm .SVC (C = 1.0 ,kernel = 'linear' ).fit (X ,y ) # 指定核函数为线性核函数
15+ plot_decisionBoundary (X , y , model ) # 画决策边界
16+ '''data2——非线性分类'''
17+ data2 = spio .loadmat ('data2.mat' )
18+ X = data2 ['X' ]
19+ y = data2 ['y' ]
20+ y = np .ravel (y )
21+ plt = plot_data (X ,y )
22+ plt .show ()
23+
24+ model = svm .SVC (gamma = 100 ).fit (X ,y ) # gamma为核函数的系数,值越大拟合的越好
25+ plot_decisionBoundary (X , y , model ,class_ = 'notLinear' ) # 画决策边界
26+
27+
28+
29+ # 作图
30+ def plot_data (X ,y ):
31+ plt .figure (figsize = (10 ,8 ))
32+ pos = np .where (y == 1 ) # 找到y=1的位置
33+ neg = np .where (y == 0 ) # 找到y=0的位置
34+ p1 , = plt .plot (np .ravel (X [pos ,0 ]),np .ravel (X [pos ,1 ]),'ro' ,markersize = 8 )
35+ p2 , = plt .plot (np .ravel (X [neg ,0 ]),np .ravel (X [neg ,1 ]),'g^' ,markersize = 8 )
36+ plt .xlabel ("X1" )
37+ plt .ylabel ("X2" )
38+ plt .legend ([p1 ,p2 ],["y==1" ,"y==0" ])
39+ return plt
40+
41+ # 画决策边界
42+ def plot_decisionBoundary (X ,y ,model ,class_ = 'linear' ):
43+ plt = plot_data (X , y )
44+
45+ # 线性边界
46+ if class_ == 'linear' :
47+ w = model .coef_
48+ b = model .intercept_
49+ xp = np .linspace (np .min (X [:,0 ]),np .max (X [:,1 ]),100 )
50+ yp = - (w [0 ,0 ]* xp + b )/ w [0 ,1 ]
51+ plt .plot (xp ,yp ,'b-' ,linewidth = 2.0 )
52+ plt .show ()
53+ else : # 非线性边界
54+ x_1 = np .transpose (np .linspace (np .min (X [:,0 ]),np .max (X [:,0 ]),100 ).reshape (1 ,- 1 ))
55+ x_2 = np .transpose (np .linspace (np .min (X [:,1 ]),np .max (X [:,1 ]),100 ).reshape (1 ,- 1 ))
56+ X1 ,X2 = np .meshgrid (x_1 ,x_2 )
57+ vals = np .zeros (X1 .shape )
58+ for i in range (X1 .shape [1 ]):
59+ this_X = np .hstack ((X1 [:,i ].reshape (- 1 ,1 ),X2 [:,i ].reshape (- 1 ,1 )))
60+ vals [:,i ] = model .predict (this_X )
61+
62+ plt .contour (X1 ,X2 ,vals ,[0 ,1 ],color = 'blue' )
63+ plt .show ()
64+
65+
66+
67+ if __name__ == "__main__" :
68+ SVM ()
69+
70+
71+
0 commit comments