// (C) 2002 Jasper Bedaux, ITFA, University of Amsterdam
// http://www.bedaux.net/jasper/
// C++ file for input/output patterns for neural network

#include "io_patterns.h"
#include "mtrand.h"
#include <string>
#include <fstream>
using namespace std;

Pattern::Pattern(int n, int a) throw (string) : nactive(a) {
  if ((n < 0) || (a < 0) || (a > n)) // test input
    throw string("wrong size or number of active neurons in pattern");
  p.resize(n, false); // set size of pattern and initialize
// generate input with desired activity
  MTRand_int32 irand;
  for (int i = 0; i < a; ++i) {
    int r;
    do { r = irand() % n; } while (p[r]);
    p[r] = true;
  }
}

Pattern::Pattern(const string& pattern) throw (string) : nactive(0) {
  p.resize(pattern.size(), false); // set size of pattern and initialize
  for (int i = 0; i < static_cast<int>(pattern.size()); ++i) {
    if (pattern[i] == '1') {
      p[i] = true;
      ++nactive;
    }
    else if (!(pattern[i] == '0')) throw string("illegal character in pattern");
  }
}

// constructor for random set of input-output relations
RelationSet::RelationSet(int nio, int ni, int no, int nai, int nao)
    throw (string) : ainput(nai), aoutput(nao) {
  if (nio < 0) throw string("negative number of patterns");
  if (ni) ainput /= ni;
  if (no) aoutput /= no;
// test if number of possible patterns with activity nai is large
// enough to generate nio unique patterns with the formula
// number of patterns = ni! / (nai!(ni - nai)!)
// use double to prevent overflow
  double nunique = 1.;
  if (nai) nunique = static_cast<double>(ni) / nai;
  for (int i = 1; i < nai; ++i) nunique *= static_cast<double>(ni - i) / i;
  if ((nunique + .5) < nio)
    throw string("number of possible unique input patterns smaller than nio");
  p.reserve(nio); // reserve memory for Relation vector
// find unique input patterns with specified activity
  for (int i = 0; i < nio; ++i) {
    bool duplicate = true;
    Pattern newinput;
    while (duplicate) {
      duplicate = false;
      newinput = Pattern(ni, nai); // construct random pattern
      for (int j = 0; j < i; ++j) {
        if (p[j].input() == newinput()) {
          duplicate = true;
          break;
        }
      }
    }
    p.push_back(Relation(i, newinput, Pattern(no, nao))); // add relation
  }
}

// constructor for reading patterns from file
RelationSet::RelationSet(const string& patternfile) throw (string) :
    ainput(0.), aoutput(0.) {
  ifstream in(patternfile.c_str(), ios::in);
  if (!in) throw string("unable to open '") + patternfile + "' for input";
  int nio, ni, no;
  in >> nio >> ni >> no;
  if (in.fail() || in.bad() || (nio < 0) || (ni < 0) || (no < 0))
    throw string("error reading nio, ni, no");
  p.reserve(nio); // reserve vector size
  string ip, op; // for input, output patterns
  for (int i = 0; i < nio; ++i) {
    in >> ip >> op;
    if (in.fail() || in.bad())
      throw string("error reading input-output line from '") + patternfile + "'";
    p.push_back(Relation(i, ip, op));
  }
  for (int i = 0; i < nio; ++i) {
// check sizes
    if (!((static_cast<int>(p[i].input.size()) == ni) &&
        (static_cast<int>(p[i].output.size()) == no)))
      throw string("read pattern has wrong size");
// check for uniqueness of read input patterns
    for (int j = i + 1; j < nio; ++j)
      if (p[i].input() == p[j].input())
        throw string("duplicate input pattern read from file");
// calculate mean activities
     ainput += p[i].input.na();
     aoutput += p[i].output.na();
  }
  if (nio) {
    if (p[0].input.size())
      ainput /= static_cast<double>(nio) * p[0].input.size();
    if (p[0].output.size())
      aoutput /= static_cast<double>(nio) * p[0].output.size();
  }
}

void RelationSet::shuffle() {
  MTRand_int32 irand;
  for (int i = 0; i < size(); ++i) {
    int destination = irand() % size();
    Relation temp(p[destination]);
    p[destination] = p[i];
    p[i] = temp;
  }
}