001 /* 002 // $Id: //open/mondrian/src/main/mondrian/olap/fun/LinReg.java#13 $ 003 // This software is subject to the terms of the Common Public License 004 // Agreement, available at the following URL: 005 // http://www.opensource.org/licenses/cpl.html. 006 // Copyright (C) 2005-2008 Julian Hyde 007 // All Rights Reserved. 008 // You must accept the terms of that agreement to use this software. 009 */ 010 011 012 package mondrian.olap.fun; 013 014 import mondrian.olap.*; 015 import mondrian.olap.type.TupleType; 016 import mondrian.olap.type.SetType; 017 import mondrian.calc.*; 018 import mondrian.calc.impl.AbstractDoubleCalc; 019 import mondrian.calc.impl.ValueCalc; 020 import mondrian.mdx.ResolvedFunCall; 021 022 import java.util.ArrayList; 023 import java.util.Iterator; 024 import java.util.List; 025 026 /** 027 * Abstract base class for definitions of linear regression functions. 028 * 029 * @see InterceptFunDef 030 * @see PointFunDef 031 * @see R2FunDef 032 * @see SlopeFunDef 033 * @see VarianceFunDef 034 * 035 * <h2>Correlation coefficient</h2> 036 * <p><i>Correlation coefficient</i></p> 037 * 038 * <p>The correlation coefficient, r, ranges from -1 to + 1. The 039 * nonparametric Spearman correlation coefficient, abbreviated rs, has 040 * the same range.</p> 041 * 042 * <table border="1" cellpadding="6" cellspacing="0"> 043 * <tr> 044 * <td>Value of r (or rs)</td> 045 * <td>Interpretation</td> 046 * </tr> 047 * <tr> 048 * <td valign="top">r= 0</td> 049 * 050 * <td>The two variables do not vary together at all.</td> 051 * </tr> 052 * <tr> 053 * <td valign="top">0 > r > 1</td> 054 * <td> 055 * <p>The two variables tend to increase or decrease together.</p> 056 * </td> 057 * </tr> 058 * <tr> 059 * <td valign="top">r = 1.0</td> 060 * <td> 061 * <p>Perfect correlation.</p> 062 * </td> 063 * </tr> 064 * 065 * <tr> 066 * <td valign="top">-1 > r > 0</td> 067 * <td> 068 * <p>One variable increases as the other decreases.</p> 069 * </td> 070 * </tr> 071 * 072 * <tr> 073 * <td valign="top">r = -1.0</td> 074 * <td> 075 * <p></p> 076 * <p>Perfect negative or inverse correlation.</p> 077 * </td> 078 * </tr> 079 * </table> 080 * 081 * <p>If r or rs is far from zero, there are four possible explanations:</p> 082 * <p>The X variable helps determine the value of the Y variable.</p> 083 * <ul> 084 * <li>The Y variable helps determine the value of the X variable. 085 * <li>Another variable influences both X and Y. 086 * <li>X and Y don't really correlate at all, and you just 087 * happened to observe such a strong correlation by chance. The P value 088 * determines how often this could occur. 089 * </ul> 090 * <p><i>r2 </i></p> 091 * 092 * <p>Perhaps the best way to interpret the value of r is to square it to 093 * calculate r2. Statisticians call this quantity the coefficient of 094 * determination, but scientists call it r squared. It is has a value 095 * that ranges from zero to one, and is the fraction of the variance in 096 * the two variables that is shared. For example, if r2=0.59, then 59% of 097 * the variance in X can be explained by variation in Y. Likewise, 098 * 59% of the variance in Y can be explained by (or goes along with) 099 * variation in X. More simply, 59% of the variance is shared between X 100 * and Y.</p> 101 * 102 * <p>(<a href="http://www.graphpad.com/articles/interpret/corl_n_linear_reg/correlation.htm">Source</a>). 103 * 104 * <p>Also see: <a href="http://mathworld.wolfram.com/LeastSquaresFitting.html">least squares fitting</a>. 105 */ 106 107 108 public abstract class LinReg extends FunDefBase { 109 /** Code for the specific function. */ 110 final int regType; 111 112 public static final int Point = 0; 113 public static final int R2 = 1; 114 public static final int Intercept = 2; 115 public static final int Slope = 3; 116 public static final int Variance = 4; 117 118 static final Resolver InterceptResolver = new ReflectiveMultiResolver( 119 "LinRegIntercept", 120 "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])", 121 "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.", 122 new String[]{"fnxn","fnxnn"}, 123 InterceptFunDef.class); 124 125 static final Resolver PointResolver = new ReflectiveMultiResolver( 126 "LinRegPoint", 127 "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])", 128 "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.", 129 new String[]{"fnnxn","fnnxnn"}, 130 PointFunDef.class); 131 132 static final Resolver SlopeResolver = new ReflectiveMultiResolver( 133 "LinRegSlope", 134 "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])", 135 "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.", 136 new String[]{"fnxn","fnxnn"}, 137 SlopeFunDef.class); 138 139 static final Resolver R2Resolver = new ReflectiveMultiResolver( 140 "LinRegR2", 141 "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])", 142 "Calculates the linear regression of a set and returns R2 (the coefficient of determination).", 143 new String[]{"fnxn","fnxnn"}, 144 R2FunDef.class); 145 146 static final Resolver VarianceResolver = new ReflectiveMultiResolver( 147 "LinRegVariance", 148 "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])", 149 "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.", 150 new String[]{"fnxn","fnxnn"}, 151 VarianceFunDef.class); 152 153 154 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 155 final ListCalc listCalc = compiler.compileList(call.getArg(0)); 156 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1)); 157 final DoubleCalc xCalc = call.getArgCount() > 2 ? 158 compiler.compileDouble(call.getArg(2)) : 159 new ValueCalc(call); 160 final boolean isTuples = 161 ((SetType) listCalc.getType()).getElementType() instanceof 162 TupleType; 163 return new LinRegCalc(call, listCalc, yCalc, xCalc, isTuples, regType); 164 } 165 166 ///////////////////////////////////////////////////////////////////////// 167 // 168 // Helper 169 // 170 ///////////////////////////////////////////////////////////////////////// 171 static class Value { 172 private List xs; 173 private List ys; 174 /** 175 * The intercept for the linear regression model. Initialized 176 * following a call to accuracy. 177 */ 178 double intercept; 179 180 /** 181 * The slope for the linear regression model. Initialized following a 182 * call to accuracy. 183 */ 184 double slope; 185 186 /** the coefficient of determination */ 187 double rSquared = Double.MAX_VALUE; 188 189 /** variance = sum square diff mean / n - 1 */ 190 double variance = Double.MAX_VALUE; 191 192 Value(double intercept, double slope, List xs, List ys) { 193 this.intercept = intercept; 194 this.slope = slope; 195 this.xs = xs; 196 this.ys = ys; 197 } 198 199 public double getIntercept() { 200 return this.intercept; 201 } 202 203 public double getSlope() { 204 return this.slope; 205 } 206 207 public double getRSquared() { 208 return this.rSquared; 209 } 210 211 /** 212 * strength of the correlation 213 * 214 * @param rSquared 215 */ 216 public void setRSquared(double rSquared) { 217 this.rSquared = rSquared; 218 } 219 220 public double getVariance() { 221 return this.variance; 222 } 223 224 public void setVariance(double variance) { 225 this.variance = variance; 226 } 227 228 public String toString() { 229 return "LinReg.Value: slope of " 230 + slope 231 + " and an intercept of " + intercept 232 + ". That is, y=" 233 + intercept 234 + (slope > 0.0 ? " +" : " ") 235 + slope 236 + " * x."; 237 } 238 } 239 240 /** 241 * Definition of the <code>LinRegIntercept</code> MDX function. 242 * 243 * <p>Synopsis: 244 * <blockquote><code>LinRegIntercept(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])</code></blockquote> 245 */ 246 public static class InterceptFunDef extends LinReg { 247 public InterceptFunDef(FunDef funDef) { 248 super(funDef, Intercept); 249 } 250 } 251 252 /** 253 * Definition of the <code>LinRegPoint</code> MDX function. 254 * 255 * <p>Synopsis: 256 * <blockquote><code>LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])</code></blockquote> 257 */ 258 public static class PointFunDef extends LinReg { 259 public PointFunDef(FunDef funDef) { 260 super(funDef, Point); 261 } 262 263 public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) { 264 final DoubleCalc xPointCalc = compiler.compileDouble(call.getArg(0)); 265 final ListCalc listCalc = compiler.compileList(call.getArg(1)); 266 final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2)); 267 final DoubleCalc xCalc = call.getArgCount() > 3 ? 268 compiler.compileDouble(call.getArg(3)) : 269 new ValueCalc(call); 270 final boolean isTuples = 271 ((SetType) listCalc.getType()).getElementType() instanceof 272 TupleType; 273 return new PointCalc(call, xPointCalc, listCalc, yCalc, xCalc, isTuples); 274 } 275 276 } 277 278 private static class PointCalc extends AbstractDoubleCalc { 279 private final DoubleCalc xPointCalc; 280 private final ListCalc listCalc; 281 private final DoubleCalc yCalc; 282 private final DoubleCalc xCalc; 283 private final boolean tuples; 284 285 public PointCalc( 286 ResolvedFunCall call, 287 DoubleCalc xPointCalc, 288 ListCalc listCalc, 289 DoubleCalc yCalc, DoubleCalc xCalc, boolean tuples) { 290 super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc}); 291 this.xPointCalc = xPointCalc; 292 this.listCalc = listCalc; 293 this.yCalc = yCalc; 294 this.xCalc = xCalc; 295 this.tuples = tuples; 296 } 297 298 public double evaluateDouble(Evaluator evaluator) { 299 double xPoint = xPointCalc.evaluateDouble(evaluator); 300 Value value = 301 process(evaluator, listCalc, yCalc, xCalc, tuples); 302 if (value == null) { 303 return FunUtil.DoubleNull; 304 } 305 // use first arg to generate y position 306 double yPoint = xPoint * value.getSlope() + 307 value.getIntercept(); 308 return yPoint; 309 } 310 } 311 312 /** 313 * Definition of the <code>LinRegSlope</code> MDX function. 314 * 315 * <p>Synopsis: 316 * <blockquote><code>LinRegSlope(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])</code></blockquote> 317 */ 318 public static class SlopeFunDef extends LinReg { 319 public SlopeFunDef(FunDef funDef) { 320 super(funDef, Slope); 321 } 322 } 323 324 /** 325 * Definition of the <code>LinRegR2</code> MDX function. 326 * 327 * <p>Synopsis: 328 * <blockquote><code>LinRegR2(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])</code></blockquote> 329 */ 330 public static class R2FunDef extends LinReg { 331 public R2FunDef(FunDef funDef) { 332 super(funDef, R2); 333 } 334 } 335 336 /** 337 * Definition of the <code>LinRegVariance</code> MDX function. 338 * 339 * <p>Synopsis: 340 * <blockquote><code>LinRegVariance(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])</code></blockquote> 341 */ 342 public static class VarianceFunDef extends LinReg { 343 public VarianceFunDef(FunDef funDef) { 344 super(funDef, Variance); 345 } 346 } 347 348 protected static void debug(String type, String msg) { 349 // comment out for no output 350 // RME 351 //System.out.println(type + ": " +msg); 352 } 353 354 355 protected LinReg(FunDef funDef, int regType) { 356 super(funDef); 357 this.regType = regType; 358 } 359 360 protected static LinReg.Value process( 361 Evaluator evaluator, 362 ListCalc listCalc, 363 DoubleCalc yCalc, 364 DoubleCalc xCalc, 365 boolean isTuples) { 366 List members = listCalc.evaluateList(evaluator); 367 368 evaluator = evaluator.push(); 369 370 SetWrapper[] sws = evaluateSet( 371 evaluator, members, new DoubleCalc[] {yCalc, xCalc}, isTuples); 372 SetWrapper swY = sws[0]; 373 SetWrapper swX = sws[1]; 374 375 if (swY.errorCount > 0) { 376 debug("LinReg.process","ERROR error(s) count =" + swY.errorCount); 377 // TODO: throw exception 378 return null; 379 } else if (swY.v.size() == 0) { 380 return null; 381 } 382 383 return linearReg(swX.v, swY.v); 384 } 385 386 public static LinReg.Value accuracy(LinReg.Value value) { 387 // for variance 388 double sumErrSquared = 0.0; 389 390 double sumErr = 0.0; 391 392 // for r2 393 // data 394 double sumSquaredY = 0.0; 395 double sumY = 0.0; 396 // predicted 397 double sumSquaredYF = 0.0; 398 double sumYF = 0.0; 399 400 // Obtain the forecast values for this model 401 List yfs = forecast(value); 402 403 // Calculate the Sum of the Absolute Errors 404 Iterator ity = value.ys.iterator(); 405 Iterator ityf = yfs.iterator(); 406 while (ity.hasNext()) { 407 // Get next data point 408 Double dy = (Double) ity.next(); 409 if (dy == null) { 410 continue; 411 } 412 Double dyf = (Double) ityf.next(); 413 if (dyf == null) { 414 continue; 415 } 416 417 double y = dy.doubleValue(); 418 double yf = dyf.doubleValue(); 419 420 // Calculate error in forecast, and update sums appropriately 421 422 // the y residual or error 423 double error = yf - y; 424 425 sumErr += error; 426 sumErrSquared += error * error; 427 428 sumY += y; 429 sumSquaredY += (y * y); 430 431 sumYF =+ yf; 432 sumSquaredYF =+ (yf * yf); 433 } 434 435 436 // Initialize the accuracy indicators 437 int n = value.ys.size(); 438 439 // Variance 440 // The estimate the value of the error variance is a measure of 441 // variability of the y values about the estimated line. 442 // http://home.ubalt.edu/ntsbarsh/Business-stat/opre504.htm 443 // s2 = SSE/(n-2) = sum (y - yf)2 /(n-2) 444 if (n > 2) { 445 double variance = sumErrSquared / (n - 2); 446 447 value.setVariance(variance); 448 } 449 450 // R2 451 // R2 = 1 - (SSE/SST) 452 // SSE = sum square error = Sum((error-MSE)*(error-MSE)) 453 // MSE = mean error = Sum(error)/n 454 // SST = sum square y diff = Sum((y-MST)*(y-MST)) 455 // MST = mean y = Sum(y)/n 456 double MSE = sumErr / n; 457 double MST = sumY / n; 458 double SSE = 0.0; 459 double SST = 0.0; 460 ity = value.ys.iterator(); 461 ityf = yfs.iterator(); 462 while (ity.hasNext()) { 463 // Get next data point 464 Double dy = (Double) ity.next(); 465 if (dy == null) { 466 continue; 467 } 468 Double dyf = (Double) ityf.next(); 469 if (dyf == null) { 470 continue; 471 } 472 473 double y = dy.doubleValue(); 474 double yf = dyf.doubleValue(); 475 476 double error = yf - y; 477 SSE += (error - MSE) * (error - MSE); 478 SST += (y - MST) * (y - MST); 479 } 480 if (SST != 0.0) { 481 double rSquared = 1 - (SSE / SST); 482 483 value.setRSquared(rSquared); 484 } 485 486 487 return value; 488 } 489 490 public static LinReg.Value linearReg(List xlist, List ylist) { 491 492 // y and x have same number of points 493 int size = ylist.size(); 494 double sumX = 0.0; 495 double sumY = 0.0; 496 double sumXX = 0.0; 497 double sumXY = 0.0; 498 499 debug("LinReg.linearReg","ylist.size()=" + ylist.size()); 500 debug("LinReg.linearReg","xlist.size()=" + xlist.size()); 501 int n = 0; 502 for (int i = 0; i < size; i++) { 503 Object yo = ylist.get(i); 504 Object xo = xlist.get(i); 505 if ((yo == null) || (xo == null)) { 506 continue; 507 } 508 n++; 509 double y = ((Double) yo).doubleValue(); 510 double x = ((Double) xo).doubleValue(); 511 512 debug("LinReg.linearReg"," " + i + " (" + x + "," + y + ")"); 513 sumX += x; 514 sumY += y; 515 sumXX += x * x; 516 sumXY += x * y; 517 } 518 519 double xMean = sumX / n; 520 double yMean = sumY / n; 521 522 debug("LinReg.linearReg", "yMean=" + yMean); 523 debug("LinReg.linearReg", "(n*sumXX - sumX*sumX)=" + (n * sumXX - sumX * sumX)); 524 // The regression line is the line that minimizes the variance of the 525 // errors. The mean error is zero; so, this means that it minimizes the 526 // sum of the squares errors. 527 double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX); 528 double intercept = yMean - slope * xMean; 529 530 LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist); 531 debug("LinReg.linearReg","value=" + value); 532 533 return value; 534 } 535 536 537 public static List forecast(LinReg.Value value) { 538 List yfs = new ArrayList(value.xs.size()); 539 540 Iterator it = value.xs.iterator(); 541 while (it.hasNext()) { 542 Double d = (Double) it.next(); 543 // If the value is missing we still must put a place 544 // holder in the y axis, otherwise there is a discontinuity 545 // between the data and the fit. 546 if (d == null) { 547 yfs.add(null); 548 } else { 549 double x = d.doubleValue(); 550 double yf = value.intercept + value.slope * x; 551 yfs.add(new Double(yf)); 552 } 553 } 554 555 return yfs; 556 } 557 558 private static class LinRegCalc extends AbstractDoubleCalc { 559 private final ListCalc listCalc; 560 private final DoubleCalc yCalc; 561 private final DoubleCalc xCalc; 562 private final boolean tuples; 563 private final int regType; 564 565 public LinRegCalc( 566 ResolvedFunCall call, 567 ListCalc listCalc, 568 DoubleCalc yCalc, 569 DoubleCalc xCalc, 570 boolean tuples, 571 int regType) { 572 super(call, new Calc[]{listCalc, yCalc, xCalc}); 573 this.listCalc = listCalc; 574 this.yCalc = yCalc; 575 this.xCalc = xCalc; 576 this.tuples = tuples; 577 this.regType = regType; 578 } 579 580 public double evaluateDouble(Evaluator evaluator) { 581 Value value = 582 process(evaluator, listCalc, yCalc, xCalc, tuples); 583 if (value == null) { 584 return FunUtil.DoubleNull; 585 } 586 switch (regType) { 587 case Intercept: 588 return value.getIntercept(); 589 case Slope: 590 return value.getSlope(); 591 case Variance: 592 return value.getVariance(); 593 case R2: 594 return value.getRSquared(); 595 default: 596 case Point: 597 throw Util.newInternal("unexpected value " + regType); 598 } 599 } 600 } 601 } 602 603 // End LinReg.java