// Copyright (C) 2010, Guy Barrand. All rights reserved.
// See the file tools.license for terms.

#ifndef tools_mat
#define tools_mat

#ifdef TOOLS_MEM
#include "mem"
#endif

#include "MATCOM"

//#include <cstring> //memcpy

//#define TOOLS_MAT_NEW
namespace tools {

template <class T,unsigned int D>
class mat {
  static const unsigned int _D2 = D*D;
  unsigned int dim2() const {return _D2;}
#define TOOLS_MAT_CLASS mat
  TOOLS_MATCOM
#undef TOOLS_MAT_CLASS
private:
#ifdef TOOLS_MEM
  static const std::string& s_class() {
    static const std::string s_v("tools::mat");
    return s_v;
  }
#endif
public:
  mat() {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    for(unsigned int i=0;i<_D2;i++) m_vec[i] = zero();
  }
  virtual ~mat() {
#ifdef TOOLS_MAT_NEW
    delete [] m_vec;
#endif
#ifdef TOOLS_MEM
    mem::decrement(s_class().c_str());
#endif
  }
public:
  mat(const mat& a_from) {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    _copy(a_from.m_vec);
  }
  mat& operator=(const mat& a_from){
    if(&a_from==this) return *this;
    _copy(a_from.m_vec);
    return *this;
  }
public:
  mat(const T a_v[]){
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
#ifdef TOOLS_MAT_NEW
    m_vec = new T[D*D];
#endif
    _copy(a_v);
  }
public:
  unsigned int dimension() const {return D;}
protected:
#ifdef TOOLS_MAT_NEW
  T* m_vec;
#else
  T m_vec[D*D];
#endif

private:static void check_instantiation() {mat<float,2> dummy;}
};

template <class T>
class nmat {
  unsigned int dim2() const {return m_D2;}
#define TOOLS_MAT_CLASS nmat
  TOOLS_MATCOM
#undef TOOLS_MAT_CLASS
private:
#ifdef TOOLS_MEM
  static const std::string& s_class() {
    static const std::string s_v("tools::nmat");
    return s_v;
  }
#endif
public:
  nmat(unsigned int a_D):m_D(a_D),m_D2(a_D*a_D),m_vec(0) {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    for(unsigned int i=0;i<m_D2;i++) m_vec[i] = zero();
  }
  virtual ~nmat() {
    delete [] m_vec;
#ifdef TOOLS_MEM
    mem::decrement(s_class().c_str());
#endif
  }
public:
  nmat(const nmat& a_from)
  :m_D(a_from.m_D),m_D2(a_from.m_D2),m_vec(0)
  {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    _copy(a_from.m_vec);
  }
  nmat& operator=(const nmat& a_from){
    if(&a_from==this) return *this;
    if(a_from.m_D!=m_D) {
      m_D = a_from.m_D;
      m_D2 = a_from.m_D2;
      delete [] m_vec;
      m_vec = new T[m_D2];
    }
    _copy(a_from.m_vec);
    return *this;
  }
public:
  nmat(unsigned int a_D,const T a_v[])
  :m_D(a_D),m_D2(a_D*a_D),m_vec(0)
  {
#ifdef TOOLS_MEM
    mem::increment(s_class().c_str());
#endif
    m_vec = new T[m_D2];
    _copy(a_v);
  }
public:
  unsigned int dimension() const {return m_D;}
protected:
  unsigned int m_D;
  unsigned int m_D2;
  T* m_vec;
private:static void check_instantiation() {nmat<float> dummy(2);}
};

template <class T,unsigned int D> 
inline nmat<T> copy(const mat<T,D>& a_from) {
  unsigned int D2 = D*D;
  nmat<T> v(D);
  for(unsigned int i=0;i<D2;i++) v[i] = a_from[i];
  return v;
}

template <class MAT> 
inline void conjugate(MAT& a_m) {
  typedef typename MAT::elem_t T;
  T* pos = const_cast<T*>(a_m.data());
  unsigned int D2 = a_m.dimension()*a_m.dimension();
  for(unsigned int i=0;i<D2;i++,pos++) {
    *pos = _conj(*pos); //T = std::complex<>
  }
}

template <class MAT> 
inline void dagger(MAT& a_m) {
  conjugate<MAT>(a_m);
  a_m.transpose();
}

//for sf, mf :
//template <class T,unsigned int D>
//inline const T* get_data(const mat<T,D>& a_v) {return a_v.data();}

template <class MAT>
inline MAT commutator(const MAT& a1,const MAT& a2) {
  return a1*a2-a2*a1;
}

template <class MAT>
inline MAT anticommutator(const MAT& a1,const MAT& a2) {
  return a1*a2+a2*a1;
}

template <class MAT>
inline void commutator(const MAT& a1,const MAT& a2,MAT& a_tmp,MAT& a_res) {
  a_res = a1;
  a_res *= a2;
  a_tmp = a2;
  a_tmp *= a1;
  a_res -= a_tmp;
}

template <class MAT,class T>
inline void commutator(const MAT& a1,const MAT& a2,MAT& a_tmp,T a_vtmp[],MAT& a_res) {
  a_res = a1;
  a_res.mul_mtx(a2,a_vtmp);
  a_tmp = a2;
  a_tmp.mul_mtx(a1,a_vtmp);
  a_res -= a_tmp;
}

template <class MAT>
inline void anticommutator(const MAT& a1,const MAT& a2,MAT& a_tmp,MAT& a_res) {
  a_res = a1;
  a_res *= a2;
  a_tmp = a2;
  a_tmp *= a1;
  a_res += a_tmp;
}

template <class MAT,class T>
inline void anticommutator(const MAT& a1,const MAT& a2,MAT& a_tmp,T a_vtmp[],MAT& a_res) {
  a_res = a1;
  a_res.mul_mtx(a2,a_vtmp);
  a_tmp = a2;
  a_tmp.mul_mtx(a1,a_vtmp);
  a_res += a_tmp;
}

template <class T,unsigned int D>
inline bool commutator_equal(const mat<T,D>& a_1,const mat<T,D>& a_2,const mat<T,D>& a_c,const T& a_prec) {
  unsigned int order = D;
  const T* p1 = a_1.data();
  const T* p2 = a_2.data();
  const T* pc = a_c.data();
  for(unsigned int r=0;r<order;r++) {
    for(unsigned int c=0;c<order;c++) {
      T _12 = T();
     {for(unsigned int i=0;i<order;i++) {
        _12 += (*(p1+r+i*order)) * (*(p2+i+c*order));
      }}
      T _21 = T();
     {for(unsigned int i=0;i<order;i++) {
        _21 += (*(p2+r+i*order)) * (*(p1+i+c*order));
      }}
      T diff = (_12-_21) - *(pc+r+c*order);
      if(diff<T()) diff *= T(-1);
      if(diff>=a_prec) return false;
    }
  }
  return true;
}

template <class T,unsigned int D>
inline bool anticommutator_equal(const mat<T,D>& a_1,const mat<T,D>& a_2,const mat<T,D>& a_c,const T& a_prec) {
  unsigned int order = D;
  const T* p1 = a_1.data();
  const T* p2 = a_2.data();
  const T* pc = a_c.data();
  for(unsigned int r=0;r<order;r++) {
    for(unsigned int c=0;c<order;c++) {
      T _12 = T();
     {for(unsigned int i=0;i<order;i++) {
        _12 += (*(p1+r+i*order)) * (*(p2+i+c*order));
      }}
      T _21 = T();
     {for(unsigned int i=0;i<order;i++) {
        _21 += (*(p2+r+i*order)) * (*(p1+i+c*order));
      }}
      T diff = (_12+_21) - *(pc+r+c*order);
      if(diff<T()) diff *= T(-1);
      if(diff>=a_prec) return false;
    }
  }
  return true;
}

template <class T,unsigned int D>
inline mat<T,D> operator-(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res -= a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator+(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res += a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator*(const mat<T,D>& a1,const mat<T,D>& a2) {
  mat<T,D> res(a1);
  res *= a2;
  return res;
}
template <class T,unsigned int D>
inline mat<T,D> operator*(const T& a_fac,const mat<T,D>& a_m) {
  mat<T,D> res(a_m);
  res *= a_fac;
  return res;
}
/*
template <class T,unsigned int D>
inline mat<T,D> operator*(const mat<T,D>& a_m,const T& a_fac) {
  mat<T,D> res(a_m);
  res *= a_fac;
  return res;
}
*/

template <class T>
inline nmat<T> operator-(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res -= a2;
  return res;
}
template <class T>
inline nmat<T> operator+(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res += a2;
  return res;
}
template <class T>
inline nmat<T> operator*(const nmat<T>& a1,const nmat<T>& a2) {
  nmat<T> res(a1);
  res *= a2;
  return res;
}
template <class T>
inline nmat<T> operator*(const T& a_fac,const nmat<T>& a_m) {
  nmat<T> res(a_m);
  res *= a_fac;
  return res;
}
/*
template <class T>
inline nmat<T> operator*(const nmat<T>& a_m,const T& a_fac) {
  nmat<T> res(a_m);
  res *= a_fac;
  return res;
}
*/

}

namespace tools {

////////////////////////////////////////////////
/// specific D=2 ///////////////////////////////
////////////////////////////////////////////////

template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11
){
  //a_<R><C>
  //vec[R + C * 2];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[2] = a_01;
  vec[1] = a_10;vec[3] = a_11;
}

template <class MAT>
inline void set_epsilon(MAT& a_m) {
  matrix_set<MAT>(a_m, 0, 1,
                      -1, 0);
}

//  Pauli matrices :
//  P1      P2      P3
//    0  1    0 -i    1  0
//    1  0    i  0    0 -1
template <class MAT>
inline void set_P1(MAT& a_m) {
  matrix_set<MAT>(a_m,  0, 1,
                        1, 0);
}

template <class MAT>
inline void set_P2(MAT& a_m) {
  typedef typename MAT::elem_t T;
  //T i;set_i(i); //with inlib::symbol
  T i(0,1);
  T _i = T(-1)*i;
  matrix_set<MAT>(a_m,  0, _i,
                        i, 0);
}
template <class MAT>
inline void set_P3(MAT& a_m) {
  matrix_set<MAT>(a_m, 1, 0,
                       0,-1);
}

////////////////////////////////////////////////
/// specific D=3 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22 //3 row
){
  //a_<R><C>
  //vec[R + C * 3];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[3] = a_01;vec[6] = a_02;
  vec[1] = a_10;vec[4] = a_11;vec[7] = a_12;
  vec[2] = a_20;vec[5] = a_21;vec[8] = a_22;
}

// Generators of rotation group :
// Rk(i,j) = epsilon(k,i,j)  i,j,k=1,2,3
template <class MAT>
inline void set_R1(MAT& a_m) {
  matrix_set<MAT>(a_m, 0, 0, 0,
                       0, 0, 1,
                       0,-1, 0);
}
template <class MAT>
inline void set_R2(MAT& a_m) {
  matrix_set<MAT>(a_m, 0, 0,-1,
                       0, 0, 0,
                       1, 0, 0);
}
template <class MAT>
inline void set_R3(MAT& a_m) {
  matrix_set<MAT>(a_m, 0, 1, 0,
                      -1, 0, 0,
                       0, 0, 0);
}

////////////////////////////////////////////////
/// specific D=4 ///////////////////////////////
////////////////////////////////////////////////
template <class MAT>
inline void matrix_set(MAT& a_m
,const typename MAT::elem_t& a_00,const typename MAT::elem_t& a_01,const typename MAT::elem_t& a_02,const typename MAT::elem_t& a_03 //1 row
,const typename MAT::elem_t& a_10,const typename MAT::elem_t& a_11,const typename MAT::elem_t& a_12,const typename MAT::elem_t& a_13 //2 row
,const typename MAT::elem_t& a_20,const typename MAT::elem_t& a_21,const typename MAT::elem_t& a_22,const typename MAT::elem_t& a_23 //3 row
,const typename MAT::elem_t& a_30,const typename MAT::elem_t& a_31,const typename MAT::elem_t& a_32,const typename MAT::elem_t& a_33 //4 row
){
  //a_<R><C>
  //vec[R + C * 4];
  typename MAT::elem_t* vec = const_cast<typename MAT::elem_t*>(a_m.data());
  vec[0] = a_00;vec[4] = a_01;vec[ 8] = a_02;vec[12] = a_03;
  vec[1] = a_10;vec[5] = a_11;vec[ 9] = a_12;vec[13] = a_13;
  vec[2] = a_20;vec[6] = a_21;vec[10] = a_22;vec[14] = a_23;
  vec[3] = a_30;vec[7] = a_31;vec[11] = a_32;vec[15] = a_33;
}

template <class MAT,class RANDOM>
inline void set_random_antisym(MAT& a_m,RANDOM& a_rd) {
  typedef typename MAT::elem_t T;
  T v01 = a_rd.shoot();
  T v02 = a_rd.shoot();
  T v03 = a_rd.shoot();
  T v12 = a_rd.shoot();
  T v13 = a_rd.shoot();
  T v23 = a_rd.shoot();
  matrix_set(a_m,
                    0, v01, v02, v03,
                 -v01,   0, v12, v13,
                 -v02,-v12,   0, v23,
                 -v03,-v13,-v23,   0);
}

template <class MAT>
inline void set_eta(MAT& a_m) {
  a_m.set_zero();
  a_m.set_value(0,0, 1);
  a_m.set_value(1,1,-1);
  a_m.set_value(2,2,-1);
  a_m.set_value(3,3,-1);
}

////////////////////////////////////////////////
/// specific D=6 ///////////////////////////////
////////////////////////////////////////////////

/*
template <class T>
inline void set_matrix(mat<T,6>& a_m
,const T& a_00,const T& a_01,const T& a_02,const T& a_03,const T& a_04,const T& a_05 //1 row
,const T& a_10,const T& a_11,const T& a_12,const T& a_13,const T& a_14,const T& a_15 //2 row
,const T& a_20,const T& a_21,const T& a_22,const T& a_23,const T& a_24,const T& a_25 //3 row
,const T& a_30,const T& a_31,const T& a_32,const T& a_33,const T& a_34,const T& a_35 //4 row
,const T& a_40,const T& a_41,const T& a_42,const T& a_43,const T& a_44,const T& a_45 //5 row
,const T& a_50,const T& a_51,const T& a_52,const T& a_53,const T& a_54,const T& a_55 //6 row
){
  //a_<R><C>
  //vec[R + C * 6];
  T* vec = const_cast<T*>(a_m.data());
  vec[0] = a_00;vec[6]  = a_01;vec[12] = a_02;vec[18] = a_03;vec[24] = a_04;vec[30] = a_05;
  vec[1] = a_10;vec[7]  = a_11;vec[13] = a_12;vec[19] = a_13;vec[25] = a_14;vec[31] = a_15;
  vec[2] = a_20;vec[8]  = a_21;vec[14] = a_22;vec[20] = a_23;vec[26] = a_24;vec[32] = a_25;
  vec[3] = a_30;vec[9]  = a_31;vec[15] = a_32;vec[21] = a_33;vec[27] = a_34;vec[33] = a_35;
  vec[4] = a_40;vec[10] = a_41;vec[16] = a_42;vec[22] = a_43;vec[28] = a_44;vec[34] = a_45;
  vec[5] = a_50;vec[11] = a_51;vec[17] = a_52;vec[23] = a_53;vec[29] = a_54;vec[35] = a_55;
}
*/

}

////////////////////////////////////////////////
////////////////////////////////////////////////
////////////////////////////////////////////////
#include <ostream>

namespace tools {

//NOTE : print is a Python keyword.
template <class MAT>
inline void dump(std::ostream& a_out,const std::string& aCMT,const MAT& a_matrix) {
  if(aCMT.size()) a_out << aCMT << std::endl;
  unsigned int D = a_matrix.dimension();
  for(unsigned int r=0;r<D;r++) {
    for(unsigned int c=0;c<D;c++) {
      a_out << " " << a_matrix.value(r,c);
    }
    a_out << std::endl;
  }
}

template <class MAT>
inline bool check_invert(const MAT& a_matrix,std::ostream& a_out) {
  MAT I;I.set_identity();
  MAT tmp;
  if(!a_matrix.invert(tmp)) return false;
  tmp.mul_mtx(a_matrix);
  if(!tmp.equal(I)) {
    dump(a_out,"problem with inv of :",a_matrix);
    return false;
  }
  return true;
}

}

#endif
