// An example solution for programming assignment 7.
// CS 509, Fall 2001.

#include <algorithm>
#include <functional>
#include <iostream>
#include "group.h"


// The nice thing about this problem is that it responds well to a little
// thought.  Given a vector V of numbers whose sum is S, the problem is to
// group the numbers of V into n vectors such that the sum of each vector is s
// and they all sum to S; that is, find n and s such that n*s = S. Also, n
// should be as small as possible.
//
// It is not too hard to see that n, the number of groups formed, must be
// prime.  Suppose n was not prime; that is, suppose n = i*j where neither i
// nor j is 1.  Then S = n*s = (i*j)*s = i*(j*s), and the groups can be
// combined into i new vectors, each new vector consisting of j old vectors.
// However, because i is a non-trivial factor of n, then i is less than n, this
// contradicts the assumption than n is as small as possible.  If n*s = S is a
// solution to the problem, then n must be a prime factor of S.
//
// That n must be prime has two immediate consequences.  The first is that,
// given a particular vector V, the solution needs to consider only the
// possibilities that contain p vectors, where p is a prime factor of S.  The
// second consiquence is that, for any particular p, each vector must sum to
// S/p.
// 
// Assuming the existence of a procedure that accepts an number n and returns 
// the prime factors of n in a vector,
//
     static uvector prime_factors(unsigned n);
//
// and a procedure that accepts a vector of numbers v and a sum s and tries to
// partition v into vectors that sum to n,
//
     static bool make_groups(uvector uv, unsigned total, uvvector & uvv);
//
// then we can solve the problem with the following routine:
//
     uvvector group(const uvector & uv) {

       // Group the numbers in uv such that each group sums to the same value
       // and there are the smallest number of groups possible.  Return the
       // groups in a vector; if no such grouping exists, return an empty
       // vector.

       const unsigned total = sum(uv);
       const uvector pf = prime_factors(total);
       uvvector uvv;

       for (uvec_citer i = pf.begin(); i != pf.end(); i++)
	 if (make_groups(uv, total/(*i), uvv))
	   return uvv;

       return uvvector();
       }
//
// If none of the prime factors work, then there no possible grouping; return
// an empty vector.
//
// Finding prime factors is boring, so let's consider make_groups() instead, which
// accepts a vector of numbers and a value and determines if all the values in
// the vector can be formed into subset which sums to the given value.
//
// One way to solve this problem is to assume the procedure
// 
     static bool subset_sum(
       const uvec_citer & b, 
       const uvec_citer & e,
       uvector & ss, 
       unsigned ss_sum, 
       unsigned total);
//
// which finds a subset of the numbers in the range [b, e) that sums to total
// and stores the subset in subset.  If no such subset exists, then the numbers
// can't be grouped.  If such a subset exists, remove the numbers in the subset
// from the original vector using the procedure
//
     static void difference(uvector & uv1, const uvector & uv2);
//
// and try to make another subset with the reduced vector.  Keep going until no
// more subsets can be formed either because the numbers don't sum properly
// (bad) or there are no more numbers to sum (good).
//
     static bool make_groups(uvector uv, unsigned total, uvvector & uvv) {

       // Group the numbers in uv such that each group sums to total.  If such
       // a grouping is possible, store the groups in uvv and return true;
       // otherwise return false.  The contents of uvv is undefined if
       // make_groups() returns false.

       uvv.clear();

       do {
	 uvector g;
	 if (!subset_sum(uv.begin(), uv.end(), g, 0, total))
	   return false;
	 difference(uv, g);
	 uvv.push_back(g);
	 }
       while (!uv.empty());

       return true;
       }
//
// There's a simple, recursive approach for creating subsets: each element in
// the original vector either is or is not in the subset.
//
     static bool subset_sum(
       const uvec_citer & b, 
       const uvec_citer & e,
       uvector & subset, 
       unsigned ss_sum, 
       unsigned total) {

       // Given the range of numbers (b, e), select a subset of those numbers
       // such that the sum of the numbers in the subset equals total.  ss_sum
       // is the sum of the numbers currently stored in subset.

       // If the subset sum equals the total, we've found the subset we want.

	  if (ss_sum == total)
	    return true;

       // If the subset-sum is larger than total, or there's no more numbers to
       // add to the subset, the subset-sum isn't possible.

	  if ((ss_sum > total) || (b == e))
	    return false;

       // Assume *b isn't part of the subset and see what happens.

	  if (subset_sum(b + 1, e, subset, ss_sum, total))
	    return true;

       // Assume *b is part of the subset and see what happens.

	  subset.push_back(*b);
	  if (subset_sum(b + 1, e, subset, ss_sum + *b, total))
	    return true;
	  subset.pop_back();

       // The problem can't be solved with *b in the subset and the problem
       // can't be solved with *b not in the subset.  Becase these are the only
       // two choices, the problem can't be solved.

          return false;
       }
//
// difference() implements the usual multi-set difference, swapping the
// duplicated elements to the end of the vector and then erasing the end of the
// vector.
//
     static void difference(uvector & uv1, const uvector & uv2) {

       // Compute the multi-set difference uv1 = uv1 - uv2.

       uvec_iter e = uv1.end();

       for (uvec_citer i = uv2.begin(); i != uv2.end(); i++) {
	 uvec_iter j = find(uv1.begin(), e, *i);
	 assert(j != e);
	 --e;
	 std::swap(*j, *e);
	 }

       assert(static_cast<unsigned>(uv1.end() - e) == uv2.size());

       uv1.erase(e, uv1.end());
       }
//
// Assuming you know how to find all the primes less than n
//
     static uvector primes(unsigned n);
//
// finding the prime factors of n involves moving the non-factors to the end of
// the list and then erasing the end of the list.
//
     const unsigned zero = 0;

     static bool is_not_factor(unsigned f, unsigned n) {
       return n % f;
       }

     static uvector prime_factors(unsigned n) {

       // Return a vector containing the prime factors of n, excluding n.

       uvector pf = primes(n);

       std::replace_if(pf.begin(), pf.end(), 
		       std::bind2nd(std::ptr_fun(is_not_factor), n), 
		       zero);
       pf.erase(std::remove(pf.begin(), pf.end(), zero), pf.end());

       return pf;
       }
//
// Our ol' pal, the sieve of Eratosthenes, STL style.
//
     static bool non_zero(unsigned i) {
       return i;
       }

     static uvector primes(unsigned n) {

       // Return a vector containing all the primes that are smaller than n.

       if (n < 3)
	 return uvector();

       uvector p(n);

       unsigned i;

       p[2] = 2;
       for (i = 3; i < n; i += 2)
	 p[i] = i;

       uvec_iter b = p.begin() + 2;
       uvec_iter e = p.end();

       while ((b = std::find_if(b + 1, e, non_zero)) != e)
	 for (i = (*b)*(*b); i < n; i += *b*2)
	   p[i] = 0;

       p.erase(std::remove(p.begin(), e, zero), e);

       return p;
       }
//
// That's it - problem solved.
//

#ifdef TESTING_GROUP

// g++ -o testing-group -g -ansi -pedantic -Wall -DTESTING_GROUP group.cc && testing-group

int main() {

  uvector p = primes(0);
  assert(p.size() == 0);

  p = primes(2);
  assert(p.size() == 0);

  const unsigned c[] = { 2, 3, 5, 7 };
  uvector ck(c, c + sizeof(c)/sizeof(unsigned));

  p = primes(11);
  assert(p == ck);

  p = primes(12);
  ck.push_back(11);
  assert(p == ck);

  p = primes(13);
  assert(p == ck);

  p = prime_factors(11);
  assert(p.size() == 0);
  p = prime_factors(2*3);
  assert((p.size() == 2) && (p[0] == 2) && (p[1] == 3));
  p = prime_factors(2*2*3*3);
  assert((p.size() == 2) && (p[0] == 2) && (p[1] == 3));
  p = prime_factors(7*11*13);
  assert((p.size() == 3) && (p[0] == 7) && (p[1] == 11) && (p[2] == 13));
  }

#endif

// $Log: group.cc,v $
// Revision 1.4  2001/12/09 15:52:00  clayton
// *** empty log message ***
//
// Revision 1.3  2001/12/09 02:02:01  clayton
// *** empty log message ***
//
// Revision 1.2  2001/12/09 01:47:30  clayton
// *** empty log message ***
//
// Revision 1.1  2001/12/08 20:28:29  clayton
// Initial revision
//


syntax highlighted by Code2HTML, v. 0.9