/*===========================================================================
	Title: BSpline.cpp
	Module: Pi/MathLib
	Author: Ignacio Castaņo
	Date: 29/05/2000
	License: Public Domain
===========================================================================*/

/*----------------------------------------------------------------------------
	Doc:
----------------------------------------------------------------------------*/

/** @file BSpline.cpp
 * @brief BSpline interpolation and aproximation. (NUNRBS)
 *
 * You can find some good notes about splines here:
 * http://www.cs.mtu.edu/~shene/COURSES/cs3621/NOTES/notes.html 
**/


/*----------------------------------------------------------------------------
	Includes:
----------------------------------------------------------------------------*/

// Core
#include "Pi/Core/Memory.h"

// MathLib
#include "Pi/MathLib/BSpline.h"
#include "Pi/MathLib/Sparse.h"



/*----------------------------------------------------------------------------
	Methods:
----------------------------------------------------------------------------*/

/** Create the basis function. */
void BSplineBasis::Create(uint cpn, uint d, const float k[], bool loop) {
	piCheck(cpn > 1);
	piCheck(d > 0);
	piCheck(cpn > d);

	// Init members.
	cpoint_num = cpn;
	degree = d;
	
	// Allocate knots.
	knot_array.Resize(cpoint_num + degree + 1);
	
	// Init knots.
	if( k != NULL ) {
		foreach(i, knot_array) {
			knot_array[i] = *k++;
		}
	}
	else {
		uint i = 0;
		if( loop ) {
			// Closed uniform knots.
			while(i < degree) {
				knot_array[i] = 0.0f;
				i++;
			}
			while(i < cpoint_num) {
				knot_array[i] = float(i - degree) / float(cpoint_num - degree);
				i++;
			}
			while(i < knot_array.Size()) {
				knot_array[i] = 1.0f;
				i++;
			}
		}
		else {
			// Open uniform knots.
			while(i < knot_array.Size()) {
				knot_array[i] = float(int(i) - int(degree)) / float(cpoint_num - degree);
				i++;
			}
			
			/*while(i < degree) {
				knot_array[i] = 1 + float(int(i) - int(degree)) / float(cpoint_num - degree);
				i++;
			}
			while(i <= cpoint_num) {
				knot_array[i] = float(int(i) - int(degree)) / float(cpoint_num - degree);
				i++;
			}
			while(i < knot_array.Size()) {
				knot_array[i] = float(int(i) - int(cpoint_num)) / float(cpoint_num - degree);
				i++;
			}*/
		}
	}
	
	
	piDebug("%d Knots.\n", knot_array.Size());
	foreach(i, knot_array) {
		piDebug(" - %f.\n", knot_array[i]);
	}
	
}


/** Locate the index for the given value. @todo Use bisection. */
uint BSplineBasis::GetKey(float t) {
	piCheck(knot_array.Size() == uint(cpoint_num + degree + 1));

	uint i;
	for(i = degree+1; i <= cpoint_num; i++) {
		if( t < knot_array[i] ) {
			//break;
			return i-1;
		}
	}
	return cpoint_num-1;
}


/** Evaluate the basis for the given t, returns the indices of the first and last non-zero basis. 
 * @todo This method can be optimized for fixed degree evaluation.
 */
void BSplineBasis::EvaluateBasis(float t, int * first_out, float N[] ) {
//	piDebugCheck(t >= 0.0f);
//	piDebugCheck(t <= 1.0f);
	piDebugCheck(first_out != NULL);
	piDebugCheck(N != NULL);
	//piDebug("> %f\n", t);
	
	/*if( t <= 0.0f ) {
		first = 0;

		N[0] = 1.0f;
		for(uint d = 1; d <= degree; d++) {
			N[d] = 0.0f;
		}
	}
	else if( t >= 1.0f ) {
		first = cpoint_num - 1 - degree;

		for(uint d = 0; d < degree; d++) {
			N[d] = 0.0f;
		}
		N[degree] = 1.0f;
	}
	else*/ {
			   
		int first = GetKey(t);

		piDebugCheck(knot_array[first] <= t);
		piDebugCheck(t <= knot_array[first+1]);

		float n0, n1;

		// First step.
		N[degree] = 1.0f;

		for(int d = 1; d <= degree; d++) {
			{
				uint i = first;
				
				n1 = (knot_array[i+1] - t) / (knot_array[i+1] - knot_array[i-d+1]);
				
				N[degree-d] = n1 * N[degree-d+1];
			}
			for(int k = 1; k < d; k++) {
				uint i = first + k;
				
				n0 = (t - knot_array[i-d]) / (knot_array[i] - knot_array[i-d]);
				n1 = (knot_array[i+1] - t) / (knot_array[i+1] - knot_array[i-d+1]);
				
				N[degree-d+k] = n0 * N[degree-d+k] + n1 * N[degree-d+k+1];
			}
			{
				uint i = first + d;
				
				n0 = (t - knot_array[i-d]) / (knot_array[i] - knot_array[i-d]);
				
				N[degree] = n0 * N[degree];
			}
		}
		
		*first_out = first - degree;
	}
}


/** Create the curve. */
void BSplineCurve::Create(uint cpn, uint d, const Vec3 * cp, const float * k) {
	
	// Init basis.
	basis.Create(cpn, d, k);
	
	// Allocate control point array.
	cpoint_array.Resize(cpn);

	if( cp != NULL ) {
		// Init control point array.
		foreach(i, cpoint_array) {
			cpoint_array[i] = *cp++;
		}
	}
}


/** Evaluate the curve at the given t. */
void BSplineCurve::Evaluate(float t, Vec3 * out) {

	const uint d = basis.GetDegree();	
	int first;
	//float N[d+1]; // gcc extension
	STACK_ARRAY(float, N, d+1);

	basis.EvaluateBasis(t, &first, N);

	*out = Vec3::Origin;

	for(uint i = 0; i <= d; i++) {
		out->Mad(*out, cpoint_array[first+i], N[i]);
	}
}


/** Adjust the control points to aproximate the given set of points. */
void BSplineCurve::Aproximate(const PiArray<float> & time_array, const PiArray<Vec3> & point_array) {
	
	const uint cpoint_num = GetControlPointNum();
	const uint sample_num = time_array.Size();	
	piCheck(sample_num == point_array.Size());
	piCheck(sample_num >= GetControlPointNum());

	DenseVector bx(sample_num), by(sample_num), bz(sample_num);
	
	// Init b vector.
	for(uint i = 0; i < sample_num; i++) {
		bx[i] = point_array[i].x;
		by[i] = point_array[i].y;
		bz[i] = point_array[i].z;
	}

	const uint d = basis.GetDegree();
	
#if 0	// Using a least squares solver

	// Init N matrix.
	SparseMatrix N(cpoint_num-2, sample_num);

	int first;
	//float B[d+1]; // gcc extension
	STACK_ARRAY(float, B, d+1);

	for(uint s = 0; s < sample_num; s++) {

		basis.EvaluateBasis(time_array[s], &first, B);

		for(uint i = 0; i <= d; i++) {
			if(first+i > 0 && first+i < cpoint_num-1)  {
				if( B[i] != 0.0f ) {
					N.SetElem(first+i-1, s, B[i]);
				}
			}
		}
	}


	DenseVector px(cpoint_num-2), py(cpoint_num-2), pz(cpoint_num-2);

	LSQRSolve( N, bx, px );
	LSQRSolve( N, by, py );
	LSQRSolve( N, bz, pz );
		
	//	CGNRSolve( N, bx, px );
	//	CGNRSolve( N, by, py );
	//	CGNRSolve( N, bz, pz );
	
	ControlPoint(0) = point_array[0];
	for(uint i = 1; i < cpoint_num-1; i++) {
		ControlPoint(i) = Vec3(px[i-1], py[i-1], pz[i-1]);
		ControlPoint(i).Print();
	}
	ControlPoint(cpoint_num-1) = point_array[sample_num-1];
	

#else	// Using the normal form.

	// Init N matrix.
	SparseMatrix N(cpoint_num, sample_num);

	int first;
	STACK_ARRAY(float, B, d+1);

	for(uint s = 0; s < sample_num; s++) {

		basis.EvaluateBasis(time_array[s], &first, B);

		for(uint i = 0; i <= d; i++) {
			if( B[i] != 0.0f ) {
				N.SetElem(first+i, s, B[i]);
			}
		}
	}


	DenseVector px(cpoint_num), py(cpoint_num), pz(cpoint_num);

	// Set initial solution.
	/*for(uint i = 0; i < cpoint_num; i++) {
		px[i] = point_array[ (i * sample_num) / cpoint_num ].x;
		py[i] = point_array[ (i * sample_num) / cpoint_num ].y;
		pz[i] = point_array[ (i * sample_num) / cpoint_num ].z;
	}*/

//	piDebug("solved in: %d\n", CGNRSolve( N, bx, px, 1e-6 ) );
//	piDebug("solved in: %d\n", CGNRSolve( N, by, py, 1e-6 ) );
//	piDebug("solved in: %d\n", CGNRSolve( N, bz, pz, 1e-6 ) );
	
	CholeskySolve( N, bx, px );
	CholeskySolve( N, by, py );
	CholeskySolve( N, bz, pz );

//	LSQRSolve( N, bx, px, 1e-6f );
//	LSQRSolve( N, by, py, 1e-6f );
//	LSQRSolve( N, bz, pz, 1e-6f );


	for(uint i = 0; i < cpoint_num; i++) {
		ControlPoint(i) = Vec3(px[i], py[i], pz[i]);
	//	ControlPoint(i).Print();
	}

#endif

}



/** Interpolate the given points. */
void BSplineCurve::Interpolate(const PiArray<float> & time_array, const PiArray<Vec3> & point_array) {

	// Handle open/close.
	// Handle periodic curves.
	uint cpoint_num = GetControlPointNum();
	piCheck(cpoint_num == time_array.Size());	
	piCheck(cpoint_num == point_array.Size());
	piCheck(cpoint_num == GetControlPointNum());


	DenseVector bx(cpoint_num), by(cpoint_num), bz(cpoint_num);

	// Init b vector.
	for(uint i = 0; i < cpoint_num; i++) {
		bx[i] = point_array[i].x;
		by[i] = point_array[i].y;
		bz[i] = point_array[i].z;
	}


	// Init N matrix.
	SparseMatrix N(cpoint_num, cpoint_num);

	const uint d = basis.GetDegree();
	int first;
	//float B[d+1]; // gcc extension
	STACK_ARRAY(float, B, d+1);

	for(uint s = 0; s < cpoint_num; s++) {

		basis.EvaluateBasis(time_array[s], &first, B);

		for(uint i = 0; i <= d; i++) {
			if( B[i] != 0.0f ) {
				N.SetElem(first+i, s, B[i]);
			}
		}
	}

	// Solve N*p=b
	DenseVector px(cpoint_num), py(cpoint_num), pz(cpoint_num);

	BiCGSTABSolve( N, bx, px );
	BiCGSTABSolve( N, by, py );
	BiCGSTABSolve( N, bz, pz );

	// Setup control points.
	for(uint i = 0; i < cpoint_num; i++) {
		ControlPoint(i) = Vec3(px[i], py[i], pz[i]);
		ControlPoint(i).Print();
	}
}

