首页 > 【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)

【cs229-Lecture2】Linear Regression with One Variable (Week 1)(含测试数据和源码)

从Ⅱ到Ⅳ都在讲的是线性回归,其中第Ⅱ章讲得是简单线性回归(simple linear regression, SLR)(单变量),第Ⅲ章讲的是线代基础,第Ⅳ章讲的是多元回归(大于一个自变量)。

本文的目的主要是对Ⅱ章中出现的一些算法进行实现,适合的人群为已经看完本章节Stanford课程的学者。本人只是一名初学者,尽可能以白话的方式来说明问题。不足之处,还请指正。

在开始讨论具体步骤之前,首先给出简要的思维路线:

1.拥有一个点集,为了得到一条最佳拟合的直线;

2.通过“最小二乘法”来衡量拟合程度,得到代价方程;

3.利用“梯度下降算法”使得代价方程取得极小值点;




首先,介绍几个概念:

回归在数学上来说是给定一个点集,能够用一条曲线去拟合之。如果这个曲线是一条直线,那就被称为线性回归;如果曲线是一条二次曲线,就被称为二次回归,回归还有很多的变种,如locally weighted回归,logistic回归等等。

课程中得到的h就是线性回归方程:

image


下面,首先来介绍一下单变量的线性回归:

问题是这样的:给定一个点集,找出一条直线去拟合,要求拟合的效果达到最佳(最佳拟合)。

既然是直线,我们先假设直线的方程为:image

     如图:image

    点集有了,直线方程有了,接下来,我们要做的就是计算出imageimage,使得拟合效果达到最佳(最佳拟合)。

    那么,拟合效果的评判标准是什么呢?换句话说,我们需要知道一种对拟合效果的度量。

   在这里,我们提出“最小二乘法”:(以下摘自wiki)

最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。

利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。

对于“最小二乘法”就不再展开讨论,只要知道他是一个度量标准,我们可以用它来评判计算出的直线方程是否达到了最佳拟合就够了。

那么,回到问题上来,在单变量的线性回归中,这个拟合效果的表达式是利用最小二乘法将未知量残差平方和最小化

image

结合课程,定义了一个成本函数:

image

其实,到这里,要是把点集的具体数值代入到成本函数中,就已经完全抽象出了一个高等数学问题(解一个二元函数的最小值问题)。

image

其中,a,b,c,d,e,f均为已知。

课程中介绍了一种叫“Gradient descent”的方法——梯度下降算法

image

两张图说明算法的基本思想:

imageimage

image

所谓梯度下降算法(一种求局部最优解的方法),举个例子就好比你现在在一座山上,你想要尽快地到达山底(极小值点),这是一个下降的过程,这里就涉及到了两个问题:1)你下山的时候,跨多大的步子(当然,肯定不是越大越好,因为有一种可能就是你一步跨地太大,正好错过了极小的位置);2)你朝哪个方向跨步(注意,这个方向是不断变化的,你每到一个新的位置,要判断一下下一步朝那个方向走才是最好的,但是有一点可以肯定的是,要想尽快到达最低点,应从最陡的地方下山)。

那么,什么时候算是你到了一个极小点呢,显然,当你所处的位置发生的变化不断减小,直至收敛于某一位置,就说明那个位置就是一个极小值点。

 

so,我们来看image的变化,则我们需要让imageimage求偏导,倒数代表变化率。也就是要朝着对陡的地方下山(因为沿着最陡显然比较快),就得到了image的变化情况:image

image

image

简化之后:

image

 

步长不宜过大或过小

image

梯度下降法是按下面的流程进行的:(转自:http://blog.sina.com.cn/s/blog_62339a2401015jyq.html)

1)首先对θ赋值,这个值可以是随机的,也可以让θ是一个全零的向量。

2)改变θ的值,使得J(θ)按梯度下降的方向进行减少。

        为了方便大家的理解,首先给出单变量的例子:

       eg:求image的最小值。(注:image

image

       java代码如下:

·

package OneVariable;public class OneVariable{public static void main(String[] args){double e=0.00001;//定义迭代精度double alpha=0.5;//定义迭代步长double x=0;            //初始化xdouble y0=2*x*x+3*x+1;//与初始化x对应的y值double y1=0;//定义变量,用于保存当前值while (true){x=x-alpha*(4.0*x+3.0);y1=2*x*x+3*x+1;if (Math.abs(y1-y0)//如果2次迭代的结果变化很小,结束迭代
        {break;}y0=y1;//更新迭代的结果
    }System.out.println("Min(f(x))="+y0);System.out.println("minx="+x);}
}//输出
Min(f(x))=1.0
minx=-1.5

两个变量的时候,为了更清楚,给出下面的图:

image

这是一个表示参数θ与误差函数J(θ)的关系图,红色的部分是表示J(θ)有着比较高的取值,我们需要的是,能够让J(θ)的值尽量的低。也就是深蓝色的部分。θ0,θ1表示θ向量的两个维度。

在上面提到梯度下降法的第一步是给θ给一个初值,假设随机给的初值是在图上的十字点。

然后我们将θ按照梯度下降的方向进行调整,就会使得J(θ)往更低的方向进行变化,如图所示,算法的结束将是在θ下降到无法继续下降为止。

image

当然,可能梯度下降的最终点并非是全局最小点,可能是一个局部最小点,可能是下面的情况:

image

上面这张图就是描述的一个局部最小点,这是我们重新选择了一个初始点得到的,看来我们这个算法将会在很大的程度上被初始点的选择影响而陷入局部最小点 

一个很重要的地方值得注意的是,梯度是有方向的,对于一个向量θ,每一维分量θi都可以求出一个梯度的方向,我们就可以找到一个整体的方向,在变化的时候,我们就朝着下降最多的方向进行变化就可以达到一个最小点,不管它是局部的还是全局的。

 


理论的知识就讲到这,下面,我们就用java去实现这个算法:

梯度下降有两种:批量梯度下降和随机梯度下降。详见:http://blog.csdn.net/lilyth_lilyth/article/details/8973972

测试数据就用课后题中的数据(ex1data1.txt),用matlab打开作图得到:

image

 

首先说明:以下源码是不正确的,具体为什么不正确我还没搞清楚!非常希望各位高手能够指正!

测试数据及源码下载:http://pan.baidu.com/s/1mgiIVm4

 

OneVariable.java
 1 package OneVariableVersion;
 2 
 3 import java.io.IOException;
 4 import java.util.List;
 5 
 6 
 7 /**
 8  * Linear Regression with One Variable
 9  * @author XBW
10  * @date 2014年8月17日
11  */
12 
13 public class OneVariable{
14     public static final Double e=0.00001;
15     public static List DS;
16     public static Double step;
17     public static Double m;
18     
19     /**
20      * 计算当前参数是否符合
21      * @param ans
22      * @param datalist
23      * @return
24      */
25     public static Ans calc(Ans ans){
26         Double costfun;
27         do{
28             costfun=calcAccuracy(ans);
29             ans=update(ans);
30             step*=0.3;
31         }while(Math.abs(costfun-calcAccuracy(ans))>e);
32         ans.ifConvergence=true;
33         return ans;
34     }
35     
36     /**
37      * 判断当前ans是否满足精度,y=t0+t1*x
38      * @param ans
39      * @return
40      */
41     public static Double calcAccuracy(Ans ans){
42         Double cost=0.0;
43         Double tmp;
44         for(int i=0;i){
45             tmp=DS.get(i).y-(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]);
46             cost+=tmp*tmp;
47         }
48         cost/=(2*m);
49         return cost;
50     }
51     
52     /**
53      * 更新ans
54      * @param ans,学习速率为step,m为数据量
55      * @return
56      */
57     public static Ans update(Ans ans){
58         Double[] tmp=new Double[100] ;
59         for(int i=0;i<2;i++){
60             tmp[i]=ans.theta[i]-step*fun(ans,i);
61         }
62         for(int i=0;i<2;i++){
63             ans.theta[i]=tmp[i];
64         }
65         return ans;
66     }
67     
68     /**
69      * 计算偏导
70      * @return
71      */
72     public static Double fun(Ans ans,int xi){
73         Double ret = 0.0;
74         for(int i=0;i){
75             ret+=(ans.theta[0]*DS.get(i).x[0]+ans.theta[1]*DS.get(i).x[1]-DS.get(i).y)*DS.get(i).x[xi];
76         }
77         ret/=m;
78         return ret;        
79     }
80     
81     public static void main(String[] args) throws IOException{
82         DS=new DataSet().ds;
83         step=1.0;        
84         m=(double)DS.size();
85         
86         
87         Double[] theta={0.0,0.0};                     //初始设定theta0=0,theta1=0
88         Ans ans=new Ans(theta,false);
89         Ans answer;
90         answer=calc(ans);
91         System.out.println("theta1= "+answer.theta[0]+"      theta2="+answer.theta[1]);
92     }
93 }

 

DataSet.java

 

 1 package OneVariableVersion;
 2 
 3 import java.io.BufferedReader;
 4 import java.io.File;
 5 import java.io.FileReader;
 6 import java.io.IOException;
 7 import java.util.ArrayList;
 8 import java.util.List;
 9 
10 
11 /**
12  * 数据处理
13  * @author XBW
14  * @date 2014年8月17日
15  */
16 
17 public class DataSet{
18     String defaultpath="D:\MachineLearning\StanfordbyAndrewNg\II.LinearRegressionwithOneVariable(Week1)\homework\ex1data1.txt";
19     
20     List ds=new ArrayList();
21     
22     public DataSet() throws IOException{
23         File dataset=new File(defaultpath);
24         BufferedReader br = new BufferedReader(new FileReader(dataset));
25         String tsing;
26         while((tsing=br.readLine())!=null){
27             String[] dlist=tsing.split(",");
28             Data dtmp=new Data(Double.parseDouble(dlist[0]),Double.parseDouble(dlist[1]));
29             this.ds.add(dtmp);
30         }
31         br.close();
32     }
33 }

 

 

 

Ans.java

 

 1 package OneVariableVersion;
 2 
 3 /**
 4  * 保存结果,y=t0+t1*x
 5  * @author XBW
 6  * @date 2014年8月17日
 7  */
 8 
 9 public class Ans {
10     Double[] theta;
11     boolean ifConvergence;
12     
13     public Ans(Double[] tmp,boolean ifCon){
14         this.theta=tmp;
15         this.ifConvergence=ifCon;
16     }
17 }

 

 

 

 

Data.java

 

 1 package OneVariableVersion;
 2 
 3 
 4 /**
 5  * 一条数据
 6  * @author XBW
 7  * @date 2014年8月17日
 8  */
 9 public class Data {
10     Double[] x=new Double[2];
11     Double y;
12     
13     public Data(Double xtmp,Double ytmp){
14         this.x[0]=1.0;
15         this.x[1]=xtmp;
16         this.y=ytmp;
17     }
18 }

 

 

 

 

 

 

总结:写代码的时候有几个讲究:

  1. 步长是否需要动态变化,按照coursera公开课上讲的是不必要动态改变的,因为偏导数会越来越小,但在实际情况下,按照一定的比值缩小或者自己定义一种缩小的方式可能是有必要的,所以具体情况具体分析;
  2. 初始步长的设定也是很重要的,大了就不会得到结果,因为发散了;步长越大,下降速率越快,但是也会导致震荡,所以,还是哪句话:具体问题具体分析;

转载于:https://www.cnblogs.com/XBWer/p/3912792.html

更多相关:

  • 原文出处: 韩昊    1 2 3 4 5 6 7 8 9 10 作 者:韩 昊 知 乎:Heinrich 微 博:@花生油工人 知乎专栏:与时间无关的故事   谨以此文献给大连海事大学的吴楠老师,柳晓鸣老师,王新年老师以及张晶泊老师。   转载的同学请保留上面这句话,谢谢。如果还能保留文章来源就更感激不尽了。 我保证这篇文章...

  • 原文出处: 韩昊   我保证这篇文章和你以前看过的所有文章都不同,这是 2012 年还在果壳的时候写的,但是当时没有来得及写完就出国了……于是拖了两年,嗯,我是拖延症患者…… 这篇文章的核心思想就是: 要让读者在不看任何数学公式的情况下理解傅里叶分析。 傅里叶分析不仅仅是一个数学工具,更是一种可以彻底颠覆一个人以前世界观的思维...

  • 很多Linux高手都喜欢使用screen命令,screen命令可以使你轻松地使用一个终端控制其他终端。尽管screen本身是一个非常有用的工具,byobu作为screen的增强版本,比screen更加好用而且美观,并且提供有用的信息和快捷的热键。 想象一下这样一个场景:你通过Secure Shell(ssh)链接到一个服务器,并...

  • NarrowbandPrimary Synchronization Signal时域位置每1个SFN存在一个NPSSSFNSubframeSymbol长度每个SFN5最后11个symbol11个symbols频域位置NB-IOT下行带宽固定180kHz,一个PRB,12个子载波。...

  •  [h1]反斜杠只能够阻止一个字符  [h2]位于键盘的左上角,和~公用一个键。...

  • Qt默认的QSlider和QSpinbox只能实现整数调整,不能实现浮点的变化,因此设计了如下可实现浮点变化的QFloatSlider和QFloatSpinner: QFloatSlider.h class QFloatSlider : public QSlider {Q_OBJECTpublic:QFloatSlider(QWi...

  • 一、概述 之前的文章介绍过卡尔曼滤波算法进行定位,我们知道kalman算法适合用于线性的高斯分布的状态环境中,我们也介绍了EKF,来解决在非高斯和非线性环境下的机器人定位算法。但是他们在现实应用中存在计算量,内存消耗上不是很高效。这就引出了MCL算法。 粒子滤波很粗浅的说就是一开始在地图空间很均匀的撒一把粒子,然后通过获取机器人的...

  • 1.精度问题 由于是double类型,r=mid 而不是r=mid-12.如果首位两端(f(0)和f(100))同号,证明解不在[1,100]区间内 这是我之所以TE的原因,没有预先判断3.若在这个区间内,则一定可要求出解 所以binarysearch 返回m#include #include ...

  • 代理(Proxy)模式给某一个对象提供一个代理,并由代理对象控制对原对象的引用。 代理模式的英文叫做Proxy或Surrogate,中文都可译成"代理"。所谓代理,就是一个人或者一个机构代表另一个人或者另一个机构采取行动。在一些情况下,一个客户不想或者不能够直接引用一个对象,而代理对象可以在客户端和目标对象之间起到中介的作用。 类图:...

  • 题目链接:http://uva.onlinejudge.org/index.php?option=com_onlinejudge&Itemid=8&category=41&page=show_problem&problem=1121 题意:给出两点坐标,用一条最短的线(曲线或者直线)连接起来,坐标系中原点处有一个半径为R的圆,连线不能...

  • 在.Net Framework中,配置文件一般采用的是XML格式的,.NET Framework提供了专门的ConfigurationManager来读取配置文件的内容,.net core中推荐使用json格式的配置文件,那么在.net core中该如何读取json文件呢?1、在Startup类中读取json配置文件1、使用Confi...

  •   1 public class FrameSubject extends JFrame {   2    3   …………..   4    5   //因为无法使用多重继承,这儿就只能使用对象组合的方式来引入一个   6    7   //java.util.Observerable对象了。   8    9   DateSub...

  • 本案例主要说明如何使用NSwag 工具使用桌面工具快速生成c# 客户端代码、快速的访问Web Api。 NSwagStudio 下载地址 比较强大、可以生成TypeScript、WebApi Controller、CSharp Client  1、运行WebApi项目  URL http://yourserver/swagger 然后...

  •   在绑定完Action的所有参数后,WebAPI并不会马上执行该方法,而要对参数进行验证,以保证输入的合法性.   ModelState 在ApiController中一个ModelState属性用来获取参数验证结果.   public abstract class ApiController : IHttpController,...

  • 1# 引用  C:AVEVAMarineOH12.1.SP4Aveva.ApplicationFramework.dll C:AVEVAMarineOH12.1.SP4Aveva.ApplicationFramework.Presentation.dll 2# 引用命名空间, using Aveva.Applicati...