001    /*
002    // $Id: //open/mondrian/src/main/mondrian/olap/fun/LinReg.java#13 $
003    // This software is subject to the terms of the Common Public License
004    // Agreement, available at the following URL:
005    // http://www.opensource.org/licenses/cpl.html.
006    // Copyright (C) 2005-2008 Julian Hyde
007    // All Rights Reserved.
008    // You must accept the terms of that agreement to use this software.
009    */
010    
011    
012    package mondrian.olap.fun;
013    
014    import mondrian.olap.*;
015    import mondrian.olap.type.TupleType;
016    import mondrian.olap.type.SetType;
017    import mondrian.calc.*;
018    import mondrian.calc.impl.AbstractDoubleCalc;
019    import mondrian.calc.impl.ValueCalc;
020    import mondrian.mdx.ResolvedFunCall;
021    
022    import java.util.ArrayList;
023    import java.util.Iterator;
024    import java.util.List;
025    
026    /**
027     * Abstract base class for definitions of linear regression functions.
028     *
029     * @see InterceptFunDef
030     * @see PointFunDef
031     * @see R2FunDef
032     * @see SlopeFunDef
033     * @see VarianceFunDef
034     *
035     * <h2>Correlation coefficient</h2>
036     * <p><i>Correlation coefficient</i></p>
037     *
038     * <p>The correlation coefficient, r, ranges from -1 to  + 1. The
039     * nonparametric Spearman correlation coefficient, abbreviated rs, has
040     * the same range.</p>
041     *
042     * <table border="1" cellpadding="6" cellspacing="0">
043     *   <tr>
044     *     <td>Value of r (or rs)</td>
045     *     <td>Interpretation</td>
046     *   </tr>
047     *   <tr>
048     *     <td valign="top">r= 0</td>
049     *
050     *     <td>The two variables do not vary together at all.</td>
051     *   </tr>
052     *   <tr>
053     *     <td valign="top">0 &gt; r &gt; 1</td>
054     *     <td>
055     *       <p>The two variables tend to increase or decrease together.</p>
056     *     </td>
057     *   </tr>
058     *   <tr>
059     *     <td valign="top">r = 1.0</td>
060     *     <td>
061     *       <p>Perfect correlation.</p>
062     *     </td>
063     *   </tr>
064     *
065     *   <tr>
066     *     <td valign="top">-1 &gt; r &gt; 0</td>
067     *     <td>
068     *       <p>One variable increases as the other decreases.</p>
069     *     </td>
070     *   </tr>
071     *
072     *   <tr>
073     *     <td valign="top">r = -1.0</td>
074     *     <td>
075     *       <p></p>
076     *       <p>Perfect negative or inverse correlation.</p>
077     *     </td>
078     *   </tr>
079     * </table>
080     *
081     * <p>If r or rs is far from zero, there are four possible explanations:</p>
082     * <p>The X variable helps determine the value of the Y variable.</p>
083     * <ul>
084     *   <li>The Y variable helps determine the value of the X variable.
085     *   <li>Another variable influences both X and Y.
086     *   <li>X and Y don't really correlate at all, and you just
087     *       happened to observe such a strong correlation by chance. The P value
088     *       determines how often this could occur.
089     * </ul>
090     * <p><i>r2 </i></p>
091     *
092     * <p>Perhaps the best way to interpret the value of r is to square it to
093     * calculate r2. Statisticians call this quantity the coefficient of
094     * determination, but scientists call it r squared. It is has a value
095     * that ranges from zero to one, and is the fraction of the variance in
096     * the two variables that is shared. For example, if r2=0.59, then 59% of
097     * the variance in X can be explained by variation in Y. &nbsp;Likewise,
098     * 59% of the variance in Y can be explained by (or goes along with)
099     * variation in X. More simply, 59% of the variance is shared between X
100     * and Y.</p>
101     *
102     * <p>(<a href="http://www.graphpad.com/articles/interpret/corl_n_linear_reg/correlation.htm">Source</a>).
103     *
104     * <p>Also see: <a href="http://mathworld.wolfram.com/LeastSquaresFitting.html">least squares fitting</a>.
105     */
106    
107    
108    public abstract class LinReg extends FunDefBase {
109        /** Code for the specific function. */
110        final int regType;
111    
112        public static final int Point = 0;
113        public static final int R2 = 1;
114        public static final int Intercept = 2;
115        public static final int Slope = 3;
116        public static final int Variance = 4;
117    
118        static final Resolver InterceptResolver = new ReflectiveMultiResolver(
119                "LinRegIntercept",
120                "LinRegIntercept(<Set>, <Numeric Expression>[, <Numeric Expression>])",
121                "Calculates the linear regression of a set and returns the value of b in the regression line y = ax + b.",
122                new String[]{"fnxn","fnxnn"},
123                InterceptFunDef.class);
124    
125        static final Resolver PointResolver = new ReflectiveMultiResolver(
126                "LinRegPoint",
127                "LinRegPoint(<Numeric Expression>, <Set>, <Numeric Expression>[, <Numeric Expression>])",
128                "Calculates the linear regression of a set and returns the value of y in the regression line y = ax + b.",
129                new String[]{"fnnxn","fnnxnn"},
130                PointFunDef.class);
131    
132        static final Resolver SlopeResolver = new ReflectiveMultiResolver(
133                "LinRegSlope",
134                "LinRegSlope(<Set>, <Numeric Expression>[, <Numeric Expression>])",
135                "Calculates the linear regression of a set and returns the value of a in the regression line y = ax + b.",
136                new String[]{"fnxn","fnxnn"},
137                SlopeFunDef.class);
138    
139        static final Resolver R2Resolver = new ReflectiveMultiResolver(
140                "LinRegR2",
141                "LinRegR2(<Set>, <Numeric Expression>[, <Numeric Expression>])",
142                "Calculates the linear regression of a set and returns R2 (the coefficient of determination).",
143                new String[]{"fnxn","fnxnn"},
144                R2FunDef.class);
145    
146        static final Resolver VarianceResolver = new ReflectiveMultiResolver(
147                "LinRegVariance",
148                "LinRegVariance(<Set>, <Numeric Expression>[, <Numeric Expression>])",
149                "Calculates the linear regression of a set and returns the variance associated with the regression line y = ax + b.",
150                new String[]{"fnxn","fnxnn"},
151                VarianceFunDef.class);
152    
153    
154        public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
155            final ListCalc listCalc = compiler.compileList(call.getArg(0));
156            final DoubleCalc yCalc = compiler.compileDouble(call.getArg(1));
157            final DoubleCalc xCalc = call.getArgCount() > 2 ?
158                    compiler.compileDouble(call.getArg(2)) :
159                    new ValueCalc(call);
160            final boolean isTuples =
161                    ((SetType) listCalc.getType()).getElementType() instanceof
162                    TupleType;
163            return new LinRegCalc(call, listCalc, yCalc, xCalc, isTuples, regType);
164        }
165    
166        /////////////////////////////////////////////////////////////////////////
167        //
168        // Helper
169        //
170        /////////////////////////////////////////////////////////////////////////
171        static class Value {
172            private List xs;
173            private List ys;
174            /**
175             * The intercept for the linear regression model. Initialized
176             * following a call to accuracy.
177             */
178            double intercept;
179    
180            /**
181             * The slope for the linear regression model. Initialized following a
182             * call to accuracy.
183             */
184            double slope;
185    
186             /** the coefficient of determination */
187            double rSquared = Double.MAX_VALUE;
188    
189            /** variance = sum square diff mean / n - 1 */
190            double variance = Double.MAX_VALUE;
191    
192            Value(double intercept, double slope, List xs, List ys) {
193                this.intercept = intercept;
194                this.slope = slope;
195                this.xs = xs;
196                this.ys = ys;
197            }
198    
199            public double getIntercept() {
200                return this.intercept;
201            }
202    
203            public double getSlope() {
204                return this.slope;
205            }
206    
207            public double getRSquared() {
208                return this.rSquared;
209            }
210    
211            /**
212             * strength of the correlation
213             *
214             * @param rSquared
215             */
216            public void setRSquared(double rSquared) {
217                this.rSquared = rSquared;
218            }
219    
220            public double getVariance() {
221                return this.variance;
222            }
223    
224            public void setVariance(double variance) {
225                this.variance = variance;
226            }
227    
228            public String toString() {
229                return "LinReg.Value: slope of "
230                    + slope
231                    + " and an intercept of " + intercept
232                    + ". That is, y="
233                    + intercept
234                    + (slope > 0.0 ? " +" : " ")
235                    + slope
236                    + " * x.";
237            }
238        }
239    
240        /**
241         * Definition of the <code>LinRegIntercept</code> MDX function.
242         *
243         * <p>Synopsis:
244         * <blockquote><code>LinRegIntercept(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric  Expression&gt;])</code></blockquote>
245         */
246        public static class InterceptFunDef extends LinReg {
247            public InterceptFunDef(FunDef funDef) {
248                super(funDef, Intercept);
249            }
250        }
251    
252        /**
253         * Definition of the <code>LinRegPoint</code> MDX function.
254         *
255         * <p>Synopsis:
256         * <blockquote><code>LinRegPoint(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric  Expression&gt;])</code></blockquote>
257         */
258        public static class PointFunDef extends LinReg {
259            public PointFunDef(FunDef funDef) {
260                super(funDef, Point);
261            }
262    
263            public Calc compileCall(ResolvedFunCall call, ExpCompiler compiler) {
264                final DoubleCalc xPointCalc = compiler.compileDouble(call.getArg(0));
265                final ListCalc listCalc = compiler.compileList(call.getArg(1));
266                final DoubleCalc yCalc = compiler.compileDouble(call.getArg(2));
267                final DoubleCalc xCalc = call.getArgCount() > 3 ?
268                        compiler.compileDouble(call.getArg(3)) :
269                        new ValueCalc(call);
270                final boolean isTuples =
271                        ((SetType) listCalc.getType()).getElementType() instanceof
272                        TupleType;
273                return new PointCalc(call, xPointCalc, listCalc, yCalc, xCalc, isTuples);
274            }
275    
276        }
277    
278        private static class PointCalc extends AbstractDoubleCalc {
279            private final DoubleCalc xPointCalc;
280            private final ListCalc listCalc;
281            private final DoubleCalc yCalc;
282            private final DoubleCalc xCalc;
283            private final boolean tuples;
284    
285            public PointCalc(
286                    ResolvedFunCall call,
287                    DoubleCalc xPointCalc,
288                    ListCalc listCalc,
289                    DoubleCalc yCalc, DoubleCalc xCalc, boolean tuples) {
290                super(call, new Calc[]{xPointCalc, listCalc, yCalc, xCalc});
291                this.xPointCalc = xPointCalc;
292                this.listCalc = listCalc;
293                this.yCalc = yCalc;
294                this.xCalc = xCalc;
295                this.tuples = tuples;
296            }
297    
298            public double evaluateDouble(Evaluator evaluator) {
299                double xPoint = xPointCalc.evaluateDouble(evaluator);
300                Value value =
301                        process(evaluator, listCalc, yCalc, xCalc, tuples);
302                if (value == null) {
303                    return FunUtil.DoubleNull;
304                }
305                // use first arg to generate y position
306                double yPoint = xPoint * value.getSlope() +
307                        value.getIntercept();
308                return yPoint;
309            }
310        }
311    
312        /**
313         * Definition of the <code>LinRegSlope</code> MDX function.
314         *
315         * <p>Synopsis:
316         * <blockquote><code>LinRegSlope(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric  Expression&gt;])</code></blockquote>
317         */
318        public static class SlopeFunDef extends LinReg {
319            public SlopeFunDef(FunDef funDef) {
320                super(funDef, Slope);
321            }
322        }
323    
324        /**
325         * Definition of the <code>LinRegR2</code> MDX function.
326         *
327         * <p>Synopsis:
328         * <blockquote><code>LinRegR2(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric  Expression&gt;])</code></blockquote>
329         */
330        public static class R2FunDef extends LinReg {
331            public R2FunDef(FunDef funDef) {
332                super(funDef, R2);
333            }
334        }
335    
336        /**
337         * Definition of the <code>LinRegVariance</code> MDX function.
338         *
339         * <p>Synopsis:
340         * <blockquote><code>LinRegVariance(&lt;Numeric Expression&gt;, &lt;Set&gt;, &lt;Numeric Expression&gt;[, &lt;Numeric  Expression&gt;])</code></blockquote>
341         */
342        public static class VarianceFunDef extends LinReg {
343            public VarianceFunDef(FunDef funDef) {
344                super(funDef, Variance);
345            }
346        }
347    
348        protected static void debug(String type, String msg) {
349            // comment out for no output
350    // RME
351            //System.out.println(type + ": " +msg);
352        }
353    
354    
355        protected LinReg(FunDef funDef, int regType) {
356            super(funDef);
357            this.regType = regType;
358        }
359    
360        protected static LinReg.Value process(
361                Evaluator evaluator,
362                ListCalc listCalc,
363                DoubleCalc yCalc,
364                DoubleCalc xCalc,
365                boolean isTuples) {
366            List members = listCalc.evaluateList(evaluator);
367    
368            evaluator = evaluator.push();
369    
370            SetWrapper[] sws = evaluateSet(
371                    evaluator, members, new DoubleCalc[] {yCalc, xCalc}, isTuples);
372            SetWrapper swY = sws[0];
373            SetWrapper swX = sws[1];
374    
375            if (swY.errorCount > 0) {
376    debug("LinReg.process","ERROR error(s) count ="  + swY.errorCount);
377                // TODO: throw exception
378                return null;
379            } else if (swY.v.size() == 0) {
380                return null;
381            }
382    
383            return linearReg(swX.v, swY.v);
384        }
385    
386        public static LinReg.Value accuracy(LinReg.Value value) {
387            // for variance
388            double sumErrSquared = 0.0;
389    
390            double sumErr = 0.0;
391    
392            // for r2
393            // data
394            double sumSquaredY = 0.0;
395            double sumY = 0.0;
396            // predicted
397            double sumSquaredYF = 0.0;
398            double sumYF = 0.0;
399    
400            // Obtain the forecast values for this model
401            List yfs = forecast(value);
402    
403            // Calculate the Sum of the Absolute Errors
404            Iterator ity = value.ys.iterator();
405            Iterator ityf = yfs.iterator();
406            while (ity.hasNext()) {
407                // Get next data point
408                Double dy = (Double) ity.next();
409                if (dy == null) {
410                    continue;
411                }
412                Double dyf = (Double) ityf.next();
413                if (dyf == null) {
414                    continue;
415                }
416    
417                double y = dy.doubleValue();
418                double yf = dyf.doubleValue();
419    
420                // Calculate error in forecast, and update sums appropriately
421    
422                // the y residual or error
423                double error = yf - y;
424    
425                sumErr += error;
426                sumErrSquared += error * error;
427    
428                sumY += y;
429                sumSquaredY += (y * y);
430    
431                sumYF =+ yf;
432                sumSquaredYF =+ (yf * yf);
433            }
434    
435    
436            // Initialize the accuracy indicators
437            int n = value.ys.size();
438    
439            // Variance
440            // The estimate the value of the error variance is a measure of
441            // variability of the y values about the estimated line.
442            // http://home.ubalt.edu/ntsbarsh/Business-stat/opre504.htm
443            // s2 = SSE/(n-2) = sum (y - yf)2 /(n-2)
444            if (n > 2) {
445                double variance = sumErrSquared / (n - 2);
446    
447                value.setVariance(variance);
448            }
449    
450            // R2
451            // R2 = 1 - (SSE/SST)
452            // SSE = sum square error = Sum((error-MSE)*(error-MSE))
453            // MSE = mean error = Sum(error)/n
454            // SST = sum square y diff = Sum((y-MST)*(y-MST))
455            // MST = mean y = Sum(y)/n
456            double MSE = sumErr / n;
457            double MST = sumY / n;
458            double SSE = 0.0;
459            double SST = 0.0;
460            ity = value.ys.iterator();
461            ityf = yfs.iterator();
462            while (ity.hasNext()) {
463                // Get next data point
464                Double dy = (Double) ity.next();
465                if (dy == null) {
466                    continue;
467                }
468                Double dyf = (Double) ityf.next();
469                if (dyf == null) {
470                    continue;
471                }
472    
473                double y = dy.doubleValue();
474                double yf = dyf.doubleValue();
475    
476                double error = yf - y;
477                SSE += (error - MSE) * (error - MSE);
478                SST += (y - MST) * (y - MST);
479            }
480            if (SST != 0.0) {
481                double rSquared =  1 - (SSE / SST);
482    
483                value.setRSquared(rSquared);
484            }
485    
486    
487            return value;
488        }
489    
490        public static LinReg.Value linearReg(List xlist, List ylist) {
491    
492            // y and x have same number of points
493            int size = ylist.size();
494            double sumX = 0.0;
495            double sumY = 0.0;
496            double sumXX = 0.0;
497            double sumXY = 0.0;
498    
499    debug("LinReg.linearReg","ylist.size()=" + ylist.size());
500    debug("LinReg.linearReg","xlist.size()=" + xlist.size());
501            int n = 0;
502            for (int i = 0; i < size; i++) {
503                Object yo = ylist.get(i);
504                Object xo = xlist.get(i);
505                if ((yo == null) || (xo == null)) {
506                    continue;
507                }
508                n++;
509                double y = ((Double) yo).doubleValue();
510                double x = ((Double) xo).doubleValue();
511    
512    debug("LinReg.linearReg"," " + i + " (" + x + "," + y + ")");
513                sumX += x;
514                sumY += y;
515                sumXX += x * x;
516                sumXY += x * y;
517            }
518    
519            double xMean = sumX / n;
520            double yMean = sumY / n;
521    
522    debug("LinReg.linearReg", "yMean=" + yMean);
523    debug("LinReg.linearReg", "(n*sumXX - sumX*sumX)=" + (n * sumXX - sumX * sumX));
524            // The regression line is the line that minimizes the variance of the
525            // errors. The mean error is zero; so, this means that it minimizes the
526            // sum of the squares errors.
527            double slope = (n * sumXY - sumX * sumY) / (n * sumXX - sumX * sumX);
528            double intercept = yMean - slope * xMean;
529    
530            LinReg.Value value = new LinReg.Value(intercept, slope, xlist, ylist);
531    debug("LinReg.linearReg","value=" + value);
532    
533            return value;
534        }
535    
536    
537        public static List forecast(LinReg.Value value) {
538            List yfs = new ArrayList(value.xs.size());
539    
540            Iterator it = value.xs.iterator();
541            while (it.hasNext()) {
542                Double d = (Double) it.next();
543                // If the value is missing we still must put a place
544                // holder in the y axis, otherwise there is a discontinuity
545                // between the data and the fit.
546                if (d == null) {
547                    yfs.add(null);
548                } else {
549                    double x = d.doubleValue();
550                    double yf = value.intercept + value.slope * x;
551                    yfs.add(new Double(yf));
552                }
553            }
554    
555            return yfs;
556        }
557    
558        private static class LinRegCalc extends AbstractDoubleCalc {
559            private final ListCalc listCalc;
560            private final DoubleCalc yCalc;
561            private final DoubleCalc xCalc;
562            private final boolean tuples;
563            private final int regType;
564    
565            public LinRegCalc(
566                    ResolvedFunCall call,
567                    ListCalc listCalc,
568                    DoubleCalc yCalc,
569                    DoubleCalc xCalc,
570                    boolean tuples,
571                    int regType) {
572                super(call, new Calc[]{listCalc, yCalc, xCalc});
573                this.listCalc = listCalc;
574                this.yCalc = yCalc;
575                this.xCalc = xCalc;
576                this.tuples = tuples;
577                this.regType = regType;
578            }
579    
580            public double evaluateDouble(Evaluator evaluator) {
581                Value value =
582                        process(evaluator, listCalc, yCalc, xCalc, tuples);
583                if (value == null) {
584                    return FunUtil.DoubleNull;
585                }
586                switch (regType) {
587                case Intercept:
588                    return value.getIntercept();
589                case Slope:
590                    return value.getSlope();
591                case Variance:
592                    return value.getVariance();
593                case R2:
594                    return value.getRSquared();
595                default:
596                case Point:
597                    throw Util.newInternal("unexpected value " + regType);
598                }
599            }
600        }
601    }
602    
603    // End LinReg.java