/*****
 * types.h
 * Andy Hammerlindl 2002/06/20
 *
 * Used by the compiler as a way to keep track of the type of a variable
 * or expression.
 *
 *****/

#ifndef TYPES_H
#define TYPES_H

#include <iostream>
#include <cstdio>
#include <cassert>

#include "errormsg.h"
#include "symbol.h"
#include "common.h"
#include "util.h"

using std::ostream;

using sym::symbol;

// Forward declaration.
namespace trans {
class access;
class varEntry;
}
namespace absyntax {
class varinit;
extern varinit *Default;
}

namespace types {

enum ty_kind {
  ty_null,
  ty_record,    // "struct" in Asymptote language
  ty_function,
  ty_overloaded,

#define PRIMITIVE(name,Name,asyName) ty_##name,
#define PRIMERROR
#include "primitives.h"
#undef PRIMERROR
#undef PRIMITIVE

  ty_array
};

// Forward declarations.
class ty;
struct signature;
typedef mem::vector<ty *> ty_vector;
typedef ty_vector::iterator ty_iterator;

// Checks if two types are equal in the sense of the language.
// That is primitive types are equal if they are the same kind.
// Structures are equal if they come from the same struct definition.
// Arrays are equal if their cell types are equal.
bool equivalent(const ty *t1, const ty *t2);

// If special is true, this is the same as above.  If special is false, just
// the signatures are compared.
bool equivalent(const ty *t1, const ty *t2, bool special);

class caster {
public:
  virtual ~caster() {}
  virtual trans::access *operator() (ty *target, ty *source) = 0;
  virtual bool castable(ty *target, ty *source) = 0;
};

class ty : public gc {
public:
  const ty_kind kind;
  ty(ty_kind kind)
    : kind(kind) {}
  virtual ~ty();

  virtual void print (ostream& out) const;
  virtual void printVar (ostream& out, string name) const {
    print(out);
    out << " " << name;
  }


  // Returns true if the type is a user-defined type or the null type.
  // While the pair, path, etc. are stored by reference, this is
  // transparent to the user.
  virtual bool isReference() {
    return true;
  }

  virtual signature *getSignature() {
    return 0;
  }

  virtual const signature *getSignature() const {
    return 0;
  }

  virtual bool primitive() {
    return false;
  }

  bool isError() const { return kind == ty_error; }
  bool isNotError() const { return !isError(); }

  // The following are only used by the overloaded type, but it is so common
  // to test for an overloaded type then iterate over its types, that this
  // allows the code:
  // if (t->isOverloaded()) {
  //   for (ty_iterator i = t->begin(); i != t->end(); ++i) {
  //     ...
  //   }
  // }
  // For speed reasons, only begin has an assert to test if t is overloaded.
  bool isOverloaded() const {
    return kind == ty_overloaded;
  }
  bool isNotOverloaded() const { return !isOverloaded(); }
  ty_iterator begin();
  ty_iterator end();

  // If a default initializer is not stored in the environment, the abstract
  // syntax asks the type if it has a "default" default initializer, by calling
  // this method.
  virtual trans::access *initializer() {
    return 0;
  }

  // If a cast function is not stored in the environment, ask the type itself.
  // This handles null->record casting, and the like.  The caster is used as a
  // callback to the environment for casts of subtypes.
  virtual trans::access *castTo(ty *, caster &) {
    return 0;
  }

  // Just checks if a cast is possible.
  virtual bool castable(ty *target, caster &c) {
    return castTo(target, c);
  }

  // For pair's x and y, and array's length, this is a special type of
  // "field".
  // In actually, it returns a function which takes the object as its
  // parameter and returns the necessary result.
  // These should not have public permission, as modifying them would
  // have strange results.
  virtual trans::varEntry *virtualField(symbol, signature *) {
    return 0;
  }

  // varGetType for virtual fields.
  // Unless you are using functions for virtual fields, the base implementation
  // should work fine.
  virtual ty *virtualFieldGetType(symbol id);

#if 0
  // Returns the type.  In case of functions, return the equivalent type
  // but with no default values for parameters.
  virtual ty *stripDefaults()
  {
    return this;
  }
#endif

  // Returns true if the other type is equivalent to this one.
  // The general function equivalent should be preferably used, as it properly
  // handles overloaded type comparisons.
  virtual bool equiv(const ty *other) const
  {
    return this==other;
  }


  // Returns a number for the type for use in a hash table.  Equivalent types
  // must yield the same number.
  virtual size_t hash() const = 0;
};

class primitiveTy : public ty {
public:
  primitiveTy(ty_kind kind)
    : ty(kind) {}

  bool primitive() {
    return true;
  }

  bool isReference() {
    return false;
  }

  ty *virtualFieldGetType(symbol );
  trans::varEntry *virtualField(symbol, signature *);

  bool equiv(const ty *other) const
  {
    return this->kind==other->kind;
  }

  size_t hash() const {
    return (size_t)kind + 47;
  }
};

class nullTy : public primitiveTy {
public:
  nullTy()
    : primitiveTy(ty_null) {}

  bool isReference() {
    return true;
  }

  trans::access *castTo(ty *target, caster &);

  size_t hash() const {
    return (size_t)kind + 47;
  }
};

// Ostream output, just defer to print.
inline ostream& operator<< (ostream& out, const ty& t)
{ t.print(out); return out; }

struct array : public ty {
  ty *celltype;
  ty *pushtype;
  ty *poptype;
  ty *appendtype;
  ty *inserttype;
  ty *deletetype;

  array(ty *celltype)
    : ty(ty_array), celltype(celltype), pushtype(0), poptype(0),
      appendtype(0), inserttype(0), deletetype(0) {}

  virtual bool isReference() {
    return true;
  }

  bool equiv(const ty *other) const {
    return other->kind==ty_array &&
      equivalent(this->celltype,((array *)other)->celltype);
  }

  size_t hash() const {
    return 1007 * celltype->hash();
  }

  Int depth() {
    if (array *cell=dynamic_cast<array *>(celltype))
      return cell->depth() + 1;
    else
      return 1;
  }

  void print(ostream& out) const
  { out << *celltype << "[]"; }

  ty *pushType();
  ty *popType();
  ty *appendType();
  ty *insertType();
  ty *deleteType();

  // Initialize to an empty array by default.
  trans::access *initializer();

  // NOTE: General vectorization of casts would be here.

  // Add length and push as virtual fields.
  ty *virtualFieldGetType(symbol id);
  trans::varEntry *virtualField(symbol id, signature *sig);
};

/* Base types */
#define PRIMITIVE(name,Name,asyName)            \
  ty *prim##Name();                             \
  ty *name##Array();                            \
  ty *name##Array2();                           \
  ty *name##Array3();
#define PRIMERROR
#include "primitives.h"
#undef PRIMERROR
#undef PRIMITIVE

ty *primNull();


struct formal {
  ty *t;
  symbol name;
  bool defval;
  bool Explicit;

  formal(ty *t,
         symbol name=symbol::nullsym,
         bool optional=false,
         bool Explicit=false)
    : t(t), name(name),
      defval(optional), Explicit(Explicit) {}

  // string->symbol translation is costly if done too many times.  This
  // constructor has been disabled to make this cost more visible to the
  // programmer.
#if 0
  formal(ty *t,
         const char *name,
         bool optional=false,
         bool Explicit=false)
    : t(t), name(symbol::trans(name)),
      defval(optional ? absyntax::Default : 0), Explicit(Explicit) {}
#endif

  friend ostream& operator<< (ostream& out, const formal& f);
};

bool equivalent(const formal& f1, const formal& f2);
bool argumentEquivalent(const formal &f1, const formal& f2);

typedef mem::vector<formal> formal_vector;

// Holds the parameters of a function and if they have default values
// (only applicable in some cases).
struct signature : public gc {
  formal_vector formals;

  // The number of keyword-only formals.  These formals always come after the
  // regular formals.
  size_t numKeywordOnly;

  // Formal for the rest parameter.  If there is no rest parameter, then the
  // type is null.
  formal rest;

  bool isOpen;

  signature()
    : numKeywordOnly(0), rest(0), isOpen(false)
  {}

  struct OPEN_t {};

  static const OPEN_t OPEN;

  explicit signature(OPEN_t) : numKeywordOnly(0), rest(0), isOpen(true) {}

  signature(signature &sig)
    : formals(sig.formals), numKeywordOnly(sig.numKeywordOnly),
      rest(sig.rest), isOpen(sig.isOpen)
  {}

  virtual ~signature() {}

  void add(formal f) {
    formals.push_back(f);
  }

  void addKeywordOnly(formal f) {
    add(f);
    ++numKeywordOnly;
  }

  void addRest(formal f) {
    rest=f;
  }

  bool hasRest() const {
    return rest.t;
  }
  size_t getNumFormals() const {
    return rest.t ? formals.size() + 1 : formals.size();
  }

  formal& getFormal(size_t n) {
    assert(n < formals.size());
    return formals[n];
  }
  const formal& getFormal(size_t n) const {
    assert(n < formals.size());
    return formals[n];
  }

  formal& getRest() {
    return rest;
  }
  const formal& getRest() const {
    return rest;
  }

  bool formalIsKeywordOnly(size_t n) const
  {
    assert(n < formals.size());
    return n >= formals.size() - numKeywordOnly;
  }

  friend string toString(const signature& s);
  friend ostream& operator<< (ostream& out, const signature& s);

  friend bool equivalent(const signature *s1, const signature *s2);

  // Check if a signature of argument types (as opposed to formal parameters)
  // are equivalent.  Here, the arguments, if named, must have the same names,
  // and (for simplicity) no overloaded arguments are allowed.
  friend bool argumentEquivalent(const signature *s1, const signature *s2);
#if 0
  friend bool castable(signature *target, signature *source);
  friend Int numFormalsMatch(signature *s1, signature *s2);
#endif

  size_t hash() const;

  // Return a unique handle for this signature
  size_t handle();
};

struct function : public ty {
  ty *result;
  signature sig;

  function(ty *result)
    : ty(ty_function), result(result) {}
  function(ty *result, signature::OPEN_t)
    : ty(ty_function), result(result), sig(signature::OPEN) {}
  function(ty *result, signature *sig)
    : ty(ty_function), result(result), sig(*sig) {}
  function(ty *result, formal f1)
    : ty(ty_function), result(result) {
    add(f1);
  }
  function(ty *result, formal f1, formal f2)
    : ty(ty_function), result(result) {
    add(f1);
    add(f2);
  }
  function(ty *result, formal f1, formal f2, formal f3)
    : ty(ty_function), result(result) {
    add(f1);
    add(f2);
    add(f3);
  }
  function(ty *result, formal f1, formal f2, formal f3, formal f4)
    : ty(ty_function), result(result) {
    add(f1);
    add(f2);
    add(f3);
    add(f4);
  }
  virtual ~function() {}

  void add(formal f) {
    sig.add(f);
  }

  void addRest(formal f) {
    sig.addRest(f);
  }

  virtual bool isReference() {
    return true;
  }

  bool equiv(const ty *other) const
  {
    if (other->kind==ty_function) {
      function *that=(function *)other;
      return equivalent(this->result,that->result) &&
        equivalent(&this->sig,&that->sig);
    }
    else return false;
  }

  size_t hash() const {
    return sig.hash()*0x1231+result->hash();
  }

  void print(ostream& out) const
  { out << *result << sig; }

  void printVar (ostream& out, string name) const {
    result->printVar(out,name);
    out << sig;
  }

  ty *getResult() {
    return result;
  }

  signature *getSignature() {
    return &sig;
  }

  const signature *getSignature() const {
    return &sig;
  }

#if 0
  ty *stripDefaults();
#endif

  // Initialized to null.
  trans::access *initializer();
};

// This is used in getType expressions when an overloaded variable is accessed.
class overloaded : public ty {
public:
  ty_vector sub;

  // Warning: The venv endScope routine relies heavily on the current
  // implementation of overloaded.
public:
  overloaded()
    : ty(ty_overloaded) {}
  overloaded(ty *t)
    : ty(ty_overloaded) { add(t); }
  virtual ~overloaded() {}

  bool equiv(const ty *other) const
  {
    for(ty_vector::const_iterator i=sub.begin();i!=sub.end();++i)
      if (equivalent(*i,other))
        return true;
    return false;
  }

  size_t hash() const {
    // Overloaded types should not be hashed.
    assert(False);
    return 0;
  }


#ifdef __clang__
#elif __GNUC__
#pragma GCC push_options
#pragma GCC optimize("O2")
#endif
  void add(ty *t) {
    if (t->kind == ty_overloaded) {
      overloaded *ot = (overloaded *)t;
      copy(ot->sub.begin(), ot->sub.end(),
           inserter(this->sub, this->sub.end()));
  }
    else
      sub.push_back(t);
  }
#ifdef __clang__
#elif __GNUC__
#pragma GCC pop_options
#endif

  // Only add a type distinct from the ones currently in the overloaded type.
  // If special is false, just the distinct signatures are added.
  void addDistinct(ty *t, bool special=false);

  // If there are less than two overloaded types, the type isn't really
  // overloaded.  This gives a more appropriate type in this case.
  ty *simplify() {
    switch (sub.size()) {
      case 0:
        return 0;
      case 1: {
        return sub.front();
      }
      default:
        return new overloaded(*this);
    }
  }

  // Returns the signature-less type of the set.
  ty *signatureless();

  // True if one of the subtypes is castable.
  bool castable(ty *target, caster &c);

  size_t size() const { return sub.size(); }

  // Use default printing for now.
};

inline ty_iterator ty::begin() {
  assert(this->isOverloaded());
  return ((overloaded *)this)->sub.begin();
}
inline ty_iterator ty::end() {
  return ((overloaded *)this)->sub.end();
}

// This is used to encapsulate iteration over the subtypes of an overloaded
// type.  The base method need only be implemented to handle non-overloaded
// types.
class collector {
public:
  virtual ~collector() {}
  virtual ty *base(ty *target, ty *source) = 0;

  virtual ty *collect(ty *target, ty *source) {
    if (overloaded *o=dynamic_cast<overloaded *>(target)) {
      ty_vector &sub=o->sub;

      overloaded *oo=new overloaded;
      for(ty_vector::iterator x = sub.begin(); x != sub.end(); ++x) {
        types::ty *t=collect(*x, source);
        if (t)
          oo->add(t);
      }

      return oo->simplify();
    }
    else if (overloaded *o=dynamic_cast<overloaded *>(source)) {
      ty_vector &sub=o->sub;

      overloaded *oo=new overloaded;
      for(ty_vector::iterator y = sub.begin(); y != sub.end(); ++y) {
        // NOTE: A possible speed optimization would be to replace this with a
        // call to base(), but this is only correct if we can guarantee that an
        // overloaded type has no overloaded sub-types.
        types::ty *t=collect(target, *y);
        if (t)
          oo->add(t);
      }

      return oo->simplify();
    }
    else
      return base(target, source);
  }
};

class tester {
public:
  virtual ~tester() {}
  virtual bool base(ty *target, ty *source) = 0;

  virtual bool test(ty *target, ty *source) {
    if (overloaded *o=dynamic_cast<overloaded *>(target)) {
      ty_vector &sub=o->sub;

      for(ty_vector::iterator x = sub.begin(); x != sub.end(); ++x)
        if (test(*x, source))
          return true;

      return false;
    }
    else if (overloaded *o=dynamic_cast<overloaded *>(source)) {
      ty_vector &sub=o->sub;

      for(ty_vector::iterator y = sub.begin(); y != sub.end(); ++y)
        if (base(target, *y))
          return true;

      return false;
    }
    else
      return base(target, source);
  }
};

} // namespace types

GC_DECLARE_PTRFREE(types::primitiveTy);
GC_DECLARE_PTRFREE(types::nullTy);

#endif
