#include "../include/solve.hpp"

namespace ngsolve
{
  using namespace ngsolve;
  // using namespace ngbla;
  // using namespace ngla;











  ///
  class NumProcZZErrorEstimator : public NumProc
  {
  private:
    ///
    BilinearForm * bfa;
    ///
    GridFunction * gfu;
    ///
    GridFunction * gferr;
  public:
    NumProcZZErrorEstimator (PDE & apde, const Flags & flags);
    virtual ~NumProcZZErrorEstimator();

    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcZZErrorEstimator (pde, flags);
    }

    static void PrintDoc (ostream & ost);
    virtual void Do(LocalHeap & lh);
    virtual void PrintReport (ostream & ost);

    virtual string GetClassName () const
    {
      return "ZZ Error Estimator";
    }
  };



  NumProcZZErrorEstimator :: NumProcZZErrorEstimator (PDE & apde, const Flags & flags)
    : NumProc (apde)
  {
    bfa = pde.GetBilinearForm (flags.GetStringFlag ("bilinearform", ""));
    gfu = pde.GetGridFunction (flags.GetStringFlag ("solution", ""));
    gferr = pde.GetGridFunction (flags.GetStringFlag ("error", ""));
  }

  NumProcZZErrorEstimator :: ~NumProcZZErrorEstimator()
  {
    ;
  }
  
  void NumProcZZErrorEstimator :: PrintDoc (ostream & ost)
  {
    ost << 
      "\n\nNumproc ZZ-error estimator:\n" \
      "---------------------------\n" \
      "Computes the Zienkiewicz-Zhu error estimator\n\n" \
      "Required flags:\n" \
      "-bilinearform=<bfname>\n"
      "    takes first integrator of bilinearform to compute the flux\n"
      "-solution=<solname>\n"
      "    gridfunction storing the finite element solution\n"
      "-error=<errname>\n"
      "    piece-wise constant gridfuntion to store the computed element-wise error\n";
  }
  
  

  void NumProcZZErrorEstimator :: Do(LocalHeap & lh)
  {
    cout << "ZZ error-estimator" << endl;

    if (bfa->NumIntegrators() == 0)
      throw Exception ("ZZErrorEstimator: Bilinearform needs an integrator");

    BilinearFormIntegrator * bfi = bfa->GetIntegrator(0);

    Flags fesflags;
    fesflags.SetFlag ("order", bfa->GetFESpace().GetOrder());
    fesflags.SetFlag ("dim", bfi->DimFlux());
    if (bfa->GetFESpace().IsComplex())
      fesflags.SetFlag ("complex");
    // cout << "flags = " << endl << fesflags << endl;

    H1HighOrderFESpace & fesflux = 
      *new H1HighOrderFESpace (ma, fesflags);

    fesflux.Update();

    Flags flags;
    //    flags.SetFlag ("novisual");
    GridFunction * flux = CreateGridFunction (&fesflux, "fluxzz", flags);
    flux->Update();

    FlatVector<double> err = 
      dynamic_cast<T_BaseVector<double>&> (gferr->GetVector()).FV();

    err = 0;
  
    int ndom = ma.GetNDomains();
    //    for (int k = 4; k <= 4; k++)
    for (int k = 0; k < ndom; k++)
      {
	if (!bfa->GetFESpace().IsComplex())
	  {
	    CalcFluxProject (ma, 
			     dynamic_cast<const S_GridFunction<double>&> (*gfu), 
			     dynamic_cast<S_GridFunction<double>&> (*flux), 
			     *bfi,
			     1, k, lh);
	  
	    CalcError (ma, 
		       dynamic_cast<const S_GridFunction<double>&> (*gfu), 
		       dynamic_cast<const S_GridFunction<double>&> (*flux), 
		       *bfi,
		       err, k, lh);
	  }
	else
	  {
	    CalcFluxProject (ma, 
			     dynamic_cast<const S_GridFunction<Complex>&> (*gfu), 
			     dynamic_cast<S_GridFunction<Complex>&> (*flux), 
			     *bfi,
			     1, k, lh);
	  
	    CalcError (ma,
		       dynamic_cast<const S_GridFunction<Complex>&> (*gfu), 
		       dynamic_cast<const S_GridFunction<Complex>&> (*flux), 
		       *bfi,
		       err, k, lh);
	  }
      
      }
    // delete flux;
    double sum = 0;
    for (int i = 0; i < err.Size(); i++)
      sum += err(i);
    cout << "estimated error = " << sqrt (sum) << endl;
    static ofstream errout ("error.out");
    errout << ma.GetNLevels() 
	   << "  " << bfa->GetFESpace().GetNDof() 
	   << "  " << sqrt(double (bfa->GetFESpace().GetNDof())) 
	   << " " << sqrt(sum) << endl;
  }




  void NumProcZZErrorEstimator :: PrintReport (ostream & ost)
  {
    ost << "NumProcZZErrorEstimator:" << endl;
    ost << "Bilinear-form = " << endl;
  }










  ///
  class NumProcRTZZErrorEstimator : public NumProc
  {
  private:
    ///
    BilinearForm * bfa;
    ///
    GridFunction * gfu;
    ///
    GridFunction * gferr;
  public:
    NumProcRTZZErrorEstimator (PDE & apde, const Flags & flags);
    virtual ~NumProcRTZZErrorEstimator();
  
    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcRTZZErrorEstimator (pde, flags);
    }

    virtual void Do(LocalHeap & lh);
    virtual void PrintReport (ostream & ost);

    virtual string GetClassName () const
    {
      return "RTZZ Error Estimator";
    }
  };



  NumProcRTZZErrorEstimator :: NumProcRTZZErrorEstimator (PDE & apde, const Flags & flags)
    : NumProc (apde)
  {
    bfa = pde.GetBilinearForm (flags.GetStringFlag ("bilinearform", ""));
    gfu = pde.GetGridFunction (flags.GetStringFlag ("solution", ""));
    gferr = pde.GetGridFunction (flags.GetStringFlag ("error", ""));
  }

  NumProcRTZZErrorEstimator :: ~NumProcRTZZErrorEstimator()
  {
    ;
  }

  void NumProcRTZZErrorEstimator :: Do(LocalHeap & lh)
  {
    cout << "RTZZ error-estimator" << endl;

    if (bfa->NumIntegrators() == 0)
      throw Exception ("RTZZErrorEstimator: Bilinearform needs an integrator");

    BilinearFormIntegrator * bfi = bfa->GetIntegrator(0);

    Flags fesflags;
    fesflags.SetFlag ("order", bfa->GetFESpace().GetOrder()+8);
    // fesflags.SetFlag ("dim", bfi->DimFlux());
    if (bfa->GetFESpace().IsComplex())
      fesflags.SetFlag ("complex");

    HDivHighOrderFESpace & fesflux = 
      *new HDivHighOrderFESpace (ma, fesflags);

    fesflux.Update();

    Flags flags;
    GridFunction * flux = CreateGridFunction (&fesflux, "fluxzz", flags);
    flux->Update();

    FlatVector<double> err = 
      dynamic_cast<T_BaseVector<double>&> (gferr->GetVector()).FV();

    err = 0;
  
    int i, j, k;

    if (!bfa->GetFESpace().IsComplex())
      {
	CalcFluxProject (ma, 
			 dynamic_cast<const S_GridFunction<double>&> (*gfu), 
			 dynamic_cast<S_GridFunction<double>&> (*flux), 
			 *bfi,
			 1, -1, lh);
	  
	CalcError (ma, 
		   dynamic_cast<const S_GridFunction<double>&> (*gfu), 
		   dynamic_cast<const S_GridFunction<double>&> (*flux), 
		   *bfi,
		   err, -1, lh);
      }
    else
      {
	CalcFluxProject (ma, 
			 dynamic_cast<const S_GridFunction<Complex>&> (*gfu), 
			 dynamic_cast<S_GridFunction<Complex>&> (*flux), 
			 *bfi,
			 1, -1, lh);
	
	CalcError (ma,
		   dynamic_cast<const S_GridFunction<Complex>&> (*gfu), 
		   dynamic_cast<const S_GridFunction<Complex>&> (*flux), 
		   *bfi,
		   err, -1, lh);
      }

    // delete flux;
    double sum = 0;
    for (i = 0; i < err.Size(); i++)
      sum += err(i);
    cout << "estimated error = " << sqrt (sum) << endl;
    static ofstream errout ("error.out");
    errout << ma.GetNLevels() 
	   << "  " << bfa->GetFESpace().GetNDof() 
	   << "  " << sqrt(double (bfa->GetFESpace().GetNDof())) 
	   << " " << sqrt(sum) << endl;
  }




  void NumProcRTZZErrorEstimator :: PrintReport (ostream & ost)
  {
    ost << "NumProcRTZZErrorEstimator:" << endl;
    ost << "Bilinear-form = " << endl;
  }

  ///
  class NumProcHierarchicalErrorEstimator : public NumProc
  {
  private:
    ///
    BilinearForm * bfa;
    ///
    BilinearForm * bfa2;
    ///
    LinearForm * lff;
    ///
    GridFunction * gfu;
    ///
    GridFunction * gferr;
    ///
    FESpace * vtest;
  public:
    NumProcHierarchicalErrorEstimator (PDE & apde, const Flags & flags)
      : NumProc (apde)
    {
      bfa = pde.GetBilinearForm (flags.GetStringFlag ("bilinearform", ""));
      bfa2 = pde.GetBilinearForm (flags.GetStringFlag ("bilinearform2", ""), 1);
      if (!bfa2) bfa2 = bfa;
      lff = pde.GetLinearForm (flags.GetStringFlag ("linearform", ""));
      gfu = pde.GetGridFunction (flags.GetStringFlag ("solution", ""));
      vtest = pde.GetFESpace (flags.GetStringFlag ("testfespace", ""));
      gferr = pde.GetGridFunction (flags.GetStringFlag ("error", ""));
    }

    virtual ~NumProcHierarchicalErrorEstimator()
    {
      ;
    }
  
    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcHierarchicalErrorEstimator (pde, flags);
    }

    virtual void Do(LocalHeap & lh)
    {
      cout << "Hierarchical error-estimator" << endl;
      
      FlatVector<double> err = gferr->GetVector().FVDouble();
      if (!bfa->GetFESpace().IsComplex())
	{
	  CalcErrorHierarchical (ma, 
				 dynamic_cast<const S_BilinearForm<double>&> (*bfa), 
				 dynamic_cast<const S_BilinearForm<double>&> (*bfa2), 
				 dynamic_cast<const S_LinearForm<double>&> (*lff), 
				 dynamic_cast<S_GridFunction<double>&> (*gfu), 
				 *vtest, err, lh);
	}

      // delete flux;
      double sum = 0;
      for (int i = 0; i < err.Size(); i++)
	sum += err(i);
      cout << "estimated error = " << sqrt (sum) << endl;
    }


    virtual void PrintReport (ostream & ost)
    {
      ost << "NumProcHierarchicalErrorEstimator:" << endl;
      ost << "Bilinear-form = " << endl;
    }
    
    virtual string GetClassName () const
    {
      return "Hierarchical Error Estimator";
    }
  };












  /**
     Mark elements for refinement
  */
  class NumProcMarkElements : public NumProc
  {
  protected:
    ///
    GridFunction * gferr;
    ///
    GridFunction * gferr2;
    ///
    int minlevel;
    ///
    double fac;
    ///
    double factor;
  public:
    ///
    NumProcMarkElements (PDE & apde, const Flags & flags);
    ///
    virtual ~NumProcMarkElements();

    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcMarkElements (pde, flags);
    }
    ///
    virtual void Do();
    ///
    virtual string GetClassName () const
    {
      return "Element Marker";
    }
    virtual void PrintReport (ostream & ost);
  };







  NumProcMarkElements :: NumProcMarkElements (PDE & apde, const Flags & flags)
    : NumProc (apde)
  {
    gferr = pde.GetGridFunction (flags.GetStringFlag ("error", ""));
    gferr2 = pde.GetGridFunction (flags.GetStringFlag ("error2", ""), 1);
    minlevel = int(flags.GetNumFlag ("minlevel", 0));
    fac = flags.GetNumFlag ("fac", -1);
    if (fac != -1)
      throw Exception ("numproc markelements:\n Flag 'fac' not supported anymore\nNew one is -factor=xxx, default 0.5, lower: less refinement");
    factor = flags.GetNumFlag ("factor", factor);
  }

  NumProcMarkElements :: ~NumProcMarkElements()
  {
    ;
  }

  void NumProcMarkElements :: Do()
  {
    cout << "Element marker, " << flush;

    if (ma.GetNLevels() < minlevel) 
      {
	cout << endl;
	return;
      }


    int i;
    FlatVector<double> err =
      dynamic_cast<T_BaseVector<double>&> (gferr->GetVector()).FV();

    double maxerr = 0;
    double toterr = 0;

    if (gferr2)
      {
	const FlatVector<double> & err2 =
	  dynamic_cast<T_BaseVector<double>&> (gferr2->GetVector()).FV();
      
	for (i = 0; i < err.Size(); i++)
	  {
	    err(i) = sqrt (err(i) * err2(i));

	    if (err(i) > maxerr) maxerr = err(i);
	    toterr += err(i);
	  }

	cout << "goal driven error estimator, est.err. = " << toterr << endl;
      }
    else
      {
	for (i = 0; i < err.Size(); i++)
	  {
	    toterr += err(i);
	    if (err(i) > maxerr) maxerr = err(i);
	  }
      }

    // cout << "maxerr = " << maxerr << endl;

    int nref;
    fac = 1;
  
    if (factor > 0.9999) factor = 0.9999;

    while (1)
      {
	fac *= 0.8;
	double markerr = 0;
	nref = 0;

	for (i = 0; i < err.Size(); i++)
	  {
	    if (err(i) > fac * maxerr)
	      {
		nref++;
		Ng_SetRefinementFlag (i+1, 1);
		markerr += err(i);
	      }
	    else
	      Ng_SetRefinementFlag (i+1, 0);
	  }

	/*
	  cout << "fac = " << fac << ", nmark = " << nref 
	  << ", markerr = " << markerr << ", toterr = " << toterr << endl;
	*/

	if (markerr >= factor * toterr) break;
      }
    cout << nref << "/" << err.Size() << " elements marked." << endl;

    if (ma.GetDimension() == 3)
      {
	int nse = ma.GetNSE();
	for (int i = 0; i < nse; i++)
	  Ng_SetSurfaceRefinementFlag (i+1, 0);
      }
	  
  }




  void NumProcMarkElements :: PrintReport (ostream & ost)
  {
    ost << "NumProcMarkElements:" << endl;
  }








  class NumProcSetVisual : public NumProc
  {
    ///
Flags visflags;
  public:
    ///
    NumProcSetVisual (PDE & apde, const Flags & flags);
    ///
    virtual ~NumProcSetVisual ();

    ///
    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcSetVisual (pde, flags);
    }

    ///
    virtual void Do();
  };






  NumProcSetVisual ::   
  NumProcSetVisual (PDE & apde, const Flags & flags)
    : NumProc (apde), visflags (flags)
  {
    cout << "SetVisual has flags" << endl;
    visflags.PrintFlags(cout);
  }

  NumProcSetVisual ::  ~NumProcSetVisual ()
  {
    ;
  }


  void NumProcSetVisual :: Do()
  {
    int i;

    /*
      cout << "Set Visualization Flag:" << endl;

      for (i = 0; i < visflags.GetNStringFlags(); i++)
      {
      const char * name;
      const char * str;
      str = visflags.GetStringFlag (i, name);
      Ng_SetVisualizationParameter (name, str);
      }

      for (i = 0; i < visflags.GetNNumFlags(); i++)
      {
      const char * name;
      double val;
      char str[100];
      val = visflags.GetNumFlag (i, name);
      sprintf (str, "%f", val);
      cout << "set flag " << name << " to " << str << endl;
      Ng_SetVisualizationParameter (name, str);
      }
    */
  }







  ///
  class NumProcPrimalDualErrorEstimator : public NumProc
  {
  private:
    ///
    BilinearForm * bfa;
    ///
    GridFunction * gfu;
    ///
    GridFunction * gfflux;
    ///
    GridFunction * gferr;
  public:
    NumProcPrimalDualErrorEstimator (PDE & apde, const Flags & flags);
    virtual ~NumProcPrimalDualErrorEstimator();
  
    static NumProc * Create (PDE & pde, const Flags & flags)
    {
      return new NumProcPrimalDualErrorEstimator (pde, flags);
    }

    virtual void Do(LocalHeap & lh);
    virtual void PrintReport (ostream & ost);

    virtual string GetClassName () const
    {
      return "PrimalDual Error Estimator";
    }
  };



  NumProcPrimalDualErrorEstimator :: NumProcPrimalDualErrorEstimator (PDE & apde, const Flags & flags)
    : NumProc (apde)
  {
    bfa = pde.GetBilinearForm (flags.GetStringFlag ("bilinearform", ""));
    gfu = pde.GetGridFunction (flags.GetStringFlag ("solution", ""));
    gfflux = pde.GetGridFunction (flags.GetStringFlag ("flux", ""));
    gferr = pde.GetGridFunction (flags.GetStringFlag ("error", ""));
  }

  NumProcPrimalDualErrorEstimator :: ~NumProcPrimalDualErrorEstimator()
  {
    ;
  }

  void NumProcPrimalDualErrorEstimator :: Do(LocalHeap & lh)
  {
    cout << "PrimalDual error-estimator" << endl;

    if (bfa->NumIntegrators() == 0)
      throw Exception ("PrimalDualErrorEstimator: Bilinearform needs an integrator");

    BilinearFormIntegrator * bfi = bfa->GetIntegrator(0);

    FlatVector<double> err = 
      dynamic_cast<T_BaseVector<double>&> (gferr->GetVector()).FV();

    err = 0;
  
    int i, j, k;
    if (!bfa->GetFESpace().IsComplex())
      {
	CalcError (ma, 
		   dynamic_cast<const S_GridFunction<double>&> (*gfu), 
		   dynamic_cast<const S_GridFunction<double>&> (*gfflux), 
		   *bfi,
		   err, -1, lh);
      }
    else
      {
	CalcError (ma,
		   dynamic_cast<const S_GridFunction<Complex>&> (*gfu), 
		   dynamic_cast<const S_GridFunction<Complex>&> (*gfflux), 
		   *bfi,
		   err, -1, lh);
      }
      
      
    double sum = 0;
    for (i = 0; i < err.Size(); i++)
      sum += err(i);
    cout << "estimated error = " << sqrt (sum) << endl;
  }




  void NumProcPrimalDualErrorEstimator :: PrintReport (ostream & ost)
  {
    ost << "NumProcPrimalDualErrorEstimator:" << endl;
    ost << "Bilinear-form = " << endl;
  }









  


  namespace
  {
    class Init
    { 
    public: 
      Init ();
    };
    
    Init::Init()
    {
      GetNumProcs().AddNumProc ("zzerrorestimator", NumProcZZErrorEstimator::Create, NumProcZZErrorEstimator::PrintDoc);
      GetNumProcs().AddNumProc ("rtzzerrorestimator", NumProcRTZZErrorEstimator::Create);
      GetNumProcs().AddNumProc ("hierarchicalerrorestimator", 
				NumProcHierarchicalErrorEstimator::Create);
      GetNumProcs().AddNumProc ("primaldualerrorestimator", 
				NumProcPrimalDualErrorEstimator::Create);
      GetNumProcs().AddNumProc ("markelements", NumProcMarkElements::Create);    }
    
    
    Init init;
    
  }
  


}
