/*****
 * runmath.in
 *
 * Runtime functions for math operations.
 *
 *****/

pair     => primPair()
realarray* => realArray()
pairarray* => pairArray()

#include <chrono>
#include <random>
#include <memory>
#include <cmath>
#include <inttypes.h>

#include "mathop.h"
#include "path.h"

using namespace camp;

typedef array realarray;
typedef array pairarray;

using types::realArray;
using types::pairArray;

using run::integeroverflow;
using vm::vmFrame;

const char *invalidargument="invalid argument";

extern uint32_t CLZ(uint32_t a);

inline unsigned intbits() {
  static unsigned count=0;
  if(count > 0) return count;
  while((1ULL << count) < Int_MAX)
    ++count;
  ++count;
  return count;
}

static const unsigned char BitReverseTable8[256]=
{
#define R2(n)     n,    n+2*64,    n+1*64,    n+3*64
#define R4(n) R2(n),R2(n+2*16),R2(n+1*16),R2(n+3*16)
#define R6(n) R4(n),R4(n+2*4 ),R4(n+1*4 ),R4(n+3*4 )
  R6(0),R6(2),R6(1),R6(3)
};
#undef R2
#undef R4
#undef R6

unsigned long long bitreverse8(unsigned long long a)
{
  return
    (unsigned long long) BitReverseTable8[a];
}

unsigned long long bitreverse16(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 8)]);
}

unsigned long long bitreverse24(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 16)]);
}

unsigned long long bitreverse32(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 24) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 16) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 24)]);
}

unsigned long long bitreverse40(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 32) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 24) |
    ((unsigned long long) BitReverseTable8[(a >> 16) & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 24) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 32)]);
}

unsigned long long bitreverse48(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 40) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 32) |
    ((unsigned long long) BitReverseTable8[(a >> 16) & 0xff] << 24) |
    ((unsigned long long) BitReverseTable8[(a >> 24) & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 32) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 40)]);
}

unsigned long long bitreverse56(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 48) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 40) |
    ((unsigned long long) BitReverseTable8[(a >> 16) & 0xff] << 32) |
    ((unsigned long long) BitReverseTable8[(a >> 24) & 0xff] << 24) |
    ((unsigned long long) BitReverseTable8[(a >> 32) & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 40) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 48)]);
}

unsigned long long bitreverse64(unsigned long long a)
{
  return
    ((unsigned long long) BitReverseTable8[a & 0xff] << 56) |
    ((unsigned long long) BitReverseTable8[(a >> 8) & 0xff] << 48) |
    ((unsigned long long) BitReverseTable8[(a >> 16) & 0xff] << 40) |
    ((unsigned long long) BitReverseTable8[(a >> 24) & 0xff] << 32) |
    ((unsigned long long) BitReverseTable8[(a >> 32) & 0xff] << 24) |
    ((unsigned long long) BitReverseTable8[(a >> 40) & 0xff] << 16) |
    ((unsigned long long) BitReverseTable8[(a >> 48) & 0xff] << 8) |
    ((unsigned long long) BitReverseTable8[(a >> 56)]);
}

#ifndef HAVE_POPCOUNT
// https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
#define T unsignedInt
Int popcount(T a)
{
  a=a-((a >> 1) & (T)~(T)0/3);
  a=(a & (T)~(T)0/15*3)+((a >> 2) & (T)~(T)0/15*3);
  a=(a+(a >> 4)) & (T)~(T)0/255*15;
  return (T)(a*((T)~(T)0/255)) >> (sizeof(T)-1)*CHAR_BIT;
}
#undef T
#endif

// Return the factorial of a non-negative integer using a lookup table.
Int factorial(Int n)
{
  static Int *table;
  static Int size=0;
  if(size == 0) {
    Int f=1;
    size=2;
    while(f <= Int_MAX/size)
      f *= (size++);
    table=new Int[size];
    table[0]=f=1;
    for(Int i=1; i < size; ++i) {
      f *= i;
      table[i]=f;
    }
  }
  if(n >= size) integeroverflow(0);
  return table[n];
}

static inline Int Round(double x)
{
  return Int(x+((x >= 0) ? 0.5 : -0.5));
}

inline Int sgn(double x)
{
  return (x > 0.0 ? 1 : (x < 0.0 ? -1 : 0));
}

namespace
{
Int makeRandomSeed() {
  std::random_device rd;
  std::uniform_int_distribution<int64_t> dist;
  return dist(rd);
}

std::mt19937_64 randEngine(makeRandomSeed());
}

// Autogenerated routines:


real ^(real x, Int y)
{
  return pow(x,y);
}

pair ^(pair z, Int y)
{
  return pow(z,y);
}

Int quotient(Int x, Int y)
{
  return quotient<Int>()(x,y);
}

Int abs(Int x)
{
  return Abs(x);
}

Int sgn(real x)
{
  return sgn(x);
}

Int rand(Int a=0, Int b=Int_MAX)
{
  std::uniform_int_distribution dist((unsigned long long) a, (unsigned long long) b);
  return dist(randEngine);
}

void srand(Int seed)
{
  if(seed < 0)
    seed=makeRandomSeed();
  randEngine=std::mt19937_64(seed);
}

// a random number uniformly distributed in the interval [0,1)
real unitrand()
{
  std::uniform_real_distribution<double> dist(0.0, 1.0);
  return dist(randEngine);
}

Int ceil(real x)
{
  return Intcast(ceil(x));
}

Int floor(real x)
{
  return Intcast(floor(x));
}

Int round(real x)
{
  if(validInt(x)) return Round(x);
  integeroverflow(0);
}

Int Ceil(real x)
{
  return Ceil(x);
}

Int Floor(real x)
{
  return Floor(x);
}

Int Round(real x)
{
  return Round(Intcap(x));
}

real fmod(real x, real y)
{
  if (y == 0.0) dividebyzero();
  return fmod(x,y);
}

real atan2(real y, real x)
{
  return atan2(y,x);
}

real hypot(real x, real y)
{
  return hypot(x,y);
}

real remainder(real x, real y)
{
  return remainder(x,y);
}

real Jn(Int n, real x)
{
  return jn(n,x);
}

real Yn(Int n, real x)
{
  return yn(n,x);
}

real erf(real x)
{
  return erf(x);
}

real erfc(real x)
{
  return erfc(x);
}

Int factorial(Int n) {
  if(n < 0) error(invalidargument);
  return factorial(n);
}

Int choose(Int n, Int k) {
  if(n < 0 || k < 0 || k > n) error(invalidargument);
  Int f=1;
  Int r=n-k;
  for(Int i=n; i > r; --i) {
    if(f > Int_MAX/i) integeroverflow(0);
    f=(f*i)/(n-i+1);
  }
  return f;
}

real gamma(real x)
{
  return std::tgamma(x);
}

realarray *quadraticroots(real a, real b, real c)
{
  quadraticroots q(a,b,c);
  array *roots=new array(q.roots);
  if(q.roots >= 1) (*roots)[0]=q.t1;
  if(q.roots == 2) (*roots)[1]=q.t2;
  return roots;
}

pairarray *quadraticroots(explicit pair a, explicit pair b, explicit pair c)
{
  Quadraticroots q(a,b,c);
  array *roots=new array(q.roots);
  if(q.roots >= 1) (*roots)[0]=q.z1;
  if(q.roots == 2) (*roots)[1]=q.z2;
  return roots;
}

realarray *cubicroots(real a, real b, real c, real d)
{
  cubicroots q(a,b,c,d);
  array *roots=new array(q.roots);
  if(q.roots >= 1) (*roots)[0]=q.t1;
  if(q.roots >= 2) (*roots)[1]=q.t2;
  if(q.roots == 3) (*roots)[2]=q.t3;
  return roots;
}


// Logical operations

bool !(bool b)
{
  return !b;
}

bool :boolMemEq(vmFrame *a, vmFrame *b)
{
  return a == b;
}

bool :boolMemNeq(vmFrame *a, vmFrame *b)
{
  return a != b;
}

bool :boolFuncEq(callable *a, callable *b)
{
  return a->compare(b);
}

bool :boolFuncNeq(callable *a, callable *b)
{
  return !(a->compare(b));
}


// Bit operations

Int AND(Int a, Int b)
{
  return a & b;
}

Int OR(Int a, Int b)
{
  return a | b;
}

Int XOR(Int a, Int b)
{
  return a ^ b;
}

Int NOT(Int a)
{
  return ~a;
}

Int CLZ(Int a)
{
  if((unsigned long long) a > 0xFFFFFFFF)
    return CLZ((uint32_t) ((unsigned long long) a >> 32));
  else {
    int bits=intbits();
    if(a != 0) return bits-32+CLZ((uint32_t) a);
    return bits;
  }
}

Int popcount(Int a)
{
  return popcount(a);
}

Int CTZ(Int a)
{
  return popcount((a&-a)-1);
}

// bitreverse a within a word of length bits.
Int bitreverse(Int a, Int bits)
{
  typedef unsigned long long Bitreverse(unsigned long long a);
  static Bitreverse *B[]={bitreverse8,bitreverse16,bitreverse24,bitreverse32,
                          bitreverse40,bitreverse48,bitreverse56,bitreverse64};
  int maxbits=intbits()-1; // Drop sign bit
#if Int_MAX2 >= 0x7fffffffffffffffLL
  --maxbits;               // Drop extra bit for reserved values
#endif
  if(bits <= 0 || bits > maxbits || a < 0 ||
     (unsigned long long) a >= (1ULL << bits))
    return -1;
  unsigned int bytes=(bits+7)/8;
  return B[bytes-1]((unsigned long long) a) >> (8*bytes-bits);
}
