// CS 305 - Computer Algorithms
// Fall 2003
// An example binary search tree implementation.


#include <cassert>
#include <cstdlib>
#include <unistd.h>
#include <algorithm>


typedef int value;


class binary_search_tree {

  public:


    // Add the given value v to this binary search tree; if this binary search
    // tree already contains v, it is not added again.

       void add(value v) {

	 node_pair p = find(v, &pseudo_root, pseudo_root.left);

	 if (0 == p.child)
	   if ((p.parent == &pseudo_root) or (p.parent->v > v))
	     p.parent->left = new node(v);
	   else {
	     assert(p.parent->v < v);
	     p.parent->right = new node(v);
	     }
	 }


    // Create an empty binary search tree.

       binary_search_tree() : pseudo_root(0) { }


    // Free up this binary serach tree.

      ~binary_search_tree() {
	 assert(0 == pseudo_root.right);
	 delete pseudo_root.left;
	 pseudo_root.left = 0;
	 }


    // Return true iff this binary search tree contains no elements.

       bool empty() const {
	 return 0 == pseudo_root.left;
	 }
   

    // Return true iff this binary search tree contains the given value.

       bool find(value v) {

	 node_pair p = find(v, &pseudo_root, pseudo_root.left);

	 return p.child != 0;
	 }


    // Remove the given value v from this binary search tree; nothing happens
    // if v isn't in this binary search tree.

       void remove(value v) {

	 node_pair p = find(v, &pseudo_root, pseudo_root.left);

	 if (p.child != 0) {
  	   assert(p.child->v == v);
	   remove_child(p.parent, p.child);
	   }
	 }


  private:


    binary_search_tree(const binary_search_tree &) : pseudo_root(0) {
      assert(!"calling binary_search_tree(const binary_search_tree &)");
      }


    binary_search_tree & operator = (const binary_search_tree &) {
      assert(!"calling binary_search_tree::operator =");
      return *this;
      }


    class node {

      public:

	value v;
	node * left, * right;
	node(value v) : v(v), left(0), right(0) { }
       ~node() { delete left; left = 0; delete right; right = 0; }

      private:

	node & operator = (const node &) {
	  assert(!"calling node::operator =");
	  return *this;
	  }

	node(const node &) {
	  assert(!"calling node::node(const node &)");
	  }
      };



    struct node_pair {
      node * parent, * child;
      node_pair(node * p, node * c) : parent(p), child(c) { }
      };


    node pseudo_root;


    node_pair find(value v, node * parent, node * child) {

      // Search the subtree rooted at the given child node, which has the given
      // parent, for a node containing given value v.  Return the node
      // containing v its parent; if v can't be found, return as the parent the
      // last node examined and set the child to 0.

      assert((parent != 0) and 
	     ((parent->left == child) or (parent->right == child)));

      if ((0 == child) or (child->v == v))
	return node_pair(parent, child);

      return find(v, child, child->v < v ? child->right : child->left);
      }


    value find_max(node * parent, node * child) {

      // Find and remove the node containing the largest value in the subtree
      // rooted at the given child, which has the given parent; return the
      // largest value.

      if (0 == child->right) {
	value v = child->v;
	remove_child(parent, child);
	return v;
	}
      else
	return find_max(child, child->right);
      }


    value find_min(node * parent, node * child) {

      // Find and remove the node containing the smallest value in the subtree
      // rooted at the given child, which has the given parent; return the
      // smallest value.

      if (0 == child->left) {
	value v = child->v;
	remove_child(parent, child);
	return v;
	}
      else
	return find_min(child, child->left);
      }


    void remove_child(node * parent, node * child) {

      // Remove the given child, which has the given parent.

      if (parent->left == child) 
	parent->left = remove_root(child);
      else {
	assert(parent->right == child);
	parent->right = remove_root(child);
	}
      }


    node * remove_root(node * root) {

      // Remove the given subtree root; return a pointer to the new root or 0
      // if the subtree's empty after removing the root.

      if (root->left)
	root->v = find_max(root, root->left);
      else if (root->right)
	root->v = find_min(root, root->right);
      else {
	delete root;
	root = 0;
	}

      return root;
      }
  };



static void 
test_bst() {

  // Create a random binary search tree and then uncreate it, checking all the
  // way.

  const unsigned cnt = (random() % 100000) + 1;
  int * const ints = new int [cnt];
  unsigned i;
  binary_search_tree bst;

  for (i = 0; i < cnt; i++)
    ints[i] = i;

  std::random_shuffle(ints, ints + cnt);

  for (i = 0; i < cnt; i++)
    bst.add(ints[i]);

  std::random_shuffle(ints, ints + cnt);

  for (i = 0; i < cnt; i++) {
    assert(bst.find(ints[i]));
    bst.remove(ints[i]);
    assert(not bst.find(ints[i]));
    }

  assert(bst.empty());

  delete [] ints;
  }


// g++ -o test-bst -gstabs -Wall -ansi -pedantic bst.cc && ./test-bst

int 
main() {

  binary_search_tree bst;
  assert(bst.empty());
  assert(not bst.find(1));
  bst.add(1);
  assert(bst.find(1));
  assert(not bst.find(2));
  bst.add(2);
  assert(bst.find(1));
  assert(bst.find(2));
  bst.remove(1);
  assert(not bst.find(1));
  assert(bst.find(2));
  bst.remove(2);
  assert(not bst.find(1));
  assert(not bst.find(2));
  assert(bst.empty());

  srandom(getpid());
  for (unsigned i = 0; i < 10; i++)
    test_bst();
  }


// $Log: bst.cc,v $
// Revision 1.1  2003/12/08 16:03:30  rclayton
// Initial revision
//


syntax highlighted by Code2HTML, v. 0.9