// (C) 2002 Jasper Bedaux, ITFA, University of Amsterdam
// http://www.bedaux.net/jasper/
// C++ file for neural network simulations

// #define FOLLOW_VARS // uncomment to follow activities, potential and weight

#include "read_parms.h"
#include "neural_net.h"
#include "io_patterns.h"
#include "mtrand.h"
#include <vector>
#include <string>
#include <iostream>
#include <ctime> // to initialize random number generator
#include <cmath>
using namespace std;

int main(int argc, char* argv[]) {
  try {
// ***************************** INITIALIZATION
    MTRand_normal gaussian(time(0)); // construct and seed pseudo random number generator
    char* inifile("brain.ini"); // set default parameter file
    char* outfile("p.dat"); // default output file containing performance data
    if (argc > 1) inifile = argv[1]; // parameter file if specified
    if (argc > 2) outfile = argv[2]; // output file if specified
    ofstream out(outfile, ios::out);
    if (!out) throw string("unable to open ") + outfile + "for output";
#ifdef FOLLOW_VARS // output activities of layers and follow potential, weight
    ofstream ah_out("ah.dat", ios::out);
    ofstream ao_out("ao.dat", ios::out);
    ofstream hh_out("hh.dat", ios::out);
    ofstream ho_out("ho.dat", ios::out);
    ofstream wh_out("wh.dat", ios::out);
    ofstream wo_out("wo.dat", ios::out);
    if (!(ah_out && ao_out && hh_out && ho_out && wh_out && wo_out))
      throw string("unable to open output file");
#endif // FOLLOW_VARS
// read number of measurements parameters
    int ndata = readValue<int>(inifile, "ndata");
    if (ndata < 1) throw string("ndata must be 1 at least");
    vector<int> nmean = readRange<int>(inifile, "nmean", ndata);
    if ((nmean[0] < 1) || (nmean[ndata - 1] < 1))
      throw string("nmean must be 1 at least");
// read network and input-output relation parameters
    vector<int> ni = readRange<int>(inifile, "ni", ndata);
    vector<int> nh = readRange<int>(inifile, "nh", ndata);
    vector<int> no = readRange<int>(inifile, "no", ndata);
    vector<int> nio = readRange<int>(inifile, "nio", ndata);
    vector<int> nai = readRange<int>(inifile, "nai", ndata);
    vector<int> nah = readRange<int>(inifile, "nah", ndata);
    vector<int> nao = readRange<int>(inifile, "nao", ndata);
    vector<float> thetah = readRange<float>(inifile, "thetah", ndata);
    vector<float> thetao = readRange<float>(inifile, "thetao", ndata);
    vector<float> ph = readRange<float>(inifile, "ph", ndata);
    vector<float> po = readRange<float>(inifile, "po", ndata);
// read learning rule parameters
    vector<float> eta = readRange<float>(inifile, "eta", ndata);
    vector<float> kappa = readRange<float>(inifile, "kappa", ndata);
    vector<float> rho = readRange<float>(inifile, "rho", ndata);
    vector<float> delta = readRange<float>(inifile, "delta", ndata);
// read program control parameters
    bool randomIO = readValue<bool>(false, inifile, "randomIO");
    bool extremal = readValue<bool>(false, inifile, "extremal");
    vector<double> xvalue = readRange<double>(inifile,
      readValue<string>(inifile, "xvalue").c_str(), ndata);
// if xvalue is of int type get the xvalues from int vector
    if (readValue<string>(string("float"), inifile, "xvalue", 1) == string("int")) {
      vector<int> xint = readRange<int>(inifile,
        readValue<string>(inifile, "xvalue").c_str(), ndata);
      for (int i = 0; i < static_cast<int>(xvalue.size()); ++i)
        xvalue[i] = static_cast<double>(xint[i]);
    }
    vector<int> nmaxsearch = readRange<int>(inifile, "nmaxsearch", ndata);
    vector<int> nmaxcycles = readRange<int>(inifile, "nmaxcycles", ndata);
    vector<int> nequilibrium = readRange<int>(inifile, "nequilibrium", ndata);
// activities and learning/punishing rates
    vector<float> ai(ndata);
    vector<float> ah(ndata);
    vector<float> ao(ndata);
    vector<float> rhoh(ndata);
    vector<float> rhoo(ndata);
    vector<float> etah(ndata);
    vector<float> etao(ndata);
// calculate activities
    for (int i = 0; i < ndata; ++i) {
      if (!(ni[i] * nh[i] * no[i])) throw string("ni, nh or no is zero");
      ai[i] = static_cast<float>(nai[i]) / ni[i];
      ah[i] = static_cast<float>(nah[i]) / nh[i];
      ao[i] = static_cast<float>(nao[i]) / no[i];
    }
// create input/output patterns
    RelationSet patterns;
    if (!randomIO) {
      patterns = RelationSet(readValue<string>(inifile, "patternfile"));
// get nio, ni, no, nai, nao from read patterns
      nio = vector<int>(ndata, patterns.size());
      if (nio[0] < 1) throw string("nio must be > 0");
      ni = vector<int>(ndata, patterns[0].input.size());
      no = vector<int>(ndata, patterns[0].output.size());
      ai = vector<float>(ndata, patterns.ai());
      ao = vector<float>(ndata, patterns.ao());
      nai = vector<int>(ndata, static_cast<int>(ai[0] * ni[0] + .5)); // round
      nao = vector<int>(ndata, static_cast<int>(ao[0] * no[0] + .5)); // round
    }
// calculate eta's, rho's
    for (int i = 0; i < ndata; ++i) {
      if (!(ni[i] * nh[i] * ai[i] * ah[i] * ph[i] * po[i]))
        throw string("ni, nh, ai, ah, ph or po is zero");
      etah[i] = eta[i] / (ni[i] * ai[i] * ph[i]);
      etao[i] = eta[i] / (nh[i] * ah[i] * po[i]);
      rhoh[i] = rho[i] / (ni[i] * ai[i] * ph[i]);
      rhoo[i] = rho[i] / (nh[i] * ah[i] * po[i]);
    }
    for (int data = 0; data < ndata; ++data) {
      cout << "Simulation " << (data + 1) << " of " << ndata << ":" << endl;
      double tmean = 0.; // mean total time of learning cycle
      double t2mean = 0.; // total time squared mean
      for (int n = 0; n < nmean[data]; ++n) {
        cout << "repeat measurement " << (n + 1) << " of " << nmean[data] << ":" << endl;
// for random patterns, create new patterns
        if (randomIO) patterns = RelationSet(nio[data], ni[data], no[data], nai[data], nao[data]);
// create network
        Network network(ni[data], nh[data], no[data], ph[data], po[data]);
// set thresholds
        for (int i = network.offset(hidden); i < network.offset(output); ++i)
          network[i].setThreshold(thetah[data]);
        for (int i = network.offset(output); i < network.size(); ++i)
          network[i].setThreshold(thetao[data]);
// set starting weights
        for (int i = 0; i < network.offset(hidden); ++i)
          for (int j = 0; j < network[i].size(); ++j)
            network[i][j].setWeight(rhoh[data] * (.5 * gaussian() + thetah[data] / rho[data]));
        for (int i = network.offset(hidden); i < network.offset(output); ++i)
          for (int j = 0; j < network[i].size(); ++j)
            network[i][j].setWeight(rhoo[data] * (.5 * gaussian() + thetao[data] / rho[data]));
// wait for equilibrium
        int start_time = time(0);
        for (int i = 0; i < nequilibrium[data]; ++i) {
          if (!(i % 100))
            cout << "waiting for equilibrium: " << i << " of " << nequilibrium[data] << "\r" << flush;
          RelationSet temp = RelationSet(1, ni[data], 0, nai[data], 0);
          if (extremal) network.findFixed(temp[0].input(), nah[data], nao[data]);
          else network.findFixed(temp[0].input());
          antiHebb(network.begin(input), network.end(input), rhoh[data], ah[data], delta[data]);
          antiHebb(network.begin(hidden), network.end(hidden), rhoo[data], ao[data], delta[data]);
        }
        cout << "waiting for equilibrium: " << nequilibrium[data] << " of " <<
          nequilibrium[data] << "\n" << endl;
        cout << (time(0) - start_time) << " seconds used waiting for equilibrium\n" << endl;
// ***************************** SIMULATION
        double cycletime = 0.;
        bool totalRecall;
        for (int cycle = 0; cycle < nmaxcycles[data]; ++cycle) {
          totalRecall = true;
          cout << "cycle " << cycle << ":" << endl;
          for (int mu = 0; mu < patterns.size(); ++mu) {
            int learntime = 0;
            for (int i = 0 ; i < nmaxsearch[data]; ++i) {
              if (extremal) network.findFixed(patterns[mu].input(), nah[data], nao[data]);
              else network.findFixed(patterns[mu].input());
#ifdef FOLLOW_VARS
              int count = 0;
              for (int j = network.offset(hidden); j < network.offset(output); ++j)
                count += network[j].isActive();
              ah_out << static_cast<float>(count) / network.size(hidden) << "\n";
              count = 0;
              for (int j = network.offset(output); j < network.size(); ++j)
                count += network[j].isActive();
              ao_out << static_cast<float>(count) / network.size(output) << "\n";
              hh_out << network[network.offset(hidden)].potential() << "\n";
              ho_out << network[network.offset(output)].potential() << "\n";
              for (int j = 0; j < network.offset(hidden); ++j) {
                if (network[j].size()) { // find first neuron that has connection
                  wh_out << network[j][0].weight() << "\n";
                  break;
                }
              }
              for (int j = network.offset(hidden); j < network.offset(output); ++j) {
                if (network[j].size()) { // find first neuron that has connection
                  wo_out << network[j][0].weight() << "\n";
                  break;
                }
              }
#endif // FOLLOW_VARS
              if (network.getOutput() == patterns[mu].output()) { // correct output
//                if (i) // do not reinforce if recalled at once
                Hebb(network.begin(input), network.end(input), etah[data], kappa[data], delta[data]);
                Hebb(network.begin(hidden), network.end(hidden), etao[data], kappa[data], delta[data]);
                learntime = i + 1; // learning succeeded in (i + 1) steps
                break;
              }
              antiHebb(network.begin(input), network.end(input), rhoh[data], ah[data], delta[data]);
              antiHebb(network.begin(hidden), network.end(hidden), rhoo[data], ao[data], delta[data]);
            }
            cycletime += learntime;
            cout << "pattern " << patterns[mu].id();
            if (learntime == 1)  cout << ": RECALLED" << endl;
            else { // not recalled
              totalRecall = false;
              if (!learntime) cout << ": FAILED" << endl;
              else cout << " learned in " << learntime << " steps" << endl;
            }
          }
          cout << '\n';
          if (totalRecall) break;
          patterns.shuffle();
        }
        cout << "total learning steps: " << cycletime << "\n" << endl;
        tmean += (cycletime - tmean) / (n + 1.);
        t2mean += (cycletime * cycletime - t2mean) / (n + 1.);
      }
      double error = sqrt((t2mean - tmean * tmean) / nmean[data]);
      double tmin = nio[data] / (pow(ao[data], nao[data]) * pow(1. - ao[data], no[data] - nao[data]));
      out << xvalue[data] << " " << tmin / tmean << " " << tmin / (tmean + error)
        << " " << tmin / (tmean - error) << endl;
    }
  }
  catch (string& error) {
    cerr << "Error: " << error << endl;
    return 1;
  }
}