/* -*- C++ -*-
 *
 * ---------------------------------------------------------------------
 * $Id: index_iter.h,v 1.2.2.4 2004/05/06 23:15:13 drory Exp $
 * ---------------------------------------------------------------------
 *
 * Copyright (C) 2000-2002 Niv Drory <drory@usm.uni-muenchen.de>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2, or (at your option)
 * any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA 
 *
 * ---------------------------------------------------------------------
 *
 */

#ifndef __LTL_IN_FILE_MARRAY__
#error "<ltl/marray/index_iter.h> must be included via <ltl/marray.h>, never alone!"
#endif

#ifndef __LTL_INDEXITER__
#define __LTL_INDEXITER__

#include <ltl/config.h>

LTL_BEGIN_NAMESPACE

/*! \ingroup marray_class
  Iterator object giving vector of index values for each position.
  Used for implementing ltl::IndexList and ltl::where().
*/
template<class T, int N>
class IndexIter
{
   public:
      enum { numIndexIter = 1 };
      enum { isVectorizable = 0};
         
      IndexIter( const MArray<T,N>& array );
      IndexIter( const Shape<N>* s );
      IndexIter( const IndexIter<T,N>& other )
            : strides_(other.strides_), first_(other.first_), last_(other.last_),
            pos_(other.pos_), stride_(other.stride_), done_(other.done_),
            shape_(other.shape_)
      { }

      // reset iterator (back to first element)
      void reset();

      inline IndexIter<T,N>& operator++()
      {
         advance(); // increment innermost loop

         if( N == 1 )  // will be const-folded away, dont worry ...
            return *this;

         if( pos_(1) <= last_(1) )
            // We hit this case almost all the time.
            return *this;


         advanceDim(); // hit end of row/columns/whatever, need to reset loops
         return *this;
      }

      inline void advance()
      {
         pos_(1) += stride_;
      }

      inline void advanceN( const int i )
      {
         LTL_ASSERT_( 0, "LTL internal error : advanceN() on IndexIter!\n" );
      }

      inline void advanceWithStride1()
      {
         ++pos_(1);
      }

      inline int readWithStride( const int i ) const
      {
         LTL_ASSERT_( 0, "LTL internal error : readWithStride() on IndexIter!\n" );
         return 0;
      }

      inline int readWithoutStride( const int i ) const
      {
         LTL_ASSERT_( 0, "LTL internal error : readWithoutStride() on IndexIter!\n" );
         return 0;         
      }

      void advanceDim();

      inline bool needAdvanceDim() const
      {
         if( N == 1 )
            return false;
         else
            return pos_(1) > last_(1);
      }

      void operator++( int )
      {
         ++(*this);
      }

      bool done() const
      {
         if( N == 1 )
            return pos_(1) > last_(1);
         else
            return done_;
      }

      const FixedVector<int,N>& index() const
      {
         return pos_;
      }

      int index( const int dim ) const
      {
         LTL_ASSERT( (dim>0 && dim<=N), "Bad dimension "<<dim
                     <<" for IndexIterator<"<<N<<"!" );
         return pos_(dim);
      }

      const FixedVector<int,N>& operator*() const
      {
         return index();
      }

      const FixedVector<int,N>& operator()() const
      {
         return index();
      }

      const int operator()( const int dim ) const
      {
         return index(dim);
      }

 #ifdef LTL_USE_SIMD
      inline typename VEC_TYPE(int) readVec( const int i ) const
      {
         return (typename VEC_TYPE(int))(0); 
      }

      inline bool sameAlignmentAs( void* p ) const
      {
         return false;
      }
#endif

      void printRanges() const;

      const Shape<N>* shape() const
      {
         return shape_;
      }

      // we can't use loop collapsing, since we need to keep track
      // of all indices ...
      bool isStorageContiguous( void ) const
      {
         return N==1;
      }

      bool isStride1() const
      {
         return stride_==1;
      }

      bool isConformable( const Shape<N>& other ) const
      {
         return shape_->isConformable( other );
      }

   protected:
      FixedVector<int,N> strides_;
      FixedVector<int,N> first_;
      FixedVector<int,N> last_;
      FixedVector<int,N> pos_;
      int stride_;
      bool done_;
      const Shape<N> *shape_;
};


template<class T, int N>
IndexIter<T,N>::IndexIter( const MArray<T,N>&array )
{
   // copy information to avoid dereferencing array all the time
   for( int i=1; i<=N; i++ )
   {
      last_(i) = array.maxIndex(i);
      pos_(i) = first_(i) = array.minIndex(i);
      strides_(i) = (last_(i) - first_(i)) / array.length(i) + 1;
   }

   stride_ = strides_(1);
   done_ = false;
   shape_ = array.shape();
}

template<class T, int N>
IndexIter<T,N>::IndexIter( const Shape<N>* s )
{
   for( int i=1; i<=N; i++ )
   {
      last_(i) = s->last(i);
      pos_(i) = first_(i) = s->base(i);
      strides_(i) = (last_(i) - first_(i)) / s->length(i) + 1;
   }

   stride_ = strides_(1);
   done_ = false;
   shape_ = s;
}

// reset the iterator (back to first element)
//
template<class T, int N>
void IndexIter<T,N>::reset()
{
   for( int i=1; i<=N; i++ )
      pos_(i) = first_(i);

   done_ = false;
}


template<class T, int N>
void IndexIter<T,N>::advanceDim()
{
   // We've hit the end of a row/column/whatever.  Need to
   // increment one of the loops over another dimension.
   int j=2;
   for( ; j<=N; ++j )
   {
      pos_(j) += strides_(j);
      if( pos_(j) <= last_(j) )
         break;
   }

   // are we finished?
   if ( j > N )
   {
      done_ = true;
      return;
   }

   // Now reset all the last pointers
   for (--j; j > 0; --j)
   {
      pos_(j) = first_(j);
   }
   return;
}

template<class T, int N>
void IndexIter<T,N>::printRanges() const
{
   cerr << "Ranges: ";
   for(int i=1; i<=N; i++)
      cerr << "(" << last_(i) - first_(i) + 1 << ")  ";
}


//
// -----------------------------------------------------------------
//



template<class T, class RetType, int N>
class IndexIterDimExpr : public _et_parse_base
{
   public:
      typedef RetType value_type;
      enum { numIndexIter = IndexIter<T,N>::numIndexIter };
      enum { isVectorizable = IndexIter<T,N>::isVectorizable };

      IndexIterDimExpr( const IndexIter<T,N>& a, const int dim )
            : iter_(a), dim_(dim)
      { }

      IndexIterDimExpr( const IndexIterDimExpr<T,RetType,N>& other )
            : iter_(other.iter_), dim_(other.dim_)
      { }

      value_type operator*() const
      {
         return static_cast<value_type>(iter_.index(dim_));
      }

      void operator++()
      {
         ++iter_;
      }

      void advance()
      {
         iter_.advance();
      }

      void advanceN( const int i )
      {
         iter_.advanceN(i);
      }

      void advanceWithStride1()
      {
         iter_.advanceWithStride1();
      }

      value_type readWithStride( const int i ) const
      {
         return iter_.readWithStride(i);
      }

      value_type readWithoutStride( const int i ) const
      {
         return iter_.readWithoutStride(i);
      }

      void advanceDim()
      {
         iter_.advanceDim();
      }

#ifdef LTL_USE_SIMD
      inline typename VEC_TYPE(value_type) readVec( const int i ) const
      {
         return iter_.readVec(i);
      }

      inline bool sameAlignmentAs( void* p ) const
      {
         return iter_.sameAlignmentAs(p);
      }
#endif

      bool isStorageContiguous() const
      {
         return iter_.isStorageContiguous();
      }

      bool isStride1() const
      {
         return iter_.isStride1();
      }

      bool isConformable( const Shape<N>& other ) const
      {
         return iter_.isConformable( other );
      }

      const Shape<N>* shape() const
      {
         return iter_.shape();
      }

      void reset()
      {
         iter_.reset();
      }

   protected:
      IndexIter<T,N> iter_;
      const int dim_;
};


// and the wrapper...
// use indexPos( A ) in an expression to refer to the current
// index in MArray A during expression evaluation
//
//template<class RetType>
//template<class T1, int N>
//inline TExpr<IndexIterExpr<T1,RetType,N>, N>
//indexPos( const MArray<T1,N>& a )
//{
//  typedef IndexIterExpr<T1,RetType,N> ExprT;
//  return TExpr<ExprT,N>( ExprT( a.indexBegin() ) );
//}

// and the wrapper...
/*! \ingroup marray_class
  Use indexPos( A, i ) in an expression to refer to the i-th dimension's
  index in MArray A during expression evaluation */
template<class T1, int N>
inline TExpr<IndexIterDimExpr<T1,int,N>, N>
indexPos( const MArray<T1,N>& a, const int dim )
{
   typedef IndexIterDimExpr<T1,int,N> ExprT;
   return TExpr<ExprT,N>( ExprT(a.indexBegin(), dim) );
}

/*! \ingroup marray_class
  indexPosFlt() return the index value as a float.
  This is neccessary because of the lack of arbitrary type
  applicative templates, only X(T a, T b) ... */
template<class T1, int N>
inline TExpr<IndexIterDimExpr<T1,float,N>, N>
indexPosFlt( const MArray<T1,N>& a, const int dim )
{
   typedef IndexIterDimExpr<T1,float,N> ExprT;
   return TExpr<ExprT,N>( ExprT(a.indexBegin(), dim) );
}

/*! \ingroup marray_class
  indexPosDbl() returns the index value as a double.
  This is neccessary because of the lack of arbitrary type
  applicative templates, only X(T a, T b) ... */
template<class T1, int N>
inline TExpr<IndexIterDimExpr<T1,double,N>, N>
indexPosDbl( const MArray<T1,N>& a, const int dim )
{
   typedef IndexIterDimExpr<T1,double,N> ExprT;
   return TExpr<ExprT,N>( ExprT(a.indexBegin(), dim) );
}

LTL_END_NAMESPACE

#endif // __LTL_INDEXITER__

