// evaljit.c: the JIT code execution machine
//
// R : A Computer Language for Statistical Data Analysis
// Copyright (C) 1995, 1996    Robert Gentleman and Ross Ihaka
// Copyright (C) 1998--2006    The R Development Core Team.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, a copy is available at
// http://www.r-project.org/Licenses/

#ifdef HAVE_CONFIG_H
  #include "config.h"
#endif
#include "Defn.h"
#include "Print.h"
#define JIT_INTERNAL 1  // tell jit.h to include JIT tables etc.
#define JIT_NAMES    1  // tell jit.h to include opcodeName
#include "jit.h"
#include "printsxp.h"

#define MAX_STACK_LEN 1000

static SEXP stack[MAX_STACK_LEN];   // jit exec stack, stack[0] isn't used
int  istack;                        // index into exec stack
                                    // if istack > 0 then in evalJit
#define TOS  (stack[istack])
#define TOS1 (stack[istack-1])
#define TOS2 (stack[istack-2])

#define TOSR (REAL(stack[istack]))
#define TOSI (INTEGER(stack[istack]))

// Defs for binops e.g. "x + y" puts x at [istack-1] and y at [istack].
// Remember: "y is top of stack; x is top of stack minus 1".

#define X  (stack[istack-1])
#define Y  (stack[istack])
#define XR (REAL(stack[istack-1]))
#define YR (REAL(stack[istack]))
#define XI (INTEGER(stack[istack-1]))
#define YI (INTEGER(stack[istack]))

//-----------------------------------------------------------------------------
// All functions beginning with "D" contain only assertions and are
// not strictly needed.  "D" is for debug.  The D functions are in
// alphabetical order below.

#if DEBUG_JIT
static int jitEvalFail(const char *filename, unsigned nline,
                       const char *exp, const JIT_RECORD * const prec)
{
    if (prec)
        printf("Evaluating jitted expression %s\n",
               deparseAsShortString(prec->original));
    printf("istack %d\n", istack);
    if (istack >= 2)
        printfSxp(stack[istack],   "stack[istack]  ");
    if (istack >= 3)
        printfSxp(stack[istack-1], "stack[istack-1]");
    if (istack >= 4)
        printfSxp(stack[istack-2], "stack[istack-3]");
    printfSxp(R_CurrentExpr, "R_CurrentExpr");
    assertFail(filename, nline, exp);
    return 0;
}
#endif

// specialized assert for the evalJit, provides some extra details

#if DEBUG_JIT
#define DEvalAssert(exp, prec) (void)((exp) || \
                                jitEvalFail(__FILE__, __LINE__, #exp, prec))
#else
#define DEvalAssert(exp, prec)
#endif

static R_INLINE void DAssertCompatibleType(CSEXP x, SEXPTYPE expectedType)
{
#if DEBUG_JIT
    SEXPTYPE xtype = TYPEOF(x);
    if (xtype == LGLSXP)
        xtype = INTSXP;
    if (expectedType == LGLSXP)
        expectedType = INTSXP;
    DEvalAssert(xtype == expectedType, NULL);
#endif
}

static R_INLINE void DCheckAssign(const JIT_OP *op,
                                  SEXPTYPE expectedLhsType,
                                  SEXPTYPE expectedRhsType,
                                  int expectedRhsLen)
{
    DEvalAssert(istack >= 2, NULL);
#if DEBUG_JIT_SYM
    DCheckSymOpConsistency(op, TOS);
#endif
#if DEBUG_JIT
    {
    SEXP currentVal = getSymValFromLoc(op->operand);
    DAssertCompatibleType(currentVal, expectedLhsType);
    DEvalAssert(LENGTH(currentVal) == LENGTH(TOS), NULL);
    }
#endif
    DAssertCompatibleType(TOS, expectedRhsType);
    DEvalAssert(expectedRhsLen == -1 || LENGTH(TOS) == expectedRhsLen,
                NULL);
}

static R_INLINE void DCheckBinop_x_y(SEXPTYPE xtype, SEXPTYPE ytype,
                                     CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == LENGTH(Y), NULL);
    DEvalAssert(LENGTH(result) == length(X), NULL);
}

static R_INLINE void DCheckBinop_x_y1(SEXPTYPE xtype, SEXPTYPE ytype,
                                      CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(Y) == 1, NULL);
    DEvalAssert(LENGTH(result) == length(X), NULL);
}

static R_INLINE void DCheckBinop_x1_y(SEXPTYPE xtype, SEXPTYPE ytype,
                                      CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == 1, NULL);
    DEvalAssert(LENGTH(result) == length(Y), NULL);
}

static R_INLINE void DCheckBinop_x1_y1(SEXPTYPE xtype, SEXPTYPE ytype,
                                       CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == 1, NULL);
    DEvalAssert(LENGTH(Y) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

static R_INLINE void DCheckIf(const JIT_OP *op, SEXPTYPE iftype)
{
    DAssertCompatibleType(TOS, iftype);
}

static R_INLINE void DCheckMath1(       // vec arg to math1
                        CSEXP result, SEXPTYPE xtype, SEXPTYPE resultType)
{
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(result, resultType);
    DAssertCompatibleType(TOS, xtype);
    DEvalAssert(LENGTH(TOS) == LENGTH(result), NULL);
}

static R_INLINE void DCheckMath1_1(     // scalar arg to math1
                        CSEXP result, SEXPTYPE xtype, SEXPTYPE resultType)
{
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(result, resultType);
    DAssertCompatibleType(TOS, xtype);
    DEvalAssert(LENGTH(TOS) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

static R_INLINE void DCheckNot(CSEXP result, SEXPTYPE xtype)
{                                       // not, vec arg
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(TOS, xtype);
    DAssertCompatibleType(result, LGLSXP);
    DEvalAssert(LENGTH(TOS) == LENGTH(result), NULL);
}

static R_INLINE void DCheckNot1(CSEXP result, SEXPTYPE xtype)
{                                       // not, scalar arg
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(TOS, xtype);
    DAssertCompatibleType(result, LGLSXP);
    DEvalAssert(LENGTH(TOS) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

static R_INLINE void DCheckRelop_x_y(SEXPTYPE xtype, SEXPTYPE ytype,
                                     CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == LENGTH(Y), NULL);
    DEvalAssert(LENGTH(result) == length(X), NULL);
}

static R_INLINE void DCheckRelop_x_y1(SEXPTYPE xtype, SEXPTYPE ytype,
                                      CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(Y) == 1, NULL);
    DEvalAssert(LENGTH(result) == length(X), NULL);
}

static R_INLINE void DCheckRelop_x1_y(SEXPTYPE xtype, SEXPTYPE ytype,
                                      CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == 1, NULL);
    DEvalAssert(LENGTH(result) == length(Y), NULL);
}

static R_INLINE void DCheckRelop_x1_y1(SEXPTYPE xtype, SEXPTYPE ytype,
                                       CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DEvalAssert(LENGTH(X) == 1, NULL);
    DEvalAssert(LENGTH(Y) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

#if DEBUG_JIT_SYM
static void printConsistencyInfo(const JIT_OP *op)
{
    printfSxp(op->operand, "\noperand");
    printfSxp(op->sym, "sym");
    printfSxp(op->env, "env");
}
#endif

// Check that the pointer to a symbol binding (the operand arg below)
// is the same as that returned by findVarLoc. Also check the value.

#if DEBUG_JIT_SYM
static R_INLINE void DCheckSymOpConsistency(const JIT_OP *op, CSEXP tos)
{
    SEXP val, findVarVal;
    SEXP findVarLoc1 = findVarLoc(op->sym, op->env);
    if (findVarLoc1 != op->operand) {
        printConsistencyInfo(op);
        printfSxp(findVarLoc1, "findVarLoc1");
        assertFail(__FILE__, __LINE__, "see above messages");
    }
    val = getSymValFromLoc(op->operand);
    findVarVal = findVar(op->sym, op->env);
    if (val != findVarVal) {
        printConsistencyInfo(op);
        printfSxp(val, "val");
        printfSxp(findVarLoc1, "findVarLoc1");
        printfSxp(findVarVal,  "findVarVal");
        assertFail(__FILE__, __LINE__, "see above messages");
    }
    assert(tos == RNIL || LENGTH(tos) == LENGTH(val));
}
#endif

static R_INLINE void DCheckSubas(SEXP x, SEXP index, SEXP y,
                        SEXPTYPE xtype, SEXPTYPE indextype, SEXPTYPE ytype,
                        SEXP result, const JIT_RECORD * const prec)
{
    DEvalAssert(istack >= 4, prec);
    DAssertCompatibleType(x, xtype);
    DAssertCompatibleType(index, indextype);
    DAssertCompatibleType(y, ytype);
    DAssertCompatibleType(result, xtype);
    DEvalAssert(LENGTH(index) == 1, prec);
    DEvalAssert(LENGTH(result) == 1, prec);
}

static R_INLINE void DCheckSubscript_x_y1(SEXPTYPE xtype, SEXPTYPE ytype,
                                          CSEXP result)
{
    DEvalAssert(istack >= 3, NULL);
    DAssertCompatibleType(X, xtype);
    DAssertCompatibleType(Y, ytype);
    DAssertCompatibleType(result, xtype);
    DEvalAssert(LENGTH(Y) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

static R_INLINE void DCheckUminus(CSEXP result, SEXPTYPE xtype)
{                                       // unary minus, vec arg
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(TOS, xtype);
    DAssertCompatibleType(result, xtype);
    DEvalAssert(LENGTH(TOS) == LENGTH(result), NULL);
}

static R_INLINE void DCheckUminus1(CSEXP result, SEXPTYPE xtype)
{                                       // unary minus, scalar arg
    DEvalAssert(istack >= 2, NULL);
    DAssertCompatibleType(TOS, xtype);
    DAssertCompatibleType(result, xtype);
    DEvalAssert(LENGTH(TOS) == 1, NULL);
    DEvalAssert(LENGTH(result) == 1, NULL);
}

//-----------------------------------------------------------------------------
// Execution tracing functions.  Empty unless DEBUG_JIT >= 2.

static R_INLINE void traceEvalBegin(const JIT_RECORD * const p, int jitEvalDepth)
{
#if DEBUG_JIT >= 2
    printf("# %d evalJit ", jitEvalDepth-1);
    printJitHeader(p);
#endif
}
static R_INLINE void traceEvalEnd(const JIT_RECORD * const p, int jitEvalDepth)
{
#if DEBUG_JIT >= 2
    printfSxp(p->ans, "# %d evalJit returns", jitEvalDepth-1);
    printf("#\n");
#endif
}
static R_INLINE void traceOpBegin(const JIT_OP * const op, int jitEvalDepth)
{
#if DEBUG_JIT >= 2
    printf("# %d evalJit    ", jitEvalDepth-1);
    printJitOp(op);
#endif
}
static R_INLINE void traceOpEnd(CSEXP tos, int jitEvalDepth)
{
#if DEBUG_JIT >= 2
    printfSxp(tos, "# %d evalJit\t\t\t\t\t\t\t-->", jitEvalDepth-1);
#endif
}

//-----------------------------------------------------------------------------
// Utilities for executing instructions

static void stackOverflowError(const JIT_RECORD * const prec)
{
    error(_("JIT stack overflow evaluating jitted expression %s"),
          deparseAsShortString(prec->original));
}
static R_INLINE void incStack(const JIT_RECORD * const prec)
{
    if (istack >= MAX_STACK_LEN-1)
        stackOverflowError(prec);
    istack++;
}
static void stackUnderflowError(const JIT_RECORD * const prec)
{
    error(_("JIT stack underflow evaluating jitted expression %s"),
          deparseAsShortString(prec->original));
}
static R_INLINE void drop1(const JIT_RECORD * const prec)
{
    if (istack <= 1)     // stack starts at 1 (not 0)
        stackUnderflowError(prec);
    istack--;
}
static R_INLINE void drop2(const JIT_RECORD * const prec)
{
    if (istack <= 2)     // stack starts at 1 (not 0)
        stackUnderflowError(prec);
    istack -= 2;
}
static R_INLINE void CheckSubscriptRange_i(CSEXP x, int i,
                                           const JIT_RECORD * const prec)
{
    if (i < 0 || i >= LENGTH(x)) {
        if (i == R_NaInt-1)     // subtract 1 to match 1 index offset
            error(_("integer NA index in jitted expression %s"),
                  deparseAsShortString(prec->original));
        else
            error(_("out of range index %d (allowed range is 1 to %d)\n"
                  "       in jitted expression %s"),
                  i+1, LENGTH(x), deparseAsShortString(prec->original));
    }
}
static R_INLINE void CheckSubscriptRange_r(CSEXP x, double i,
                                           const JIT_RECORD * const prec)
{
    // test is "backwards" so NaNs are detected too (NaNs always compare false)

    if (i >= 0 && i < LENGTH(x))
        ;
    else if (isnan(i))
        error(_("NaN index in jitted expression %s"),
              deparseAsShortString(prec->original));
    else
        error(_("out of range index %g (allowed range is 1 to %d)\n"
              "       in jitted expression %s"),
              i+1, LENGTH(x), deparseAsShortString(prec->original));
}
static R_INLINE double FMOD(double x, double y) // for R %% operator
{
    return x - floor(x / y) * y;
}
static R_INLINE int IMOD(int x, int y, const JIT_RECORD * const prec)
{
    if (x >= 0 && y > 0)    // "correct" use of modulo comes first, for speed
        return x % y;

    if (y == 0)
        error(_("integer divide by 0 in \"%%%%\" "
              "in jitted expression %s\n"),
              deparseAsShortString(prec->original));

    return (int)FMOD((double)x, (double)y);     // same as arithmetic.c
}
static R_INLINE double FIDIV(double x, double y)   // for R %/% operator
{
    return floor(x / y);
}
static R_INLINE int IDIV(int x, int y, const JIT_RECORD * const prec)
{
    if (y == 0)
        error(_("integer divide by 0 in \"%%/%%\" "
              "in jitted expression %s\n"),
              deparseAsShortString(prec->original));

    return (int)floor((double)x / (double)y);   // same as arithmetic.c
}

static R_INLINE void pushSym(const JIT_OP *op,
                             const JIT_RECORD * const prec)
{
#if DEBUG_JIT_SYM
    DCheckSymOpConsistency(op, RNIL);
#endif
    incStack(prec);
    TOS = getSymValFromLoc(op->operand);
    if (TYPEOF(TOS) == PROMSXP)
        TOS = PRVALUE(TOS);
    DEvalAssert(TOS != R_UnboundValue && TOS != RNIL, prec);
}

static R_INLINE void evalJitAs(JIT_OP *op,      // jitted assignment
                               SEXPTYPE expectedLhsType,
                               SEXPTYPE expectedRhsType,
                               int expectedRhsLen)
{
    DCheckAssign(op, expectedLhsType, expectedRhsType, expectedRhsLen);

    // Tests show that a memcpy here is faster than a
    // "duplicate if NAMED", probably because no allocVectors are needed.

    const int n = op->n;        // number of bytes to copy
    if (n == sizeof(int))       // special handling for common cases
        INTEGER(op->result)[0] = INTEGER(TOS)[0];
    else if (n == sizeof(double))
        REAL(op->result)[0] = REAL(TOS)[0];
    else
        memcpy(DATAPTR(op->result), DATAPTR(TOS), n);

    TOS = op->result;

    R_set_binding_value(op->operand, TOS);

    // Set NAMED for compatibility with do_set in case results gets
    // passed on to eval() later.

    if (NAMED(TOS) == 0)
        SET_NAMED(TOS, 1);
    else if (NAMED(TOS) == 1)
        SET_NAMED(TOS, 2);
}

// About disallowAssignToLoopVar: assigns to loop vars in evalJitFor
// give results different from standard for loops (because in evalJitFor
// we don't call setVar on every iteration of the loop).  To prevent users
// from shooting themselves in the foot, we don't allow users to assign
// to jitted loop vars.  For efficiency we only detect this with
// disallowAssignToLoopVar at the end of loop.  Example:
//
//      for (i in 1:3) {
//          cat("before", i)
//          i = 9L
//          cat(" after", i, "\n")
//      }
//
// Standard R execution:
//                      before 1 after 9
//                      before 2 after 9
//                      before 3 after 9
// Jitted code (note that the loop still gets correctly executed 3 times):
//                      before 1 after 9
//                      before 9 after 9
//                      before 9 after 9
//                      Error: assignment to jitted loop variable "i"

static R_INLINE void disallowAssignToLoopVar(CSEXP v, CSEXP sym, CSEXP rho)
{
    CSEXP finalv = findVar(sym, rho);

    if (finalv != v)    // v was value at the start of the loop
        error(_("assignment to loop variable \"%s\" in a jit(2) block\n"),
              CHAR(PRINTNAME(sym)));
}

static R_INLINE void traceLoopIteration1(const char msg[],
                                         const SEXP indexVar, int i)
{
#if DEBUG_JIT > 1
    if (jitDirective && (traceEvalFlag || jitTrace >= 4))
        printLoopIteration(msg, indexVar, i);
#endif
}

// This is an experimental fast for.  It is only invoked when
// jitDirective==2 and only on inner loops with an integer loop variable.
//
// It is called from jitEval. The instruction sequence looks like this:
//   address  opcode params
//   op       for_i  operand=indexVarBindingLoc  sym=sym  env=env
//   op+1     args   operand=evaluated rhs of index expression
//   op+2     args   operand=body (LANGSXP or JITSXP)
//
// NOTE: The generated JIT code only gives correct results if successive
// invocations of the "for" have the same rhs in the header.  For example,
// "for (i in 1:n)"  is ok only if n doesn't change the next time
// the "for" is invoked.

// NOTE: A break or next in the loop body will cause problems
//       since there is no call to setcontext.
//
// NOTE: named handling is broken in this routine. Example:
// foo <- function() {
//     jit(jit.flag); j <- 1; y <- 3L; x <- double(15);
//     for (k in 1:3) {
//       for (i in 5:9) {
//         x[j] <- y  # after this, y is NAMED
//         y <- i
//         j <- j + 1
//       }
//     }
//     cat("x", x, "\n") # results different when jit(2)
// }

static R_INLINE void evalJitFor(JIT_OP *op)
{
    CSEXP indexLoc = op->operand;
    Dassert(TYPEOF(indexLoc) == LISTSXP);
    Dassert(TYPEOF(CAR(indexLoc)) == INTSXP);
    SEXP v = CAR(indexLoc);
    int * const pv = INTEGER(v);
    CSEXP sym = op->sym;
    CSEXP env = op->env;

    Dassert((op+1)->opcode == JIT_arg);
    CSEXP rhs = (op+1)->operand;
    Dassert(TYPEOF(rhs) == INTSXP);
    const int n = LENGTH(rhs);
    const int * const pRhs = INTEGER(rhs);

    Dassert((op+2)->opcode == JIT_arg);
    CSEXP body = (op+2)->operand;

    int istackStart = istack;
    int i;
    if (TYPEOF(body) == JITSXP) {
        for (i = 0; i < n; i++) {
            traceLoopIteration1("in evalJit: ", indexLoc, i);
            pv[0] = pRhs[i];
            evalJit(body);
        }
    } else {
        for (i = 0; i < n; i++) {
            traceLoopIteration1("in evalJit: bodyJitted", indexLoc, i);
            pv[0] = pRhs[i];
            eval(body, env);
        }
    }
    disallowAssignToLoopVar(v, sym, env);
    istack = istackStart;
    TOS = RNIL;         // value returned by jitted for is RNIL
}

//-----------------------------------------------------------------------------
// The JIT code execution machine. This can be called recursively.
// The execution stack is global. The initial value of istack is set
// above 0 so we can use istack as an "in evaJIT" indicator.
// RA_TODO I think no one uses istack for that purpose.
//
// RA_TODO istackStart is needed because expressions like "1;2;3" cause
//         stack growth and would eventually overflow stack if in loop
//         Is there a better i.e. faster solution?

SEXP evalJit(CSEXP e)
{
    JIT_RECORD * const p = (JIT_RECORD *)RAW(e->u.jitsxp.pjit);
    JIT_OP *op = p->ops;                // pointer to current jit op
    int istackStart = istack;

    double r, *xr, *yr, *ar, xr0, yr0;  // temps for faster code
    int i, n, *xi, *yi, *ai, xi0, yi0;
    FUNC_TYPE   func;
    IFUNC_TYPE  ifunc;
    SEXP result;
    static int jitEvalDepth;            // only used in debug prints

#if DEBUG_JIT >= 2
    jitEvalDepth++;
#endif
    Dassert(TYPEOF(e) == JITSXP);
    Dassert(TYPEOF(CAR(e)) == RAWSXP);
    traceEvalBegin(p, jitEvalDepth);

    istack++;                           // start at non zero

    while(1) { // JIT_endop in evaljit1.c will exit this loop
        traceOpBegin(op, jitEvalDepth);
        switch (op->opcode) {
            #include "evaljit1.c"
        }
        traceOpEnd(TOS, jitEvalDepth);
        op++;
    }
    done:
    p->ans = stack[istack];
    istack = istackStart;
    traceEvalEnd(p, jitEvalDepth);
#if DEBUG_JIT >= 2
    jitEvalDepth--;
#endif
    return p->ans;
}
