/* 
   Socket wrapper for http operations
   Copyright (C) 2003-2004, Lei Jiang <sledge10@hotmail.com>
   Copyright (C) 1998-2004, Joe Orton <joe@manyfish.co.uk>,
   Copyright (C) 1999-2000, Tommi Komulainen <Tommi.Komulainen@iki.fi>

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; if not, write to the Free
   Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
   MA 02111-1307, USA

   $Id: DavSocket.cpp 498 2015-05-19 02:12:43Z yone $
*/

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif /* HAVE_CONFIG_H */

#if defined(HAVE_SYS_TYPES_H) || defined(WIN32)
#include <sys/types.h>
#endif /* HAVE_SYS_TYPES_H */

#ifdef WIN32
#  include <winsock2.h>
#  include <stddef.h>
#endif

#ifdef HAVE_SYS_SOCKET_H
#  include <sys/socket.h>
#endif /* HAVE_SYS_SOCKET_H */
#ifdef HAVE_UNISTD_H
#  include <unistd.h>
#endif /* HAVE_UNISTD_H */
#ifdef HAVE_NETDB_H
#  include <netdb.h>
#endif /* HAVE_NETDB_H */
#ifdef HAVE_NETINET_IN_H
#  include <netinet/in.h>
#endif /* HAVE_NETINET_IN_H */
#ifdef HAVE_ARPA_INET_H
#  include <arpa/inet.h>
#endif /* HAVE_ARPA_INET_H */


#include <onion/DavSocket.h>
#include <onion/DavSocketIORaw.h>
#include <onion/DavSocketIOSSL.h>
#include <onion/DavWorkSession.h>

#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/pkcs12.h> /* for PKCS12_PBE_add */
#include <openssl/rand.h>

#ifndef INADDR_NONE
#define INADDR_NONE ((unsigned long) -1)
#endif /* INADDR_NONE */

extern SSL_CTX* g_pSSLCtx;

OI_RESULT
Onion_initSocket()
{
#ifdef WIN32
  WSADATA wsaData;
  WORD wVersionRequested = MAKEWORD(OI_SOCK_VER_MINOR, 
				    OI_SOCK_VER_MAJOR);
  int nResult = WSAStartup(wVersionRequested, &wsaData);
  if(nResult != 0)
    return OISEINITFAILED;

  if(LOBYTE(wsaData.wVersion) != OI_SOCK_VER_MINOR
     || HIBYTE(wsaData.wVersion) != OI_SOCK_VER_MAJOR){
    WSACleanup();
    return OISESOCKVER;
  }
#endif /* WIN32 */
  return OI_OK;
}

void
Onion_cleanupSocket()
{
#ifdef WIN32
  int nResult = WSACleanup();
#endif /* WIN32 */
}

static int
fnSSLVerifyCallback(int nOK, X509_STORE_CTX *pStoreCtx)
{
  if(nOK == 0){
    unsigned long ulFailures = 0;
    X509* pX509 = X509_STORE_CTX_get_current_cert(pStoreCtx);
    int nError = X509_STORE_CTX_get_error(pStoreCtx);
    SSL* pSSL = 
      (SSL*)X509_STORE_CTX_get_ex_data(pStoreCtx,
				       SSL_get_ex_data_X509_STORE_CTX_idx());

    CDavSocket* pSocket = (CDavSocket*)SSL_get_app_data(pSSL);
    OI_DEBUG("SSL certificate problem: %s\n",
	     X509_verify_cert_error_string(nError));

    switch(nError){
    case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
      ulFailures = OI_SSL_ERR_UNABLE_TO_GET_ISSUER_CERT;
      break;
    case X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE:
      ulFailures = OI_SSL_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE;
      break;
    case X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY:
      ulFailures = OI_SSL_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY;
      break;
    case X509_V_ERR_CERT_SIGNATURE_FAILURE:
      ulFailures = OI_SSL_ERR_CERT_SIGNATURE_FAILURE;
      break;
    case X509_V_ERR_CERT_NOT_YET_VALID:
      ulFailures = OI_SSL_ERR_CERT_NOT_YET_VALID;
      break;
    case X509_V_ERR_CERT_HAS_EXPIRED:
      ulFailures = OI_SSL_ERR_CERT_HAS_EXPIRED;
      break;
    case X509_V_ERR_OUT_OF_MEM:
      ulFailures = OI_SSL_ERR_OUT_OF_MEM;
      break;
    case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
      ulFailures = OI_SSL_ERR_DEPTH_ZERO_SELF_SIGNED_CERT;
      break;
    case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
      ulFailures = OI_SSL_ERR_SELF_SIGNED_CERT_IN_CHAIN;
      break;
    case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY:
      ulFailures = OI_SSL_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY;
      break;
    case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE:
      ulFailures = OI_SSL_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE;
      break;
    case X509_V_ERR_CERT_CHAIN_TOO_LONG:
      ulFailures = OI_SSL_ERR_CERT_CHAIN_TOO_LONG;
      break;
    case X509_V_ERR_CERT_REVOKED:
      ulFailures = OI_SSL_ERR_CERT_REVOKED;
      break;
    case X509_V_ERR_INVALID_CA:
      ulFailures = OI_SSL_ERR_INVALID_CA;
      break;
    case X509_V_ERR_PATH_LENGTH_EXCEEDED:
      ulFailures = OI_SSL_ERR_PATH_LENGTH_EXCEEDED;
      break;
    case X509_V_ERR_INVALID_PURPOSE:
      ulFailures = OI_SSL_ERR_INVALID_PURPOSE;
      break;
    case X509_V_ERR_CERT_UNTRUSTED:
      ulFailures = OI_SSL_ERR_CERT_UNTRUSTED;
      break;
    case X509_V_ERR_CERT_REJECTED:
      ulFailures = OI_SSL_ERR_CERT_REJECTED;
      break;
    case X509_V_ERR_SUBJECT_ISSUER_MISMATCH:
      ulFailures = OI_SSL_ERR_SUBJECT_ISSUER_MISMATCH;
      break;
    case X509_V_ERR_AKID_SKID_MISMATCH:
      ulFailures = OI_SSL_ERR_AKID_SKID_MISMATCH;
      break;
    case X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH:
      ulFailures = OI_SSL_ERR_AKID_ISSUER_SERIAL_MISMATCH;
      break;
    case X509_V_ERR_KEYUSAGE_NO_CERTSIGN:
      ulFailures = OI_SSL_ERR_KEYUSAGE_NO_CERTSIGN;
      break;
    default:
      OI_ASSERT(false);
      break;
    }
    pSocket->AddUnauthenticCert(pX509, ulFailures);
  }
  return 1;
}

CDavSocket::CDavSocket(CDavWorkSession* pSession):
  m_pSession(pSession),
  m_hSocket(0)
{
  m_bConnected = false;
  m_unTimeout = OI_SOCK_TIMEOUTSEC;

  m_pIOFuncRaw = new CDavSocketIORaw;
  m_pIOFuncSSL = new CDavSocketIOSSL;
  m_pszBuffer = new char[OI_SOCKBUFSIZE];
  m_pszBufPos = m_pszBuffer;
  m_unContentLen = 0;

  //statistics
  m_unReadCounter = 0;	
  m_unWriteCounter = 0;	

  initSSL();
  UseSSL(false);
}

CDavSocket::~CDavSocket(void)
{
  cleanupSSL();
  cleanupSSLSession();
  Disconnect();
  delete m_pIOFuncRaw;
  delete m_pIOFuncSSL;
  delete[] m_pszBuffer;
}

OI_RESULT
CDavSocket::initSSL()
{
  m_pSSL = NULL;
  m_pSSLSession = NULL;
  m_pServerCertificate= NULL;
  return OI_OK;
}

OI_RESULT
CDavSocket::connect(const struct sockaddr* pSockaddr, 
		    int nSockaddrLen)
{
  int nError;
  m_hSocket = (int)socket(AF_INET, SOCK_STREAM, 0);
  if(m_hSocket == ONION_INVALIDSOCKET)
    return OISEINVALIDSOCK;
  nError = ::connect(m_hSocket, pSockaddr, nSockaddrLen);
  if(nError) {
    /* TODO: error code processing for every platform */
    /*
      int nSockError = onion_errno;
      switch(nSockError)
      {
      case WSAETIMEDOUT:
        return OISETIMEOUT;
        break;
	......
      }
    */
    Disconnect();
    return OISECONNFAILED;
  }
  m_bConnected = true;
  return OI_OK;
}

void
CDavSocket::Disconnect()
{
  if(m_hSocket > 0){
    ResetBuffer();
    ONION_CLOSE(m_hSocket);
    m_bConnected = false;
    m_hSocket = 0;
  }
}

void
CDavSocket::ResetBuffer()
{
  m_pszBufPos = m_pszBuffer;
  m_unContentLen = 0;
}

SSL*
CDavSocket::getSSL()
{
  return m_pSSL;
}

void
CDavSocket::cleanupSSL()
{
  if(m_pSSL){
	SSL_set_session(m_pSSL, NULL);
    SSL_shutdown(m_pSSL);
    SSL_free(m_pSSL);
    m_pSSL = NULL;
  }
}

/*
 *	release cached SSL_SESSION and set pointer to null.
 *	use this function in to reset session and cached 
 *	server certificate whenever hostname/ssl mode changes
 *	or a connection is terminated
 */
void
CDavSocket::cleanupSSLSession()
{
  if(m_pSSLSession){
    SSL_SESSION_free(m_pSSLSession);
    m_pSSLSession = NULL;
  }
  if(m_pServerCertificate){
    X509_free(m_pServerCertificate);
    m_pServerCertificate = NULL;
  }
}

OI_RESULT
CDavSocket::connectSSL()
{
  cleanupSSL();	//free ssl connection and set pointer to NULL

  OI_ASSERT(g_pSSLCtx);
  m_pSSL = SSL_new(g_pSSLCtx);
  if(!m_pSSL)
    return OIEEDATAINITFAILED;

  UseSSL(true);	//setting the IO function

  SSL_set_mode(m_pSSL, SSL_MODE_AUTO_RETRY);
  SSL_set_options(m_pSSL, SSL_OP_ALL);
  SSL_set_fd(m_pSSL, m_hSocket);
  SSL_set_app_data(m_pSSL, (char*)this);
  SSL_set_verify(m_pSSL, SSL_VERIFY_PEER, fnSSLVerifyCallback);

  if(m_pSSLSession)
	  SSL_set_session(m_pSSL, m_pSSLSession);	//resume last session.

  int nResult = SSL_connect(m_pSSL);
  if(nResult != 1){
    OI_DEBUG("Error connecting via SSL: %s\n",
	     ERR_reason_error_string(ERR_get_error()));
    ERR_clear_error();
    cleanupSSL();
    cleanupSSLSession();

    if(nResult == 0) {
      /*handshake failed*/
      return OIEEHANDSHAKEFAILED;
    } else if(nResult < 0) {
      /*fatal connection error*/
      return OIEEFATALCONNERROR;
    } else {
      /*the impossible happened. openssl's problem*/
      return OIEEGENERIC;
      /*TODO: try to find where the problem is*/
      OI_ASSERT(false);
    }
    return OISEGEN;
  } else {
    if(!m_pSSLSession)
      //cache the session(session reference is increamented)
      m_pSSLSession = SSL_get1_session(m_pSSL);
  }

  //check server certificate
  X509* pServerCertificate = SSL_get_peer_certificate(m_pSSL);
  if(!pServerCertificate){
    cleanupSSL();
    cleanupSSLSession();
    return OIEENOSERVERCERT;
  }

  if(m_pServerCertificate){
    int nCmp = X509_cmp(pServerCertificate, m_pServerCertificate);
    X509_free(pServerCertificate);
    if(nCmp != 0){
      cleanupSSL();
      cleanupSSLSession();
      return OIEECERTCHANGED;	//could be MITM attack
    }
  }
  return OI_OK;
}

bool
CDavSocket::UseSSL(bool bUse)
{
  bool bRet = IsUsingSSL();
  if(bUse)
    m_pIOFunc = m_pIOFuncSSL;
  else
    m_pIOFunc = m_pIOFuncRaw;
  ResetBuffer();
  return bRet;
}

bool
CDavSocket::IsUsingSSL()
{
  return (m_pIOFunc == m_pIOFuncSSL);
}

bool
CDavSocket::HasSSLConnection()
{
  return (m_pSSLSession != NULL); 
}

OI_RESULT
CDavSocket::Connect(const char* pszAddr, unsigned short usPort)
{
  sockaddr_in sockAddr;

  OI_ASSERT(pszAddr);

  memset(&sockAddr, 0, sizeof(sockAddr));

  sockAddr.sin_family = AF_INET;
  sockAddr.sin_addr.s_addr = inet_addr(pszAddr);

  if (sockAddr.sin_addr.s_addr == INADDR_NONE){
    hostent* pHost;
    pHost = gethostbyname(pszAddr);
    if (pHost){
      sockAddr.sin_addr.s_addr = ((in_addr*)pHost->h_addr)->s_addr;
    } else {
      //WSASetLastError(WSAEINVAL);
      return OISEHOSTNOTFOUND;
    }
  }
  sockAddr.sin_port = htons(usPort);

  return connect((sockaddr*)&sockAddr, sizeof(sockAddr));
}

int
CDavSocket::GetHandle()
{
  return m_hSocket;
}

unsigned int
CDavSocket::GetTimeout()
{
  return m_unTimeout;
}

void
CDavSocket::SetTimeout(unsigned int unTimeout)
{
  if(unTimeout > OI_SOCK_MAXTIMEOUTSEC)
    m_unTimeout = OI_SOCK_MAXTIMEOUTSEC;
  else
    m_unTimeout = unTimeout;
}

size_t
CDavSocket::GetReadBytes()
{
  return m_unReadCounter;
}

size_t
CDavSocket::GetWrittenBytes()
{
  return m_unWriteCounter;
}

CDavWorkSession*
CDavSocket::GetSession()
{
  return m_pSession;
}

OI_RESULT
CDavSocket::Read(char *pszBuf, size_t *punReadLen)
{
  OI_ASSERT(pszBuf && punReadLen);

  OI_RESULT enuRet;
  size_t unBufLen = *punReadLen;
  size_t unBytes = unBufLen;

  if (m_unContentLen > 0){
    // Deliver buffered data.
    if (unBufLen > m_unContentLen)
      unBufLen = m_unContentLen;
    memcpy(pszBuf, m_pszBufPos, unBufLen);
    m_pszBufPos += unBufLen;
    m_unContentLen -= unBufLen;
    //save read length into the parameter
    *punReadLen = unBufLen;
    m_unReadCounter += unBufLen;
    return OI_OK;
  } else if (unBufLen >= OI_SOCKBUFSIZE){
    /* No need for read buffer. */
    enuRet =  m_pIOFunc->Read(this,
			      pszBuf,
			      punReadLen,
			      m_unTimeout);
    m_unReadCounter += *punReadLen;
    return enuRet;
  } else {
    // the internal buffer is empty and 
    // external buffer is smaller
    /* Fill read buffer. */
    OI_RESULT enuReadRet;
    unBufLen = OI_SOCKBUFSIZE;
    enuReadRet = m_pIOFunc->Read(this,
				 m_pszBuffer,
				 &unBufLen,
				 m_unTimeout);
    
    if (OI_OK != enuReadRet){
      *punReadLen = 0;
      return enuReadRet;
    }
    if (unBytes > unBufLen) 
      unBytes = unBufLen;
    memcpy(pszBuf, m_pszBuffer, unBytes);
    m_pszBufPos = m_pszBuffer + unBytes;
    m_unContentLen = unBufLen - unBytes;
    //save read length into the parameter
    *punReadLen = unBytes;
    m_unReadCounter += unBytes;
    return OI_OK; 
  }
}

OI_RESULT
CDavSocket::ReadLine(char *pszBuf, size_t *punReadLen)
{
  char *pszLF;
  size_t unLen;
  size_t unBufLen = *punReadLen;
	
  if ((pszLF = (char*)memchr(m_pszBufPos, '\n', m_unContentLen)) == 0
      && m_unContentLen < OI_SOCKBUFSIZE){
    /* The buffered data does not contain a complete line: move it
     * to the beginning of the buffer. */
    if (m_unContentLen > 0)
      memmove(m_pszBuffer, m_pszBufPos, m_unContentLen);
    m_pszBufPos = m_pszBuffer;
    /* Loop filling the buffer whilst no newline is found in the data
     * buffered so far, and there is still buffer space available */ 
    do {
      /* Read more data onto end of buffer. */
      size_t unReadLen = OI_SOCKBUFSIZE - m_unContentLen;
      OI_RESULT enuResult;
      enuResult = m_pIOFunc->Read(this,
				  m_pszBuffer + m_unContentLen, 
				  &unReadLen,
				  m_unTimeout);

      if (OI_OK != enuResult)
	return enuResult;
      m_unContentLen += unReadLen;
    } while ((pszLF = (char*)memchr(m_pszBuffer, '\n', m_unContentLen)) == 0
	     && m_unContentLen < OI_SOCKBUFSIZE);
    }

  if (pszLF)
    unLen = (size_t)(pszLF - m_pszBufPos + 1);
  else
    unLen = unBufLen; /* fall into "line too long" error... */
  
  if ((unLen + 1) > unBufLen){
    return OISELINETOOLONG;
  }
  
  memcpy(pszBuf, m_pszBufPos, unLen);
  pszBuf[unLen] = '\0';
  /* consume the line from buffer: */
  m_unContentLen -= unLen;
  m_pszBufPos += unLen;
  *punReadLen = unLen;
  m_unReadCounter += unLen;
  return OI_OK;
}

OI_RESULT
CDavSocket::FullRead(char *pszBuf, size_t *punReadLen)
{
  OI_RESULT enuRet;
  size_t unBufLen = *punReadLen;
  size_t unReadLen = unBufLen;

  while(unBufLen > 0){
    enuRet = Read(pszBuf, &unReadLen);
    if(OI_OK != enuRet)
      return enuRet;

    unBufLen -= unReadLen;
    pszBuf += unReadLen;
  }
  return OI_OK;
}

OI_RESULT
CDavSocket::Peek(char *pszBuf, size_t *punReadLen)
{
  size_t unBufLen = *punReadLen;
  size_t unBytes;
  OI_RESULT enuRet;
  if (m_unContentLen){
    /* just return buffered data. */
    unBytes = m_unContentLen;
  } else {
    /* fill the buffer. */
    unBytes = OI_SOCKBUFSIZE;
    enuRet = Read(m_pszBuffer,  &unBytes);
    if (OI_OK != enuRet)
      return enuRet;
    m_pszBufPos = m_pszBuffer;
    m_unContentLen = unBytes;
  }

  if (unBufLen > unBytes)
    unBufLen = unBytes;

  memcpy(pszBuf, m_pszBufPos, unBufLen);
  *punReadLen = unBufLen;
  return OI_OK;
}

OI_RESULT
CDavSocket::Write(const char *pszBuf, size_t *punWrittenLen)
{
  OI_ASSERT(pszBuf && punWrittenLen);
  m_unWriteCounter += *punWrittenLen;
  return m_pIOFunc->Write(this, pszBuf, punWrittenLen);
}

/*
 *	just pass the parameters to work session, circumvent
 *	the protected method. pX509 should not be incremented
 *	before calling this function
 */
void
CDavSocket::AddUnauthenticCert(X509* pX509, unsigned long ulFailures)
{
  m_pSession->addUnauthenticCert(pX509, ulFailures);
}

