/*
 *   Written by Bradley Broom (2002).
 *
 *   Copyright (c) 2002 Bradley Broom
 *
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation; either version 2, or (at your option)
 *   any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */
#include <stdio.h>
#include <math.h>

#include "MRI.h"
#include "vmedian.h"

#define	SQRCLOSE 1.0e-10
#define MAXITER 50

static void updatew (int n, fpixel *y1, fpixel *y2, fpixel x[], double w[])
{
	fpixel R;
	double wy, ry, den;
	int i;

#if 0
	printf ("y1=%g,%g,%g\n", y1->r, y1->g, y1->b);
#endif
	y2->r = y2->b = y2->g = 0;
	wy = 0.0;
	den = 0.0;
	for (i = 0; i < n; i++) {
		double mag = (y1->r - x[i].r) * (y1->r - x[i].r) +
		             (y1->g - x[i].g) * (y1->g - x[i].g) +
		             (y1->b - x[i].b) * (y1->b - x[i].b);
		if (mag < SQRCLOSE) {
			wy += w[i];
#ifdef DEBUG
			printf ("y close to x[%d]\n", i+1);
#endif
		}
		else {
			double wt = w[i] / sqrt (mag);
#if 0
			printf ("i=%d w[i]=%g mag=%g wt=%g\n", i, w[i], mag, wt);
#endif
			den += wt;
			y2->r += wt * x[i].r;
			y2->b += wt * x[i].b;
			y2->g += wt * x[i].g;
		}
	}
#if 0
	printf ("y2=%g,%g,%g den=%g\n",y2->r,y2->g,y2->b,den);
#endif
	y2->r /= den;
	y2->g /= den;
	y2->b /= den;
	if (wy > 0) {
		R.r = R.b = R.g = 0;
		for (i = 0; i < n; i++) {
			double d = (y1->r - x[i].r) * (y1->r - x[i].r) +
				   (y1->g - x[i].g) * (y1->g - x[i].g) +
				   (y1->b - x[i].b) * (y1->b - x[i].b);
			if (d >= SQRCLOSE) {
				double wt = w[i] / sqrt(d);
				R.r += wt * (x[i].r - y1->r);
				R.g += wt * (x[i].g - y1->g);
				R.b += wt * (x[i].b - y1->b);
			}
		}
		ry = sqrt ((R.r * R.r) + (R.g * R.g) + (R.b * R.b));
#ifdef DEBUG
		printf ("ry=%g wy=%g wy/ry=%g\n", ry, wy, wy/ry);
#endif
		wy /= ry;
		if (wy < 1.0) {
			y2->r = y2->r * (1.0 - wy) + wy * y1->r;
			y2->g = y2->g * (1.0 - wy) + wy * y1->g;
			y2->b = y2->b * (1.0 - wy) + wy * y1->b;
		}
		else {
			y2->r = y1->r;
			y2->g = y1->g;
			y2->b = y1->b;
		}
	}
}

static void update (int n, fpixel *y1, fpixel *y2, fpixel x[])
{
	fpixel R;
	double wy, ry, den;
	int i;
	int *close;
	double *wt;

	y2->r = y2->b = y2->g = 0;
	wy = 0.0;
	den = 0.0;
	close = (int *)alloca (sizeof(int)*n);
	wt = (double *)alloca (sizeof(double)*n);
	for (i = 0; i < n; i++) {
		double mag = (y1->r - x[i].r) * (y1->r - x[i].r) +
		             (y1->g - x[i].g) * (y1->g - x[i].g) +
		             (y1->b - x[i].b) * (y1->b - x[i].b);
		if (mag < SQRCLOSE) {
			wy ++;
			close[i] = 1;
#ifdef DEBUG
			printf ("y close to x[%d]\n", i+1);
#endif
		}
		else {
			double mywt = 1.0 / sqrt (mag);
			close[i] = 0;
			wt[i] = mywt;
			den += mywt;
			y2->r += mywt * x[i].r;
			y2->b += mywt * x[i].b;
			y2->g += mywt * x[i].g;
		}
	}
	y2->r /= den;
	y2->g /= den;
	y2->b /= den;
	if (wy > 0) {
		R.r = R.b = R.g = 0;
		for (i = 0; i < n; i++) {
			if (!close[i]) {
				R.r += wt[i] * (x[i].r - y1->r);
				R.g += wt[i] * (x[i].g - y1->g);
				R.b += wt[i] * (x[i].b - y1->b);
			}
		}
		ry = sqrt ((R.r * R.r) + (R.g * R.g) + (R.b * R.b));
#ifdef DEBUG
		printf ("ry=%g wy=%g wy/ry=%g\n", ry, wy, wy/ry);
#endif
		wy /= ry;
		if (wy < 1.0) {
			y2->r = y2->r * (1.0 - wy) + wy * y1->r;
			y2->g = y2->g * (1.0 - wy) + wy * y1->g;
			y2->b = y2->b * (1.0 - wy) + wy * y1->b;
		}
		else {
			y2->r = y1->r;
			y2->g = y1->g;
			y2->b = y1->b;
		}
	}
}

fpixel VectorMedianW (int n, fpixel x[], double w[], double tol)
{
	fpixel y1, y2;
	int i;
	double conv;

	y1.r = y1.g = y1.b = 0.0;
	for (i = 0; i < n; i++) { y1.r += x[i].r; } y1.r /= n;
	for (i = 0; i < n; i++) { y1.g += x[i].g; } y1.g /= n;
	for (i = 0; i < n; i++) { y1.b += x[i].b; } y1.b /= n;
	conv = (y1.r * y1.r + y1.g * y1.g + y1.b*y1.b);
	if (conv < 1e-5) return y1;

 	/* Scale tolerance by average magnitude of inputs. */
	y2.r = y2.g = y2.b = 0.0;
	for (i = 0; i < n; i++) { y2.r += fabs(x[i].r); } y2.r /= n;
	for (i = 0; i < n; i++) { y2.g += fabs(x[i].g); } y2.g /= n;
	for (i = 0; i < n; i++) { y2.b += fabs(x[i].b); } y2.b /= n;
	conv = (y2.r * y2.r + y2.g * y2.g + y2.b*y2.b);
	tol *= conv;

	for (i = 0; i < MAXITER; i += 2) {
		updatew (n, &y1, &y2, x, w);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
#ifdef WDEBUG
		printf ("i=%d y=%g,%g,%g conv=%g\n", i, y2.r, y2.g, y2.b, conv);
#endif
		if (conv <= tol) return y2;
		updatew (n, &y2, &y1, x, w);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
#ifdef WDEBUG
		printf ("i=%d y=%g,%g,%g conv=%g\n", i+1, y1.r, y1.g, y1.b, conv);
#endif
		if (conv <= tol) return y1;
	}
	fprintf (stderr, "vmedian: maximum iterations (%d) exceeded. Conv=%g, tol=%g: \n", MAXITER, conv, tol);
	for (i = 0; i < n; i++)
		fprintf (stderr, "  x[%d] = %g,%g,%g\n", i, x[i].r, x[i].g, x[i].b);
	fprintf (stderr, "  ===>   %g,%g,%g conv=%g\n", y1.r, y1.g, y1.b, conv);
	return y1;
}

fpixel
VectorMedian (int n, fpixel x[], double tolr)
{
	fpixel y1, y2;
	int i;
	double tol, conv;
#ifdef TESTING
	double tmp;
#endif

	/* Determine average vector. */
	y1.r = y1.g = y1.b = 0.0;
	for (i = 0; i < n; i++) { y1.r += x[i].r; } y1.r /= n;
	for (i = 0; i < n; i++) { y1.g += x[i].g; } y1.g /= n;
	for (i = 0; i < n; i++) { y1.b += x[i].b; } y1.b /= n;

	/* Return if the average vector is pretty much zero. */
	conv = (y1.r * y1.r + y1.g * y1.g + y1.b*y1.b);
	if (conv < 1e-5) return y1;

 	/* Scale tolerance by average magnitude of inputs. */
	y2.r = y2.g = y2.b = 0.0;
	for (i = 0; i < n; i++) { y2.r += fabs(x[i].r); } y2.r /= n;
	for (i = 0; i < n; i++) { y2.g += fabs(x[i].g); } y2.g /= n;
	for (i = 0; i < n; i++) { y2.b += fabs(x[i].b); } y2.b /= n;
	conv = (y2.r * y2.r + y2.g * y2.g + y2.b*y2.b);
	tol = tolr * conv;

	/* Loop until converged or maximum iterations exceeded. */
	/* Each loop does two iterations to avoid data shuffling. */
	for (i = 0; i < MAXITER; i += 2) {
		update (n, &y1, &y2, x);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
#ifdef DEBUG
		printf ("i=%d y=%g,%g,%g conv=%g\n", i, y2.r, y2.g, y2.b, conv);
#endif
		if (conv <= tol) return y2;
		update (n, &y2, &y1, x);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
#ifdef DEBUG
		printf ("i=%d y=%g,%g,%g conv=%g\n", i+1, y1.r, y1.g, y1.b, conv);
#endif
		if (conv <= tol) return y1;
	}
	/* For sufficiently tight tolerances, some inputs take excessively
	 * long to converge.  For the default parameters, the differences
	 * between the result after MAXITER iterations is sufficiently
	 * close to the true value that it's probably imperceptable.
	 * Still, if you'd like to see details of the tail end of the
	 * algorithm, you can define TESTING and see lots of verbose debugging
	 * information.
	 */
#ifdef TESTING
	fprintf (stderr, "vmedian: maximum iterations (%d) exceeded: conv=%g, tol=%g. \n", MAXITER, conv, tol);
	for (i = 0; i < n; i++)
		fprintf (stderr, "  x[%d] = %g,%g,%g\n", i, x[i].r, x[i].g, x[i].b);
 	/* Scale tolerance by average magnitude of inputs. */
	y2.r = y2.g = y2.b = 0.0;
	for (i = 0; i < n; i++) { y2.r += (x[i].r); } y2.r /= n;
	for (i = 0; i < n; i++) { y2.g += (x[i].g); } y2.g /= n;
	for (i = 0; i < n; i++) { y2.b += (x[i].b); } y2.b /= n;
	tmp = (y2.r * y2.r + y2.g * y2.g + y2.b*y2.b);
	tol = tolr * tmp;
	fprintf (stderr, "Old: tolr=%g conv=%g -> tol=%g\n", tolr, tmp, tol);

	y2.r = y2.g = y2.b = 0.0;
	for (i = 0; i < n; i++) { y2.r += fabs(x[i].r); } y2.r /= n;
	for (i = 0; i < n; i++) { y2.g += fabs(x[i].g); } y2.g /= n;
	for (i = 0; i < n; i++) { y2.b += fabs(x[i].b); } y2.b /= n;
	tmp = (y2.r * y2.r + y2.g * y2.g + y2.b*y2.b);
	tol = tolr * tmp;
	fprintf (stderr, "New: tolr=%g conv=%g -> tol=%g\n", tolr, tmp, tol);

	for (i = 0; i < 6; i += 2) {
		fprintf (stderr, "         %g,%g,%g conv=%g\n", y1.r, y1.g, y1.b, conv);
		update (n, &y1, &y2, x);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
		fprintf (stderr, "         %g,%g,%g conv=%g\n", y2.r, y2.g, y2.b, conv);
		update (n, &y2, &y1, x);
		conv = (y2.r - y1.r) * (y2.r - y1.r) +
		       (y2.g - y1.g) * (y2.g - y1.g) +
		       (y2.b - y1.b) * (y2.b - y1.b);
	}
	fprintf (stderr, "  ===>   %g,%g,%g conv=%g\n", y1.r, y1.g, y1.b, conv);
#endif
	return y1;
}

#ifdef DEBUG
main ()
{
	fpixel x[1024];
	double w[1024];
	fpixel m;
	int i;
	char	buffer[1024];

	i = 0;
	while (gets(buffer)) {
		if (sscanf (buffer, "%lf %lf %lf %lf", &x[i].r, &x[i].g, &x[i].b, &w[i]) != 4)
			fprintf (stderr, "Error line %d\n", i+1);
		i++;
	}
	m = VectorMedianW (i, x, w, 1.0/(0x0FFF*0x0FFF));
	printf ("m.r=%g m.g=%g m.b=%g\n", m.r, m.g, m.b);
}
#endif
