/*
   orxonox - the future of 3D-vertical-scrollers

   Copyright (C) 2004 orx

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

### File Specific:
   main-programmer: claudio
   co-programmer:
*/


/* this is for debug output. It just says, that all calls to PRINT() belong to the DEBUG_MODULE_NETWORK module
   For more information refere to https://www.orxonox.net/cgi-bin/trac.cgi/wiki/DebugOutput
*/
#define DEBUG_MODULE_NETWORK


#include "base_object.h"
#include "network_protocol.h"
#include "udp_socket.h"
#include "udp_server_socket.h"
#include "connection_monitor.h"
#include "synchronizeable.h"
#include "network_game_manager.h"
#include "shared_network_data.h"
#include "message_manager.h"
#include "preferences.h"
#include "zip.h"

#include "src/lib/util/loading/resource_manager.h"

#include "network_log.h"


#include "lib/util/loading/factory.h"

#include "debug.h"
#include "class_list.h"
#include <algorithm>

/* include your own header */
#include "network_stream.h"

/* probably unnecessary */
using namespace std;


#define PACKAGE_SIZE  256


NetworkStream::NetworkStream()
    : DataStream()
{
  this->init();
  /* initialize the references */
  this->type = NET_CLIENT;
}


NetworkStream::NetworkStream( std::string host, int port )
{
  this->type = NET_CLIENT;
  this->init();
  this->peers[0].socket = new UdpSocket( host, port );
  this->peers[0].userId = 0;
  this->peers[0].isServer = true;
  this->peers[0].connectionMonitor = new ConnectionMonitor( 0 );
}


NetworkStream::NetworkStream( int port )
{
  this->type = NET_SERVER;
  this->init();
  this->serverSocket = new UdpServerSocket(port);
  this->bActive = true;
}


void NetworkStream::init()
{
  /* set the class id for the base object */
  this->setClassID(CL_NETWORK_STREAM, "NetworkStream");
  this->bActive = false;
  this->serverSocket = NULL;
  this->networkGameManager = NULL;
  myHostId = 0;
  currentState = 0;
  
  remainingBytesToWriteToDict = Preferences::getInstance()->getInt( "compression", "writedict", 0 );
  
  assert( Zip::getInstance()->loadDictionary( "testdict" ) >= 0 );
  this->dictClient = Zip::getInstance()->loadDictionary( "dict2pl_client" );
  assert( this->dictClient >= 0 );
  this->dictServer = Zip::getInstance()->loadDictionary( "dict2p_server" );
  assert( this->dictServer >= 0 );
}


NetworkStream::~NetworkStream()
{
  if ( this->serverSocket )
  {
    serverSocket->close();
    delete serverSocket;
    serverSocket = NULL;
  }
  for ( PeerList::iterator i = peers.begin(); i!=peers.end(); i++)
  {
    if ( i->second.socket )
    {
      i->second.socket->disconnectServer();
      delete i->second.socket;
      i->second.socket = NULL;
    }
    
    if ( i->second.handshake )
    {
      delete i->second.handshake;
      i->second.handshake = NULL;
    }
    
    if ( i->second.connectionMonitor )
    {
      delete i->second.connectionMonitor;
      i->second.connectionMonitor = NULL;
    }
  }
  for ( SynchronizeableList::const_iterator it = getSyncBegin(); it != getSyncEnd(); it ++ )
    (*it)->setNetworkStream( NULL );
}


void NetworkStream::createNetworkGameManager()
{
  this->networkGameManager = NetworkGameManager::getInstance();
  // setUniqueID( maxCon+2 ) because we need one id for every handshake
  // and one for handshake to reject client maxCon+1
  this->networkGameManager->setUniqueID( SharedNetworkData::getInstance()->getNewUniqueID() );
  MessageManager::getInstance()->setUniqueID( SharedNetworkData::getInstance()->getNewUniqueID() );
}


void NetworkStream::startHandshake()
{
  Handshake* hs = new Handshake(false);
  hs->setUniqueID( 0 );
  assert( peers[0].handshake == NULL );
  peers[0].handshake = hs;
//   peers[0].handshake->setSynchronized( true );
  //this->connectSynchronizeable(*hs);
  //this->connectSynchronizeable(*hs);
  PRINTF(0)("NetworkStream: Handshake created: %s\n", hs->getName());
}


void NetworkStream::connectSynchronizeable(Synchronizeable& sync)
{
  this->synchronizeables.push_back(&sync);
  sync.setNetworkStream( this );

  this->bActive = true;
}


void NetworkStream::disconnectSynchronizeable(Synchronizeable& sync)
{
  // removing the Synchronizeable from the List.
  std::list<Synchronizeable*>::iterator disconnectSynchro = std::find(this->synchronizeables.begin(), this->synchronizeables.end(), &sync);
  if (disconnectSynchro != this->synchronizeables.end())
    this->synchronizeables.erase(disconnectSynchro);
  
  oldSynchronizeables[sync.getUniqueID()] = SDL_GetTicks();
}


void NetworkStream::processData()
{
  int tick = SDL_GetTicks();
  
  currentState++;
  
  if ( this->type == NET_SERVER )
  {
    if ( serverSocket )
      serverSocket->update();
    
    this->updateConnectionList();
  }
  else
  {
    if ( peers[0].socket && ( !peers[0].socket->isOk() || peers[0].connectionMonitor->hasTimedOut() ) )
    {
      PRINTF(1)("lost connection to server\n");

      peers[0].socket->disconnectServer();
      delete peers[0].socket;
      peers[0].socket = NULL;

      if ( peers[0].handshake )
        delete peers[0].handshake;
      peers[0].handshake = NULL;
      
      if ( peers[0].connectionMonitor )
        delete peers[0].connectionMonitor;
      peers[0].connectionMonitor = NULL;
    }
  }

  cleanUpOldSyncList();
  handleHandshakes();
  
  // order of up/downstream is important!!!!
  // don't change it
  handleDownstream( tick );
  handleUpstream( tick );

}

void NetworkStream::updateConnectionList( )
{
  //check for new connections

  NetworkSocket* tempNetworkSocket = serverSocket->getNewSocket();

  if ( tempNetworkSocket )
  {
    int clientId;
    if ( freeSocketSlots.size() >0 )
    {
      clientId = freeSocketSlots.back();
      freeSocketSlots.pop_back();
      peers[clientId].socket = tempNetworkSocket;
      peers[clientId].handshake = new Handshake(true, clientId, this->networkGameManager->getUniqueID(), MessageManager::getInstance()->getUniqueID() );
      peers[clientId].connectionMonitor = new ConnectionMonitor( clientId );
      peers[clientId].handshake->setUniqueID(clientId);
      peers[clientId].userId = clientId;
      peers[clientId].isServer = false;
    } else
    {
      clientId = 1;
      
      for ( PeerList::iterator it = peers.begin(); it != peers.end(); it++ )
        if ( it->first >= clientId )
          clientId = it->first + 1;
      
      peers[clientId].socket = tempNetworkSocket;
      peers[clientId].handshake = new Handshake(true, clientId, this->networkGameManager->getUniqueID(), MessageManager::getInstance()->getUniqueID());
      peers[clientId].handshake->setUniqueID(clientId);
      peers[clientId].connectionMonitor = new ConnectionMonitor( clientId );
      peers[clientId].userId = clientId;
      peers[clientId].isServer = false;
      
      PRINTF(0)("num sync: %d\n", synchronizeables.size());
    }

    if ( clientId > MAX_CONNECTIONS )
    {
      peers[clientId].handshake->doReject( "too many connections" );
      PRINTF(0)("Will reject client %d because there are to many connections!\n", clientId);
    }
    else

    PRINTF(0)("New Client: %d\n", clientId);

    //this->connectSynchronizeable(*handshakes[clientId]);
  }

  //check if connections are ok else remove them
  for ( PeerList::iterator it = peers.begin(); it != peers.end(); )
  {
    if (  
          it->second.socket &&
          ( 
            !it->second.socket->isOk()  ||
            it->second.connectionMonitor->hasTimedOut()
          )
       )
    {
      std::string reason = "disconnected";
      if ( it->second.connectionMonitor->hasTimedOut() )
        reason = "timeout";
      PRINTF(0)("Client is gone: %d (%s)\n", it->second.userId, reason.c_str());
      
      //assert(false);

      it->second.socket->disconnectServer();
      delete it->second.socket;
      it->second.socket = NULL;

      if ( it->second.connectionMonitor )
        delete it->second.connectionMonitor;
      it->second.connectionMonitor = NULL;
      
      if ( it->second.handshake )
        delete it->second.handshake;
      it->second.handshake = NULL;
      
      for ( SynchronizeableList::iterator it2 = synchronizeables.begin(); it2 != synchronizeables.end(); it2++ )
      {
        (*it2)->cleanUpUser( it->second.userId );
      }

      NetworkGameManager::getInstance()->signalLeftPlayer(it->second.userId);

      freeSocketSlots.push_back( it->second.userId );
      
      PeerList::iterator delit = it;
      it++;
      
      peers.erase( delit );
      
      continue;
    }
    
    it++;
  }


}

void NetworkStream::debug()
{
  if( this->isServer())
    PRINT(0)(" Host ist Server with ID: %i\n", this->myHostId);
  else
    PRINT(0)(" Host ist Client with ID: %i\n", this->myHostId);

  PRINT(0)(" Got %i connected Synchronizeables, showing active Syncs:\n", this->synchronizeables.size());
  for (SynchronizeableList::iterator it = synchronizeables.begin(); it!=synchronizeables.end(); it++)
  {
    if( (*it)->beSynchronized() == true)
      PRINT(0)("  Synchronizeable of class: %s::%s, with unique ID: %i, Synchronize: %i\n", (*it)->getClassName(), (*it)->getName(),
               (*it)->getUniqueID(), (*it)->beSynchronized());
  }
  PRINT(0)(" Maximal Connections: %i\n", MAX_CONNECTIONS );

}


int NetworkStream::getSyncCount()
{
  int n = 0;
  for (SynchronizeableList::iterator it = synchronizeables.begin(); it!=synchronizeables.end(); it++)
    if( (*it)->beSynchronized() == true)
      ++n;

  //return synchronizeables.size();
  return n;
}

/**
 * check if handshakes completed
 */
void NetworkStream::handleHandshakes( )
{
  for ( PeerList::iterator it = peers.begin(); it != peers.end(); it++ )
  {
    if ( it->second.handshake )
    {
      if ( it->second.handshake->completed() )
      {
        if ( it->second.handshake->ok() )
        {
          if ( !it->second.handshake->allowDel() )
          {
            if ( type != NET_SERVER )
            {
              SharedNetworkData::getInstance()->setHostID( it->second.handshake->getHostId() );
              myHostId = SharedNetworkData::getInstance()->getHostID();

              this->networkGameManager = NetworkGameManager::getInstance();
              this->networkGameManager->setUniqueID( it->second.handshake->getNetworkGameManagerId() );
              MessageManager::getInstance()->setUniqueID( it->second.handshake->getMessageManagerId() );
            }
              

            PRINT(0)("handshake finished id=%d\n", it->second.handshake->getNetworkGameManagerId());

            it->second.handshake->del();
          }
          else
          {
            if ( it->second.handshake->canDel() )
            {
              if ( type == NET_SERVER )
              {
                handleNewClient( it->second.userId );
              }
              
              PRINT(0)("handshake finished delete it\n");
              delete it->second.handshake;
              it->second.handshake = NULL;
            }
          }

        }
        else
        {
          PRINT(1)("handshake failed!\n");
          it->second.socket->disconnectServer();
        }
      }
    }
  }
}

/**
 * handle upstream network traffic
 */
void NetworkStream::handleUpstream( int tick )
{
  int offset;
  int n;
  
  for ( PeerList::reverse_iterator peer = peers.rbegin(); peer != peers.rend(); peer++ )
  {
    offset = INTSIZE; //make already space for length 
    
    if ( !peer->second.socket )
      continue;
    
    n = Converter::intToByteArray( currentState, buf + offset, UDP_PACKET_SIZE - offset );
    assert( n == INTSIZE );
    offset += n;
    
    n = Converter::intToByteArray( peer->second.lastAckedState, buf + offset, UDP_PACKET_SIZE - offset );
    assert( n == INTSIZE );
    offset += n;
    
    n = Converter::intToByteArray( peer->second.lastRecvedState, buf + offset, UDP_PACKET_SIZE - offset );
    assert( n == INTSIZE );
    offset += n;
    
    for ( SynchronizeableList::iterator it = synchronizeables.begin(); it != synchronizeables.end(); it++ )
    {
      int oldOffset = offset;
      Synchronizeable & sync = **it;
      
      if ( !sync.beSynchronized() || sync.getUniqueID() < 0 )
        continue;

      //if handshake not finished only sync handshake
      if ( peer->second.handshake && sync.getLeafClassID() != CL_HANDSHAKE )
        continue;
      
      if ( isServer() && sync.getLeafClassID() == CL_HANDSHAKE && sync.getUniqueID() != peer->second.userId )
        continue;
      
      //do not sync null parent
      if ( sync.getLeafClassID() == CL_NULL_PARENT )
        continue;

      assert( offset + INTSIZE <= UDP_PACKET_SIZE );
      
      //server fakes uniqueid=0 for handshake
      if ( this->isServer() && sync.getUniqueID() < MAX_CONNECTIONS - 1 )
        n = Converter::intToByteArray( 0, buf + offset, UDP_PACKET_SIZE - offset );
      else
        n = Converter::intToByteArray( sync.getUniqueID(), buf + offset, UDP_PACKET_SIZE - offset );
      assert( n == INTSIZE );
      offset += n;
      
      //make space for size
      offset += INTSIZE;

      n = sync.getStateDiff( peer->second.userId, buf + offset, UDP_PACKET_SIZE-offset, currentState, peer->second.lastAckedState, -1000 );
      offset += n;
      //NETPRINTF(0)("GGGGGEEEEETTTTT: %s (%d) %d\n",sync.getClassName(), sync.getUniqueID(), n);
      
      assert( Converter::intToByteArray( n, buf + offset - n - INTSIZE, INTSIZE ) == INTSIZE );
      
      //check if all bytes == 0 -> remove data 
      //TODO not all synchronizeables like this maybe add Synchronizeable::canRemoveZeroDiff()
      bool allZero = true; 
      for ( int i = 0; i < n; i++ ) 
      { 
         if ( buf[i+oldOffset+2*INTSIZE] != 0 ) 
           allZero = false; 
      } 

      if ( allZero ) 
      { 
        //NETPRINTF(n)("REMOVE ZERO DIFF: %s (%d)\n", sync.getClassName(), sync.getUniqueID()); 
        offset = oldOffset; 
      } 

      
    }
    
    for ( SynchronizeableList::iterator it = synchronizeables.begin(); it != synchronizeables.end(); it++ )
    {
      Synchronizeable & sync = **it;
      
      if ( !sync.beSynchronized() || sync.getUniqueID() < 0 )
        continue;
      
      sync.handleSentState( peer->second.userId, currentState, peer->second.lastAckedState );
    }
    
    assert( Converter::intToByteArray( offset, buf, INTSIZE ) == INTSIZE );
    
    int compLength = 0;
    if ( this->isServer() )
      compLength = Zip::getInstance()->zip( buf, offset, compBuf, UDP_PACKET_SIZE, dictServer );
    else
      compLength = Zip::getInstance()->zip( buf, offset, compBuf, UDP_PACKET_SIZE, dictClient );
    
    if ( compLength <= 0 )
    {
      PRINTF(1)("compression failed!\n");
      continue;
    }
    
    assert( peer->second.socket->writePacket( compBuf, compLength ) );
    
    if ( this->remainingBytesToWriteToDict > 0 )
      writeToNewDict( buf, offset, true );
    
    peer->second.connectionMonitor->processUnzippedOutgoingPacket( tick, buf, offset, currentState );
    peer->second.connectionMonitor->processZippedOutgoingPacket( tick, compBuf, compLength, currentState );
    
    //NETPRINTF(n)("send packet: %d userId = %d\n", offset, peer->second.userId);
  }
}

/**
 * handle downstream network traffic
 */
void NetworkStream::handleDownstream( int tick )
{
  int offset = 0;
  
  int length = 0;
  int packetLength = 0;
  int compLength = 0;
  int uniqueId = 0;
  int state = 0;
  int ackedState = 0;
  int fromState = 0;
  int syncDataLength = 0;
  
  for ( PeerList::iterator peer = peers.begin(); peer != peers.end(); peer++ )
  {
    
    if ( !peer->second.socket )
      continue;

    while ( 0 < (compLength = peer->second.socket->readPacket( compBuf, UDP_PACKET_SIZE )) )
    {
      peer->second.connectionMonitor->processZippedIncomingPacket( tick, compBuf, compLength );
      
      //PRINTF(0)("GGGGGOOOOOOOOOOTTTTTTTT: %d\n", compLength);
      packetLength = Zip::getInstance()->unZip( compBuf, compLength, buf, UDP_PACKET_SIZE );

      if ( packetLength < 4*INTSIZE )
      {
        if ( packetLength != 0 )
          PRINTF(1)("got too small packet: %d\n", packetLength);
        continue;
      }
      
      if ( this->remainingBytesToWriteToDict > 0 )
        writeToNewDict( buf, packetLength, false );
    
      assert( Converter::byteArrayToInt( buf, &length ) == INTSIZE );
      assert( Converter::byteArrayToInt( buf + INTSIZE, &state ) == INTSIZE );
      assert( Converter::byteArrayToInt( buf + 2*INTSIZE, &fromState ) == INTSIZE );
      assert( Converter::byteArrayToInt( buf + 3*INTSIZE, &ackedState ) == INTSIZE );
      //NETPRINTF(n)("ackedstate: %d\n", ackedState);
      offset = 4*INTSIZE;
      
      peer->second.connectionMonitor->processUnzippedIncomingPacket( tick, buf, packetLength, state, ackedState );

      //NETPRINTF(n)("got packet: %d, %d\n", length, packetLength);
    
    //if this is an old state drop it
      if ( state <= peer->second.lastRecvedState )
        continue;
    
      if ( packetLength != length )
      {
        PRINTF(1)("real packet length (%d) and transmitted packet length (%d) do not match!\n", packetLength, length);
        peer->second.socket->disconnectServer();
        continue;
      }
      
      while ( offset + 2*INTSIZE < length )
      {
        assert( offset > 0 );
        assert( Converter::byteArrayToInt( buf + offset, &uniqueId ) == INTSIZE );
        offset += INTSIZE;
      
        assert( Converter::byteArrayToInt( buf + offset, &syncDataLength ) == INTSIZE );
        offset += INTSIZE;
        
        assert( syncDataLength > 0 );
        assert( syncDataLength < 10000 );
      
        Synchronizeable * sync = NULL;
        
        for ( SynchronizeableList::iterator it = synchronizeables.begin(); it != synchronizeables.end(); it++ )
        { 
        //                                        client thinks his handshake has id 0!!!!!
          if ( (*it)->getUniqueID() == uniqueId || ( uniqueId == 0 && (*it)->getUniqueID() == peer->second.userId ) )
          {
            sync = *it;
            break;
          }
        }
        
        if ( sync == NULL )
        {
          PRINTF(0)("could not find sync with id %d. try to create it\n", uniqueId);
          if ( oldSynchronizeables.find( uniqueId ) != oldSynchronizeables.end() )
          {
            offset += syncDataLength;
            continue;
          }
          
          if ( !peers[peer->second.userId].isServer )
          {
            offset += syncDataLength;
            continue;
          }
          
          int leafClassId;
          if ( INTSIZE > length - offset )
          {
            offset += syncDataLength;
            continue;
          }

          Converter::byteArrayToInt( buf + offset, &leafClassId );
          
          assert( leafClassId != 0 );
        
          BaseObject * b = NULL;
          /* These are some small exeptions in creation: Not all objects can/should be created via Factory */
          /* Exception 1: NullParent */
          if( leafClassId == CL_NULL_PARENT || leafClassId == CL_SYNCHRONIZEABLE || leafClassId == CL_NETWORK_GAME_MANAGER )
          {
            PRINTF(1)("Can not create Class with ID %x!\n", (int)leafClassId);
            offset += syncDataLength;
            continue;
          }
          else
            b = Factory::fabricate( (ClassID)leafClassId );

          if ( !b )
          {
            PRINTF(1)("Could not fabricate Object with classID %x\n", leafClassId);
            offset += syncDataLength;
            continue;
          }

          if ( b->isA(CL_SYNCHRONIZEABLE) )
          {
            sync = dynamic_cast<Synchronizeable*>(b);
            sync->setUniqueID( uniqueId );
            sync->setSynchronized(true);
  
            PRINTF(0)("Fabricated %s with id %d\n", sync->getClassName(), sync->getUniqueID());
          }
          else
          {
            PRINTF(1)("Class with ID %x is not a synchronizeable!\n", (int)leafClassId);
            delete b;
            offset += syncDataLength;
            continue;
          }
        }
        

        int n = sync->setStateDiff( peer->second.userId, buf+offset, syncDataLength, state, fromState ); 
        offset += n;
        //NETPRINTF(0)("SSSSSEEEEETTTTT: %s %d\n",sync->getClassName(), n);

      }
      
      if ( offset != length )
      {
        PRINTF(0)("offset (%d) != length (%d)\n", offset, length);
        peer->second.socket->disconnectServer();
      }
      
      
      for ( SynchronizeableList::iterator it = synchronizeables.begin(); it != synchronizeables.end(); it++ )
      {
        Synchronizeable & sync = **it;
      
        if ( !sync.beSynchronized() || sync.getUniqueID() < 0 )
          continue;
      
        sync.handleRecvState( peer->second.userId, state, fromState );
      }
      
      assert( peer->second.lastAckedState <= ackedState );
      peer->second.lastAckedState = ackedState;
      
      assert( peer->second.lastRecvedState < state );
      peer->second.lastRecvedState = state;

    }
  
  }
  
}

/**
 * is executed when a handshake has finished
 * @todo create playable for new user
 */
void NetworkStream::handleNewClient( int userId )
{
  MessageManager::getInstance()->initUser( userId );
  
  networkGameManager->signalNewPlayer( userId );
}

/**
 * removes old items from oldSynchronizeables
 */
void NetworkStream::cleanUpOldSyncList( )
{
  int now = SDL_GetTicks();
  
  for ( std::map<int,int>::iterator it = oldSynchronizeables.begin(); it != oldSynchronizeables.end();  )
  {
    if ( it->second < now - 10*1000 )
    {
      std::map<int,int>::iterator delIt = it;
      it++;
      oldSynchronizeables.erase( delIt );
      continue;
    }
    it++;
  }
}

/**
 * writes data to DATA/dicts/newdict
 * @param data pointer to data
 * @param length length
 */
void NetworkStream::writeToNewDict( byte * data, int length, bool upstream )
{
  if ( remainingBytesToWriteToDict <= 0 )
    return;
  
  if ( length > remainingBytesToWriteToDict )
    length = remainingBytesToWriteToDict;
  
  std::string fileName = ResourceManager::getInstance()->getDataDir();
  fileName += "/dicts/newdict";
  
  if ( upstream )
    fileName += "_upstream";
  else
    fileName += "_downstream";
  
  FILE * f = fopen( fileName.c_str(), "a" );
  
  if ( !f )
  {
    PRINTF(2)("could not open %s\n", fileName.c_str());
    remainingBytesToWriteToDict = 0;
    return;
  }
  
  if ( fwrite( data, 1, length, f ) != length )
  {
    PRINTF(2)("could not write to file\n");
    fclose( f );
    return;
  }
  
  fclose( f );
  
  remainingBytesToWriteToDict -= length;  
}






