/* @HEADER@ */
/* ***********************************************************************
// 
//           TSFExtended: Trilinos Solver Framework Extended
//                 Copyright (2004) Sandia Corporation
// 
// Under terms of Contract DE-AC04-94AL85000, there is a non-exclusive
// license for use of this work by or on behalf of the U.S. Government.
// 
// This library is free software; you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as
// published by the Free Software Foundation; either version 2.1 of the
// License, or (at your option) any later version.
//  
// This library is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
// Lesser General Public License for more details.
//  
// You should have received a copy of the GNU Lesser General Public
// License along with this library; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
// USA
// Questions? Contact Michael A. Heroux (maherou@sandia.gov) 
// 
// **********************************************************************/
 /* @HEADER@ */

#ifndef TSFBLOCKTRIANGULARSOLVER_HPP
#define TSFBLOCKTRIANGULARSOLVER_HPP

#include "SundanceDefs.hpp"
#include "TSFLinearSolverDecl.hpp" 
#include "TSFLinearCombinationDecl.hpp" 
#include "TSFCommonOperatorsDecl.hpp" 


namespace TSFExtended
{
  /** */
  template <class Scalar>
  class BlockTriangularSolver : public LinearSolverBase<Scalar>,
                                public SundanceUtils::Handleable<LinearSolverBase<Scalar> >
  {
  public:
    /** */
    BlockTriangularSolver(const LinearSolver<Scalar>& solver)
      : LinearSolverBase<Scalar>(ParameterList()), solvers_(tuple(solver)) {;}

    /** */
    BlockTriangularSolver(const Array<LinearSolver<Scalar> >& solvers)
      : LinearSolverBase<Scalar>(ParameterList()), solvers_(solvers) {;}

    /** */
    virtual ~BlockTriangularSolver(){;}

    /** */
    virtual SolverState<Scalar> solve(const LinearOperator<Scalar>& op,
                                      const Vector<Scalar>& rhs,
                                      Vector<Scalar>& soln) const ;

    /* */
    GET_RCP(LinearSolverBase<Scalar>);
  private:
    Array<LinearSolver<Scalar> > solvers_;
  };


  template <class Scalar> inline
  SolverState<Scalar> BlockTriangularSolver<Scalar>
  ::solve(const LinearOperator<Scalar>& op,
          const Vector<Scalar>& rhs,
          Vector<Scalar>& soln) const
  {
    int nRows = op.numBlockRows();
    int nCols = op.numBlockCols();

    soln = op.domain().createMember();
    //    bool converged = false;

    TEST_FOR_EXCEPTION(nRows != rhs.space().numBlocks(), std::runtime_error,
                       "number of rows in operator " << op
                       << " not equal to number of blocks on RHS "
                       << rhs);

    TEST_FOR_EXCEPTION(nRows != nCols, std::runtime_error,
                       "nonsquare block structure in block triangular "
                       "solver: nRows=" << nRows << " nCols=" << nCols);

    bool isUpper = false;
    bool isLower = false;

    for (int r=0; r<nRows; r++)
      {
        for (int c=0; c<nCols; c++)
          {
            if (op.getBlock(r,c).ptr().get() == 0 ||
                dynamic_cast<const SimpleZeroOp<Scalar>* >(op.getBlock(r,c).ptr().get()))
              {
                TEST_FOR_EXCEPTION(r==c, std::runtime_error,
                                   "zero diagonal block (" << r << ", " << c 
                                   << " detected in block "
                                   "triangular solver. Operator is " << op);
                continue;
              }
            else
              {
                if (r < c) isUpper = true;
                if (c < r) isLower = true;
              }
          }
      }

    TEST_FOR_EXCEPTION(isUpper && isLower, std::runtime_error, 
                       "block triangular solver detected non-triangular operator "
                       << op);

    bool oneSolverFitsAll = false;
    if ((int) solvers_.size() == 1 && nRows != 1) 
      {
        oneSolverFitsAll = true;
      }

    for (int i=0; i<nRows; i++)
      {
        int r = i;
        if (isUpper) r = nRows - 1 - i;
        Vector<Scalar> rhs_r = rhs.getBlock(r);
        for (int j=0; j<i; j++)
          {
            int c = j;
            if (isUpper) c = nCols - 1 - j;
            if (op.getBlock(r,c).ptr().get() != 0)
              {
                rhs_r = rhs_r - op.getBlock(r,c) * soln.getBlock(c);
              }
          }

        SolverState<Scalar> state;
        Vector<Scalar> soln_r;
        if (oneSolverFitsAll)
          {
            state = solvers_[0].solve(op.getBlock(r,r), rhs_r, soln_r);
          }
        else
          {
            state = solvers_[r].solve(op.getBlock(r,r), rhs_r, soln_r);
          }
        if (nRows > 1) soln.setBlock(r, soln_r);
        else soln = soln_r;
        if (state.finalState() != SolveConverged)
          {
            return state;
          }
      }

    return SolverState<Scalar>(SolveConverged, "block solves converged",
                               0, ScalarTraits<Scalar>::zero());
  }
  
}

#endif
