/*	pbpdemo.cpp

	This is a trivial demo of the reference C++ library
	implementation of the new Parallel Bit Pattern
	computation model.

	Copyright (C) 2022 by Henry Dietz, All rights reserved.

	This library is free software available under CC BY 4.0,
	https://creativecommons.org/licenses/by/4.0/

	This version is available at http://aggregate.org/PBP
	along with any relevant documentation, such as the
	reference card and citations to publications.

	This code is distributed in the hope that it will be
	useful, but WITHOUT ANY WARRANTY; without even the
	implied warranty of MERCHANTABILITY or FITNESS FOR A
	PARTICULAR PURPOSE.
*/

// This program generates a lot of output to stderr
// if REWAYS has a large value; 8 is sufficient for
// all the demos to run correctly
#define	REWAYS	8

#include "pbp.h"

void
simpletest()
{
	// Try various pint operations

	fprintf(stderr, "nan = ");
	pint nan; nan.Summary();

	fprintf(stderr, "zero = ");
	pint zero(0); zero.Summary();
	pint naught = 0; naught.Summary();

	fprintf(stderr, "one = ");
	pint one(1); one.Summary();

	fprintf(stderr, "two = ");
	pint two(2); two.Summary();

	fprintf(stderr, "fortytwo = ");
	pint fortytwo = 42; fortytwo.Summary();

	fprintf(stderr, "h2 = ");
	pint h2; h2 = H(2); h2.Summary();

	fprintf(stderr, "nan %s\n", (nan.Valid() ? "Valid" : "Not valid"));

	fprintf(stderr, "two %s\n", (two.Valid() ? "Valid" : "Not valid"));

	fprintf(stderr, "oversize = ");
	pint oversize(8, 3); oversize.Summary();
	oversize = oversize.Minimize(); oversize.Summary();
	oversize = oversize.Extend(6); oversize.Summary();
	oversize = oversize.Extend(4); oversize.Summary();

	fprintf(stderr, "one promoted = ");
	one = one.Promote(oversize); one.Summary();

	fprintf(stderr, "one as logic = ");
	one = one.Logic(); one.Summary();

	fprintf(stderr, "too = ");
	pint too; too = two; too.Summary();

	fprintf(stderr, "And gives ");
	pint math(2); math = math.And(h2); math.Summary();
	math = pint(2); math = (math & h2); math.Summary();
	math = pint(2); math &= h2; math.Summary();

	fprintf(stderr, "LAnd gives ");
	math = pint(2); math = (math && h2); math.Summary();

	fprintf(stderr, "Or gives ");
	math = pint(2); math = (math | h2); math.Summary();
	math = pint(2); math |= h2; math.Summary();

	fprintf(stderr, "LOr gives ");
	math = pint(2); math = (math || h2); math.Summary();

	fprintf(stderr, "Xor gives ");
	math = pint(2); math = (math ^ h2); math.Summary();
	math = pint(2); math ^= h2; math.Summary();

	fprintf(stderr, "LXor gives ");
	math = pint(2); math = math.LXor(h2); math.Summary();

	fprintf(stderr, "Not gives ");
	math = ~h2; math.Summary();

	fprintf(stderr, "LNot gives ");
	math = !h2; math.Summary();

	fprintf(stderr, "EQ gives ");
	math = (one == h2); math.Summary();

	fprintf(stderr, "NE gives ");
	math = (one != h2); math.Summary();

	fprintf(stderr, "Rot gives "); 
	math = h2.Rot(1); math.Show();

	fprintf(stderr, "Flip gives ");
	math = h2.Flip(42); math.Show();

	fprintf(stderr, "Reset gives ");
	math = one.Reset(42); math.Show();

	fprintf(stderr, "Set gives ");
	math = zero.Set(42); math.Show();

	fprintf(stderr, "Dom gives ");
	math = zero.Dom(42); math.Show();

	fprintf(stderr, "Ones gives %d\n", one.Ones());

	fprintf(stderr, "Meas gives %d\n", h2.Meas(42));
	fprintf(stderr, "Meas gives %d\n", pint(-42).Meas(42));
	fprintf(stderr, "Meas gives %d\n", h2.Meas());
	fprintf(stderr, "Meas gives %d\n", h2.Meas());
	fprintf(stderr, "Meas gives %d\n", h2.Meas());
	fprintf(stderr, "Meas gives %d\n", h2.Meas());

	fprintf(stderr, "First gives %d\n", h2.First());

	fprintf(stderr, "Ones gives %d\n", h2.Ones());

	fprintf(stderr, "GT gives ");
	math = (one > h2); math.Show();

	fprintf(stderr, "LT gives ");
	math = (one < h2); math.Show();

	fprintf(stderr, "GE gives ");
	math = (one >= h2); math.Show();

	fprintf(stderr, "LE gives ");
	math = (one <= h2); math.Show();

	fprintf(stderr, "Min gives ");
	math = one.Min(h2); math.Show();

	fprintf(stderr, "Max gives ");
	math = one.Max(h2); math.Show();

	fprintf(stderr, "ShR gives ");
	math = (h2 >> one); math.Show();
	math = h2; math >>= one; math.Summary();

	fprintf(stderr, "ShL gives ");
	math = (one << one); math.Show();
	math = one; math <<= one; math.Summary();

	fprintf(stderr, "Add gives ");
	math = (one + h2); math.Summary();
	math = one; math += h2; math.Summary();
	pint mathy;
	math = h2; mathy = (++math); math.Summary(); mathy.Summary();
	math = h2; mathy = (math++); math.Summary(); mathy.Summary();

	fprintf(stderr, "Sub gives ");
	math = (one - h2); math.Summary();
	math = one; math -= h2; math.Summary();
	math = h2; mathy = (--math); math.Summary(); mathy.Summary();
	math = h2; mathy = (math--); math.Summary(); mathy.Summary();

	fprintf(stderr, "Neg gives ");
	math = -one; math.Summary();
	math = (-one).Extend(8); math.Summary();

	fprintf(stderr, "Mul gives ");
	math = (h2 * h2); math.Summary();
	math = h2; math *= h2; math.Summary();

	fprintf(stderr, "Mul 2 bits gives ");
	math = h2.Mul(h2, 2); math.Summary();

	fprintf(stderr, "Abs gives ");
	math = (-one).Abs(); math.Summary();

	fprintf(stderr, "Signed gives ");
	math = pint(2).Signed(); math.Summary();

	fprintf(stderr, "UnSigned gives ");
	math = pint(-2).UnSigned(); math.Summary();

	fprintf(stderr, "Divide gives "); // wrong
	math = (h2 / pint(2)); math.Summary();
	math = h2; math /= pint(2); math.Summary();

	fprintf(stderr, "Modulus gives "); // wrong
	math = (h2 % pint(2)); math.Summary();
	math = h2; math %= pint(2); math.Summary();

	fprintf(stderr, "Any gives %d\n", zero.Any());
	fprintf(stderr, "Any gives %d\n", h2.Any());
	fprintf(stderr, "Any gives %d\n", one.Any());

	fprintf(stderr, "All gives %d\n", zero.All());
	fprintf(stderr, "All gives %d\n", h2.All());
	fprintf(stderr, "All gives %d\n", one.All());
}

void
pintsqrt(int val)
{
	// Compute square root of val
	pint a(val); // 8-bit number
	pint b = H(4); // 4-bit possible square roots
	pint c = (b * b); // square them
	pint d = (c == a); // which were 169?
	int pos = d.First(); // first non-0 is answer
	printf("Square root of %d is %d\n", val, pos);
}

void
pintfactor(int val)
{
	// Factor val
	pint a(val); // 8-bit number
	pint b = H(4,0x0f); // 4-bit possible 1st factor
	pint c = H(4,0xf0); // 4-bit possible 2nd factor
	pint d = b * c; // multiply 'em
	pint e = (d == a); // which were 143?
	pint f = e * b;
	int spot = f.First(); // factors
	int one = c.Meas(spot);
	int two = b.Meas(spot);
	printf("%d, %d are factors of %d\n", one, two, val);
}

void
pbitripple()
{
	// 4-bit wide pbitripple-carry adder
	// per Cuccaro et al
	// arXiv:quant-ph/0410184v1
	pbit a0(0), a1(0), a2(0), a3(0);
	pbit b0(1), b1(0), b2(0), b3(0);
	pbit z(0), x(0);
	H(a0, 0);
	H(a1, 1);
	H(a2, 2);
	H(a3, 3);
	CNOT(a1,b1); CNOT(a2,b2);
	CNOT(a3,b3); CNOT(a1,x);
	CCNOT(a0,b0,x); CNOT(a2,a1);
	CCNOT(x,b1,a1); CNOT(a3,a2);
	CCNOT(a1,b2,a2); CNOT(a3,z);
	CCNOT(a2,b3,z); NOT(b1);
	NOT(b2); CNOT(x,b1);
	CNOT(a1,b2); CNOT(a2,b3);
	CCNOT(a1,b2,a2);
	CCNOT(x,b1,a1);
	CNOT(a3,a2); NOT(b2);
	CCNOT(a0,b0,x); CNOT(a2,a1);
	NOT(b1); CNOT(a1,x);
	CNOT(a0,b0); CNOT(a1,b1);
	CNOT(a2,b2); CNOT(a3,b3);
	SETMEAS();
	printf("a=%d b=%d\n",
	       MEAS(a0)+(MEAS(a1)<<1)+(MEAS(a2)<<2)+(MEAS(a3)<<3),
	       MEAS(b0)+(MEAS(b1)<<1)+(MEAS(b2)<<2)+(MEAS(b3)<<3));
}

int
main(int argc, char **argv)
{
	srand(time(0)); // for random measurements...

	simpletest(); // tries each pint operation

	pintsqrt(169); // computes sqrt(169)

	pintfactor(143); // factors 143

	pbitripple(); // 4-bit ripple-carry adder

	re.Stats(); // show lower-level statistics
}
