//train-tiny2.cpp
//compila train-tiny2 -c -t
#include <cektiny.h>

int main(int argc, char** argv) {
  if (argc!=4) {
    printf("train-tiny: Cria modelo para segmentar Mat_<COR> pela cor\n");
    printf("train-tiny 'zer*' 'one*' modelo.net\n");
    erro("Erro: Numero de argumentos invalido");
  }

  vector< vec_t > ax; 
  vector< vec_t > ay; 
  
  vector<string> arqZero; vsWildCard(argv[1],arqZero);
  for (unsigned i=0; i<arqZero.size(); i++) {
    cout << "Lendo " << arqZero[i] << endl;
    Mat_<COR> a; le(a,arqZero[i]);
    for (unsigned j=0; j<a.total(); j++) {
      COR cor=a(j); 
      vec_t tx(3);
      vec_t ty(1);
      tx[0]=G2F(cor[0])-0.5; tx[1]=G2F(cor[1])-0.5; tx[2]=G2F(cor[2])-0.5;
      //A funcao G2F converte grayscale (0 a 255) para float (0 a 1)
      //Subtrai 0.5 para que o intervalo de cor va' de -0.5 a +0.5
      ty[0]=0;
      ax.push_back(tx); ay.push_back(ty);
    }
  }

  vector<string> arqOne; vsWildCard(argv[2],arqOne);
  for (unsigned i=0; i<arqOne.size(); i++) {
    cout << "Lendo " << arqOne[i] << endl;
    Mat_<COR> a; le(a,arqOne[i]);
    for (unsigned j=0; j<a.total(); j++) {
      COR cor=a(j); 
      vec_t tx(3);
      vec_t ty(1);
      tx[0]=G2F(cor[0])-0.5; tx[1]=G2F(cor[1])-0.5; tx[2]=G2F(cor[2])-0.5;
      ty[0]=1;
      ax.push_back(tx); ay.push_back(ty);
    }
  }
  // Neste ponto do programa, cada linha do vetor ax tem uma amostras de treinamento.
  // A cor de um pixel numa entrada de ax. As intensidades vao de -0.5 a +0.5
  // Cada linha do vetor ay tem uma saida: 0 se o pixel for feijao, 1 se for fundo

  printf("Treinando rede de tiny_dnn...\n");
  network<sequential> net;
  adagrad opt; // Otimizador adagrad funciona melhor que os outros.
  
  net << // Coloque aqui a sua estrutura de rede
      
  int batch_size = // Preencha adequadamente
  int epochs = // Preencha adequadamente

  cout << "start training" << endl;
  progress_display disp(ax.size());
    
  int epoch=1;

  //Esta funcao-lambda e' chamada toda vez que terminar um epoch
  auto on_enumerate_epoch = [&]() {
    float erro = net.get_loss<mse>(ax,ay)/ax.size();
    cout << endl << "Epoch " << epoch << "/" << epochs << " finished. "
         << "MSE: " << erro << endl;
    ++epoch;
    disp.restart(ax.size());
  };

  //Esta funcao-lambda e' chamada toda vez que terminar um minibatch
  auto on_enumerate_minibatch = [&]() { disp += batch_size; };
  
  net.fit<mse>(opt, ax, ay, batch_size, epochs, on_enumerate_minibatch, on_enumerate_epoch);
  
  net.save(argv[3]);
  cout << "Gravado arquivo " << string(argv[3]) << endl;
}
