/* 
   Authentication session
   Copyright (C) 2003-2004, Lei Jiang <sledge10@hotmail.com>
   Copyright (C) 1999-2004, Joe Orton <joe@manyfish.co.uk>

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Library General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.
   
   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Library General Public License for more details.

   You should have received a copy of the GNU Library General Public
   License along with this library; if not, write to the Free
   Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
   MA 02111-1307, USA

   $Id: DavAuthSession.cpp 432 2008-10-27 05:44:14Z yone $
*/

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif /* HAVE_CONFIG_H */

#include <ctype.h>
#include <string.h>
#include <stdio.h>
#include <onion/DavAuthSession.h>
#include <onion/DavAuthManager.h>
#include <onion/DavRequest.h>
#include <onion/HandlerAuth.h>
#include <onion/HandlerAuthInfo.h>
#include <openssl/rand.h>
#include <openssl/des.h>
#include <openssl/md4.h>
#include <openssl/ssl.h>

/* NTLM auth copied from the code Daniel Stenberg submitted
 *  to neon mailing list
 * <http://mailman.webdav.org/pipermail/neon/2003-September/001413.html>
 */
#if OPENSSL_VERSION_NUMBER < 0x00907001L
#define DES_key_schedule	des_key_schedule
#define DES_cblock		des_cblock
#define DES_set_odd_parity	des_set_odd_parity
#define DES_set_key		des_set_key
#define DES_ecb_encrypt		des_ecb_encrypt

/* This is how things were done in the old days */
#define DESKEY(x)	x
#define DESKEYARG(x)	x
#else
/* Modern version */
#define DESKEYARG(x)	*x
#define DESKEY(x)	&x
#endif /* OPENSSL_VERSION_NUMBER */

#define SHORTPAIR(x) ((unsigned char)((x) & 0xff)), (unsigned char)(((x) >> 8))
#define LONGQUARTET(x) ((x) & 0xff), (((x) >> 8)&0xff), \
  (((x) >>16)&0xff), ((x)>>24)

/* Define this to make the type-3 message include the NT response message */
#undef USE_NTRESPONSES


/* binary version of base64 encode for NTLM auth, really stupid :-(( */
static const char *pszBase64 =  
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/=";

static int
base64Encode(const unsigned char *pszIn,
	     int nInLen,
	     unsigned char* pszOut,
	     int nBufLen)
{
  unsigned char* pTmp;
  int nOutLen;
    
  /* calculate needed output buffer size
   * This must be a multiple of 4 bytes */

  nOutLen = (nInLen*4)/3;

  if ((nInLen % 3) > 0) /* got to pad */
    nOutLen += 4 - (nInLen % 3);

  if(nOutLen >= nBufLen)
    return -1;
    
  /* now do the main stage of conversion, 3 bytes at a time,
   * leave the trailing bytes (if there are any) for later */

  for (pTmp=pszOut; nInLen>=3; nInLen-=3, pszIn+=3) {
    *(pTmp++) = pszBase64[ (*pszIn)>>2 ]; 
    *(pTmp++) = pszBase64[ ((*pszIn)<<4 & 0x30) | (*(pszIn+1))>>4 ]; 
    *(pTmp++) = pszBase64[ ((*(pszIn+1))<<2 & 0x3c) | (*(pszIn+2))>>6 ];
    *(pTmp++) = pszBase64[ (*(pszIn+2)) & 0x3f ];
  }

  /* Now deal with the trailing bytes */
  if (nInLen > 0) {
    /* We always have one trailing byte */
    *(pTmp++) = pszBase64[ (*pszIn)>>2 ];
    *(pTmp++) = pszBase64[ (((*pszIn)<<4 & 0x30) |
			    (nInLen==2?(*(pszIn+1))>>4:0)) ]; 
    *(pTmp++) = (nInLen==1?'=':pszBase64[ (*(pszIn+1))<<2 & 0x3c ]);
    *(pTmp++) = '=';
  }

  /* Null-terminate */
  *pTmp = '\0';

  return (pTmp - pszOut) + 1;
}

/* VALID_B64: fail if 'ch' is not a valid base64 character */
#define VALID_B64(ch) (((ch) >= 'A' && (ch) <= 'Z') || \
                       ((ch) >= 'a' && (ch) <= 'z') || \
                       ((ch) >= '0' && (ch) <= '9') || \
                       (ch) == '/' || (ch) == '+' || (ch) == '=')

/* DECODE_B64: decodes a valid base64 character. */
#define DECODE_B64(ch) ((ch) >= 'a' ? ((ch) + 26 - 'a') : \
                        ((ch) >= 'A' ? ((ch) - 'A') : \
                        ((ch) >= '0' ? ((ch) + 52 - '0') : \
                        ((ch) == '+' ? 62 : 63))))

static int
base64Decode(const unsigned char* pszIn,
	     int nInLen,
	     unsigned char* pszOut,
	     int nBufLen)
{
  unsigned char* pTmp;
  unsigned char* pOut;

  if (nInLen == 0 || (nInLen % 4) != 0) 
    return 0;

  if(nBufLen < (nInLen * 3 / 4))
    return -1;

  pOut = pszOut;

  for (pTmp = (unsigned char *)pszIn; *pTmp; pTmp += 4) {
    unsigned int unTmp;
    if (!VALID_B64(pTmp[0]) || !VALID_B64(pTmp[1]) || !VALID_B64(pTmp[2]) ||
	!VALID_B64(pTmp[3]) || pTmp[0] == '=' || pTmp[1] == '=' ||
	(pTmp[2] == '=' && pTmp[3] != '=')) {
      return -1;
    }

    unTmp =
      (DECODE_B64(pTmp[0]) & 0x3f) << 18 | (DECODE_B64(pTmp[1]) & 0x3f) << 12;
    *pOut++ = (unTmp >> 16) & 0xff;
    if (pTmp[2] != '=') {
      unTmp |= (DECODE_B64(pTmp[2]) & 0x3f) << 6;
      *pOut++ = (unTmp >> 8) & 0xff;
      if (pTmp[3] != '=') {
	unTmp |= DECODE_B64(pTmp[3]) & 0x3f;
	*pOut++ = unTmp & 0xff;
      }
    }
  }
    
  return pOut - pszOut;
}

/*
 * Turns a 56 bit key into the 64 bit, odd parity key and sets the key.  The
 * key schedule ks is also set.
 */
static void
setupDESKey(unsigned char* pszKey56,
	    DES_key_schedule DESKEYARG(ks))
{
  DES_cblock key;
  key[0] = pszKey56[0];
  key[1] = ((pszKey56[0] << 7) & 0xff) | (pszKey56[1] >> 1);
  key[2] = ((pszKey56[1] << 6) & 0xff) | (pszKey56[2] >> 2);
  key[3] = ((pszKey56[2] << 5) & 0xff) | (pszKey56[3] >> 3);
  key[4] = ((pszKey56[3] << 4) & 0xff) | (pszKey56[4] >> 4);
  key[5] = ((pszKey56[4] << 3) & 0xff) | (pszKey56[5] >> 5);
  key[6] = ((pszKey56[5] << 2) & 0xff) | (pszKey56[6] >> 6);
  key[7] = ((pszKey56[6] << 1) & 0xff) ;

  DES_set_odd_parity(&key);
  DES_set_key(&key, ks);
}

/*
 * takes a 21 byte array and treats it as 3 56-bit DES keys. The
 * 8 byte plaintext is encrypted with each key and the resulting 24
 * bytes are stored in the results array.
 */
static void
calcResponse(unsigned char* pszKeys,
	     unsigned char* pszPlainText,
	     unsigned char* pszResults)
{
  DES_key_schedule ks;

  setupDESKey(pszKeys, DESKEY(ks));
  DES_ecb_encrypt((DES_cblock*) pszPlainText, (DES_cblock*) pszResults,
		  DESKEY(ks), DES_ENCRYPT);

  setupDESKey(pszKeys + 7, DESKEY(ks));
  DES_ecb_encrypt((DES_cblock*) pszPlainText, (DES_cblock*) (pszResults + 8),
		  DESKEY(ks), DES_ENCRYPT);
  
  setupDESKey(pszKeys + 14, DESKEY(ks));
  DES_ecb_encrypt((DES_cblock*) pszPlainText, (DES_cblock*) (pszResults + 16),
		  DESKEY(ks), DES_ENCRYPT);
}

/*
 * Set up lanmanager and nt hashed passwords
 */
static void
makeHash(char* pszPassword,
	 unsigned char* pszNonce,
	 unsigned char* pszLMResponse
#ifdef USE_NTRESPONSES
	 , unsigned char *pszNTResp
#endif /* USE_NTRESPONSES */
	 )
{
  unsigned char chLMBuffer[21];
#ifdef USE_NTRESPONSES
  unsigned char chNTBuffer[21];
#endif /* USE_NTRESPONSES */
  unsigned char* pszPW;
  static const unsigned char chMagic[] = {
    0x4B, 0x47, 0x53, 0x21, 0x40, 0x23, 0x24, 0x25
  };
  int nIndex;
  int nLen = strlen(pszPassword);
  pszPW = (unsigned char*)malloc(nLen < 7 ? 14 : nLen * 2);
  if(!pszPW)
    return;

  if(nLen > 14)
    nLen = 14;

  for(nIndex = 0; nIndex < nLen; nIndex++)
    pszPW[nIndex] = toupper(pszPassword[nIndex]);

  for(; nIndex < 14; nIndex++)
    pszPW[nIndex] = 0;

  {
    //Create LanManager hashed password
    DES_key_schedule ks;

    setupDESKey(pszPW, DESKEY(ks));
    DES_ecb_encrypt((DES_cblock*)chMagic, (DES_cblock*)chLMBuffer,
		    DESKEY(ks), DES_ENCRYPT);

    setupDESKey(pszPW + 7, DESKEY(ks));
    DES_ecb_encrypt((DES_cblock*)chMagic, (DES_cblock*)(chLMBuffer + 8),
		    DESKEY(ks), DES_ENCRYPT);

    memset(chLMBuffer + 16, 0 , 5);
  }

  //create LM responses
  calcResponse(chLMBuffer, pszNonce, pszLMResponse);

#ifdef USE_NTRESPONSES

  {
    // create NT hashed password
    MD4_CTX MD4;
    nLen = strlen(pszPassword);
    for(nIndex = 0; nIndex < nLen; nIndex++)
      {
	pszPW[2 * nIndex] = pszPassword[nIndex];
	pszPW[2 * nIndex + 1] = 0;
      }
    
    MD4_Init(&MD4);
    MD4_Update(&MD4, pszPW, 2 * nLen);
    MD4_Final(chNTBuffer, &MD4);
    
    memset(chNTBuffer + 16, 0, 8);
  }
  
  calcResponse(chNTBuffer, pszNonce, pszNTResp);
  
#endif /* USE_NTRESPONSES */

  free(pszPW);
}


CDavAuthSession::CDavAuthSession(CDavAuthManager* pMgr,
				 const char* pszRequestHeader,
				 const char* pszResponseHeader,
				 const char* pszResponseInfoHeader,
				 unsigned int unStatusCode,
				 OI_AUTH_CLASS enuClass)
{
  OI_ASSERT(pMgr);
  OI_ASSERT(pszRequestHeader);
  OI_ASSERT(pszResponseHeader);
  OI_ASSERT(pszResponseInfoHeader);

  m_pAuthManager = pMgr;
  m_pszRequestHeader = pszRequestHeader;
  m_pszResponseHeader = pszResponseHeader;
  m_pszResponseInfoHeader =pszResponseInfoHeader;
  m_nStatusCode = (int)unStatusCode;
  m_enuClass = enuClass;

  Reset();
}

CDavAuthSession::~CDavAuthSession()
{
}

void 
CDavAuthSession::updateCNonce()
{
  unsigned char yData[256];

  if(RAND_pseudo_bytes(yData, sizeof(yData)) >= 0)
    MD5Hash((const void*)yData, sizeof(yData), m_strCNonce);
}

const char* 
CDavAuthSession::GetPasswd()
{
  return m_strPasswd.c_str();
}

const char* 
CDavAuthSession::GetUsername()
{
  return m_strUsername.c_str();
}

void 
CDavAuthSession::SetPasswd(const char* pszPasswd)
{
  if(pszPasswd)
    m_strPasswd = pszPasswd;
}

void 
CDavAuthSession::SetUsername(const char* pszUsername)
{
  if(pszUsername)
    m_strUsername = pszUsername;
}

const char* 
CDavAuthSession::GetAuthHeaderName()
{
  return m_pszResponseHeader;
}

const char* 
CDavAuthSession::GetAuthInfoHeaderName()
{
  return m_pszResponseInfoHeader;
}

OI_AUTH_SCHEME 
CDavAuthSession::GetAuthScheme()
{
  return m_enuScheme;
}

OI_AUTH_STATE 
CDavAuthSession::GetAuthState()
{
  return m_enuState;
}

void 
CDavAuthSession::SetAuthInfoHeader(const char* pszValue)
{
  if(pszValue){
    m_strAuthInfoHeader = pszValue;
    m_bGotAuthInfoHeader = true;
  }
}

void 
CDavAuthSession::Reset()
{
  m_enuScheme = H_AUTH_NOAUTH;
  m_enuState = S_AUTH_UNAUTHENTICATED;
  m_enuQop = Q_AUTH_NONE;

  m_unAttempt = 0;
  m_unNonceCount = 0;
  m_bGotQop = false;
  m_bGotAlgorithm = false;
  m_bGotStale = false;
  m_bGotOpaque = false;
  m_bGotNonce = false;
  m_bGotRealm = false;
  m_bGotScheme = false;
  m_bStale = false;
  m_bGotAuthHeader = false;
  m_bGotAuthInfoHeader = false;

  m_strUsername.erase();
  m_strPasswd.erase();
  m_strBasic.erase();
  m_strDigest.erase();
  m_strCNonce.erase();
  m_strAuthInfoHeader.erase();
  m_strHA1.erase();
  m_strNTLM.erase();
  m_strNTLMDomain.erase();
  m_strNTLMHost.erase();

}

OI_RESULT 
CDavAuthSession::OnCreateRequest(CDavRequest* pReq)
{
  pReq->AddHandler(new CHandlerAuth(this), m_pszResponseHeader);
  pReq->AddHandler(new CHandlerAuthInfo(this), m_pszResponseInfoHeader);
  return OI_OK;
}

/*TODO: error handling in this function*/
OI_RESULT
CDavAuthSession::ParseAuthHeader(const char* pszValue)
{
  m_bGotAuthHeader = true;

  CDavStringTokenizer headers(pszValue, ", ");
  OI_STRING_A strField;

  for(;headers.GetNextToken(strField);){
    CDavStringTokenizer field(strField.c_str(), "= ");
    OI_STRING_A strName, strValue;
    if(field.GetNextToken(strName)){
      MakeLowerA(strName);
      bool bIsScheme = !(field.GetNextToken(strValue));
      if(bIsScheme){
	m_bGotScheme = true;
	if(strName == "digest"){
	  m_enuScheme = H_AUTH_DIGEST;
	} else if(strName == "basic"){
	  m_enuScheme = H_AUTH_BASIC;
	} else if(strName == "ntlm"){
	  if(m_enuScheme != H_AUTH_NTLM)
	    Reset();
	  m_enuScheme = H_AUTH_NTLM;
	  if(!headers.GetNextToken(strValue)) {
	    /* got a Type-1 message */
	    m_enuState = S_AUTH_NTLMTYPE1;
	  } else {
	    /* type 2 message */
	    m_enuState = S_AUTH_NTLMTYPE2;
	    TrimLeftA(strValue);
	    TrimRightA(strValue);

	    int nInLen = strValue.length();
	    int nBufLen = nInLen * 4 / 3;
	    unsigned char* pszBuf = (unsigned char*)malloc(nBufLen);

	    OI_ASSERT(pszBuf);
	    int nOutLen = base64Decode((const unsigned char*)strValue.c_str(),
				       nInLen, pszBuf, nBufLen);
	    if(nOutLen >= 48) {
	      memcpy((void*)m_chNTLMNonce, pszBuf + 24, 8);
	    }
	    free(pszBuf);
	  }
	  m_enuScheme = H_AUTH_INVALID;
	} else {
	  m_enuScheme = H_AUTH_INVALID;
	}
      } else {
	//remove quotation marks from value
	OI_STRING_A::size_type npos = OI_STRING_A::npos;
	OI_STRING_A::size_type pos;

	for(; (pos = strValue.find_first_of('\"')) != npos;)
	  strValue.erase(pos, 1);

	for(; (pos = strValue.find_first_of('\'')) != npos;)
	  strValue.erase(pos, 1);

	if(strName == "realm"){
	  m_bGotRealm = true;
	  m_strRealm = strValue;
	} else if(strName == "nonce"){
	  m_bGotNonce = true;
	  m_strNonce = strValue;
	} else if(strName == "opaque"){
	  m_bGotOpaque = true;
	  m_strOpaque = strValue;
	} else if(strName == "stale"){
	  m_bGotStale = true;
	  m_bStale = (strValue == "true");
	} else if(strName == "algorithm"){
	  m_bGotAlgorithm = true;
	  MakeLowerA(strValue);
	  if(strValue == "md5"){
	    //TODO: handle this case
	    //m_enu
	  } else {
	    //TODO: unsupported algorithm
	    OI_ASSERT(false);
	  }
	} else if(strName == "qop"){
	  MakeLowerA(strValue);
	  m_bGotQop = true;
	  if(strValue == "auth"){
	    m_enuQop = Q_AUTH_AUTH;
	  } else if(strValue == "auth-int"){
	    m_enuQop = Q_AUTH_AUTHINT;
	  } else {
	    m_enuQop = Q_AUTH_NONE;
	  }
	} else {
	  //TODO: handle this case
	  OI_DEBUG("*** unknown auth header field: [%s] ***\n",
		   strName.c_str());
	  //OI_ASSERT(false);
	}
      }
    }
  }
  return OI_OK;
}

OI_RESULT 
CDavAuthSession::OnPreSendRequest(CDavRequest* pReq)
{
  switch(m_enuScheme){
  case H_AUTH_BASIC:
    pReq->AddRequestHeader(m_pszRequestHeader, m_strBasic.c_str());
    break;
  case H_AUTH_DIGEST:
    requestDigest(pReq);
    pReq->AddRequestHeader(m_pszRequestHeader, m_strDigest.c_str());
    break;
  case H_AUTH_NTLM:
    /* NTLM is connection oriented */
//    if(T_REQ_CONNECT == pReq->GetMethod()) 
//    pReq->AddRequestHeader(OI_REQHDR_PROXYCONN, "Keep-Alive");
    if(S_AUTH_NTLMTYPE1 == m_enuState ||
       S_AUTH_NTLMTYPE3 == m_enuState)
      pReq->AddRequestHeader(m_pszRequestHeader, m_strNTLM.c_str());
    break;
  case H_AUTH_NOAUTH:
  default:
    //H_AUTH_INVALID is handled in QueryEndRequest
    break;
  }
  return OI_OK;
}

OI_RESULT 
CDavAuthSession::verifyResponse(CDavRequest* pReq, 
				const char* pszValue)
{
  OI_RESULT enuRet = OIAEDIGMISMATCH;

  bool bGotNextNonce = false;
  bool bGotRspAuth = false;
  bool bGotCNonce = false;
  bool bGotQop = false;
  bool bGotNonceCount = false;

  OI_STRING_A strNextNonce;
  OI_STRING_A strRspAuth;
  OI_STRING_A strCNonce;
  OI_STRING_A strQop;
	
  unsigned int unNonceCount;
  OI_AUTH_QOP enuQop = Q_AUTH_NONE;

  if(m_enuScheme != H_AUTH_DIGEST)
    return OIAEHEADERSYNTAX;

  CDavStringTokenizer fields(pszValue, ", ");
  OI_STRING_A strValPair;
  for(;fields.GetNextToken(strValPair);){
    CDavStringTokenizer field(strValPair.c_str(), "= ");
    OI_STRING_A strName, strValue;
    if(field.GetNextToken(strName)){
      MakeLowerA(strName);
      bool bIsValid = field.GetNextToken(strValue);
      if(!bIsValid){
	//every field in RFC2069 has its own value,
	//however, 2069 did not mention the 'rspauth' field,
	//which appeared in 2831. Further investigation is needed
	return OIAEHEADERSYNTAX;
      }

      //removing quotations
      OI_STRING_A::size_type npos = OI_STRING_A::npos;
      OI_STRING_A::size_type pos;

      for(; (pos = strValue.find_first_of('\"')) != npos;)
	strValue.erase(pos, 1);

      for(; (pos = strValue.find_first_of('\'')) != npos;)
	strValue.erase(pos, 1);

      if(strName == "qop"){
	bGotQop = true;
	strQop = strValue;
	MakeLowerA(strValue);
	if(strQop == "auth-int"){
	  enuQop = Q_AUTH_AUTHINT;
	} else if(strQop == "auth"){
	  enuQop = Q_AUTH_AUTH;
	} else {
	  enuQop = Q_AUTH_NONE;
	}
      } else if(strName == "nextnonce"){
	bGotNextNonce	= true;
	strNextNonce	= strValue;
      }
      else if(strName == "rspauth"){
	bGotRspAuth = true;
	strRspAuth = strValue;
      } else if(strName == "cnonce"){
	bGotCNonce = true;
	strCNonce = strValue;
      } else if(strName == "nc"){
	bGotNonceCount = true;
	if(!sscanf(strValue.c_str(), "%x", &unNonceCount))
	  return OIAEHEADERSYNTAX;
      }
    }
  }

  if(bGotQop && (enuQop != Q_AUTH_NONE)){
    if(!(bGotRspAuth && bGotCNonce && bGotNonceCount))
      return OIAEHEADERSYNTAX;
    if(strCNonce != m_strCNonce){
      OI_DEBUG("cnonce mismatch!\n");
      return OIAECNONCEMISMATCH;
    } else if(unNonceCount != m_unNonceCount){
      OI_DEBUG("nonce mismatch!\n");
      return OIAENONCEMISMATCH;
    } else {
      OI_STRING_A strTmp = ":";

      strTmp += Escape(pReq->m_strURI.UTF8());
      if(enuQop == Q_AUTH_AUTHINT){
	//TODO: Implement this
	OI_ASSERT(false);
      }

      if(enuQop != Q_AUTH_NONE){
	m_strStoredDigest += strQop;
	m_strStoredDigest += ":";
      }
			
      OI_STRING_A strDigest;
      MD5HashString(strTmp, strDigest);
      m_strStoredDigest += strDigest;

      //H_A2
      MD5HashString(m_strStoredDigest, strDigest);
      MakeLowerA(strDigest);
      if(strDigest == strRspAuth){
	enuRet = OI_OK;
      } else {
	enuRet = OIAEDIGMISMATCH;
      }
    }
  } else {
    //No qop directive, auth okay
    enuRet = OI_OK;
  }

  if(bGotNextNonce){
    OI_DEBUG("got next nonce\n");
    m_strNonce = strNextNonce;
  }

  return enuRet;
}

OI_RESULT 
CDavAuthSession::QueryEndRequest(CDavRequest* pReq)
{
  OI_RESULT enuRet;
  int nStatusCode = pReq->GetStatusCode();

  OI_ASSERT(m_enuScheme != H_AUTH_INVALID);
  /*TODO:remove H_AUTH_INVALID*/
  
  if(m_bGotAuthInfoHeader) {
    enuRet = verifyResponse(pReq, m_strAuthInfoHeader.c_str());
    //reset the flag
    m_bGotAuthInfoHeader = false;
    return enuRet;
  } else if(nStatusCode == m_nStatusCode){
    //server-auth
    if(!m_bGotAuthHeader)
      return OIAEHEADERSYNTAX;
    enuRet = challenge();
    //reset the flag
    m_bGotAuthHeader = false;
    if(OI_OK != enuRet){
      return enuRet;
    } else {
      m_enuState = S_AUTH_UNAUTHENTICATED;
      return OI_RETRY;
    }
  } else {
    //TODO: handle other situations
  }
  return OI_OK;
}

OI_RESULT 
CDavAuthSession::challenge()
{
  if(m_bGotScheme){
    switch(m_enuScheme){
    case H_AUTH_BASIC:
      return basicChallenge();
    case H_AUTH_DIGEST:
      return digestChallenge();
    case H_AUTH_NTLM:
      return ntlmChallenge();
    default:
      return OIAEINVALIDSCHEME;
    }
  } else {
    return OIAEHEADERSYNTAX;
  }
  return OI_OK;
}

bool 
CDavAuthSession::requestCredential()
{
  X strRealm = (XMLByte*)(Unescape(m_strRealm).c_str());
  bool bRet = m_pAuthManager->OnAuthentication(strRealm,
					       m_strUsername,
					       m_strPasswd,
					       m_unAttempt,
					       m_enuClass);

  if(bRet)
    m_unAttempt++;
  return bRet;
}

OI_RESULT 
CDavAuthSession::basicChallenge()
{
  if(!m_bGotScheme)
    return OIAEHEADERSYNTAX;

  if(!requestCredential())
    return OIAEFAILEDTOGETCRED;

  OI_STRING_A strClearText = m_strUsername;
  strClearText += ":";
  strClearText += m_strPasswd;

  m_strBasic = "Basic ";
  m_strBasic += Base64Encode(strClearText);
  return OI_OK;
}

OI_RESULT 
CDavAuthSession::digestChallenge()
{
  OI_STRING_A strHash;

  if(!requestCredential())
    return OIAEFAILEDTOGETCRED;

  if(m_bGotStale)
    OI_DEBUG("got stale challenge\n");

  if(m_bGotQop)
    m_unNonceCount = 0;

  if(!m_bStale){
    strHash = m_strUsername;
    strHash += ":";
    strHash += m_strRealm;
    strHash += ":";
    strHash += m_strPasswd;
    MD5HashString(strHash, m_strHA1);
  }

  return OI_OK;
}

OI_RESULT
CDavAuthSession::ntlmChallenge()
{
  if(!m_bGotScheme)
    return OIAEHEADERSYNTAX;

  switch(m_enuState) {
  case S_AUTH_NTLMTYPE1:
    {
      if(!requestCredential())
	return OI_USERCANCELED;

      unsigned long unFlags, unDomainLen, unHostLen;
      unsigned long unHostOff, unDomainOff, unBufLen;

      unFlags = NTLMFLAG_NEGOTIATE_OEM | NTLMFLAG_NEGOTIATE_NTLM_KEY;
      unHostLen = m_strNTLMHost.length();
      unDomainLen = m_strNTLMDomain.length();
      unHostOff = 32;
      unDomainOff = unHostOff + unHostLen;
      unBufLen = unDomainOff + unDomainLen + 1;
      unsigned char* pszBuf = (unsigned char*)malloc(unBufLen);

      sprintf((char*)pszBuf,
	      "NTLMSSP%c"
	      "\x01%c%c%c" /* 32-bit type = 1 */
	      "%c%c%c%c"   /* 32-bit NTLM flag field */
	      "%c%c"  /* domain length */
	      "%c%c"  /* domain allocated space */
	      "%c%c"  /* domain name offset */
	      "%c%c"  /* 2 zeroes */
	      "%c%c"  /* host length */
	      "%c%c"  /* host allocated space */
	      "%c%c"  /* host name offset */
	      "%c%c"  /* 2 zeroes */
	      "%s"   /* host name */
	      "%s",  /* domain string */
	      0,     /* trailing zero */
	      0,0,0, /* part of type-1 long */
	      LONGQUARTET(NTLMFLAG_NEGOTIATE_OEM |	  /*   2 */
			  NTLMFLAG_NEGOTIATE_NTLM_KEY /* 200 */
			  ),
	      SHORTPAIR(unDomainLen),
	      SHORTPAIR(unDomainLen),
	      SHORTPAIR(unDomainOff),
	      0,0,
	      SHORTPAIR(unHostLen),
	      SHORTPAIR(unHostLen),
	      SHORTPAIR(unHostOff),
	      0,0,
	      m_strNTLMHost.c_str(),
	      m_strNTLMDomain.c_str());

      unsigned int unSize = unBufLen - 1;
      unsigned int unOutSize = unSize * 4 / 3;
      if((unSize % 3) > 0)
	unOutSize += 4 - (unSize % 3);

      unsigned char* pszOut = (unsigned char*)malloc(unOutSize + 1);
      base64Encode(pszBuf, unSize, pszOut, unOutSize + 1);

      m_strNTLM = "NTLM";
      m_strNTLM += (const char*)pszOut;

      free(pszOut);
      free(pszBuf);
      return OI_OK;
    }
  case S_AUTH_NTLMTYPE2:
    {
      /* We received the type-2 already, create a type-3 message:

      Index   Description            Content
      0       NTLMSSP Signature      Null-terminated ASCII "NTLMSSP"
      (0x4e544c4d53535000)
      8       NTLM Message Type      long (0x03000000)
      12      LM/LMv2 Response       security buffer(*)
      20      NTLM/NTLMv2 Response   security buffer(*)
      28      Domain Name            security buffer(*)
      36      User Name              security buffer(*)
      44      Workstation Name       security buffer(*)
      (52)    Session Key (optional) security buffer(*)
      (60)    Flags (optional)       long
      52 (64) start of data block
      */

      int nLMRespOff;
      int nNTRespOff;
      int nDomainOff;
      int nHostOff;
      int nUserOff;
      int nDomainLen;
      int nUsernameLen;
      int nHostLen;
      unsigned int unBufLen;
      unsigned char* pszBuf;
      unsigned char chLMResp[0x18]; //fixed-size
#ifdef USE_NTRESPONSES
      unsigned char chNTResp[0x18];
#endif /* USE_NTRESPONSES */

      OI_STRING_A strUsername;
      OI_STRING_A strDomain;

      OI_STRING_A::size_type nIndex = m_strUsername.find('\\');
      if(nIndex == OI_STRING_A::npos)
	nIndex = m_strUsername.find('/');
      if(nIndex == OI_STRING_A::npos) {
	strUsername = m_strUsername;
      } else {
	strDomain = m_strUsername.substr(0, nIndex);
	strUsername = m_strUsername.substr(nIndex + 1);
      }

      MakeUpperA(m_strUsername);
      makeHash((char*)m_strPasswd.c_str(), m_chNTLMNonce, chLMResp
#ifdef USE_NTRESPONSES
	       , chNTResp
#endif /* USE_NTRESPONSES */
	       );

      nDomainOff	= 64; //fixed
      nDomainLen	= strDomain.length();
      nUsernameLen	= strUsername.length();
      nHostLen		= m_strNTLMHost.length();
      nUserOff		= nDomainOff + nDomainLen;
      nHostOff		= nUserOff + nUsernameLen;
      nLMRespOff	= nHostOff + nHostLen;
      nNTRespOff	= nLMRespOff + 0x18;

      /* calculate buffer size */
#ifdef USE_NTRESPONSES
      unBufLen		= nNTRespOff + 0x18 + 1;
#else
      unBufLen		= nNTRespOff + 1;
#endif /* USE_NTRESPONSES */
      pszBuf = (unsigned char*)malloc(unBufLen);
      OI_ASSERT(pszBuf);

      int nSize = sprintf((char*)pszBuf,
			  "NTLMSSP%c"
			  "\x03%c%c%c" /* type-3, 32 bits */
		  
			  "%c%c%c%c" /* LanManager length + allocated space */
			  "%c%c" /* LanManager offset */
			  "%c%c" /* 2 zeroes */
			  
			  "%c%c" /* NT-response length */
			  "%c%c" /* NT-response allocated space */
			  "%c%c" /* NT-response offset */
			  "%c%c" /* 2 zeroes */
			  
			  "%c%c"  /* domain length */
			  "%c%c"  /* domain allocated space */
			  "%c%c"  /* domain name offset */
			  "%c%c"  /* 2 zeroes */
			  
			  "%c%c"  /* user length */
			  "%c%c"  /* user allocated space */
			  "%c%c"  /* user offset */
			  "%c%c"  /* 2 zeroes */
			  
			  "%c%c"  /* host length */
			  "%c%c"  /* host allocated space */
			  "%c%c"  /* host offset */
			  "%c%c%c%c%c%c"  /* 6 zeroes */
			  
			  "\xff\xff"  /* message length */
			  "%c%c"  /* 2 zeroes */
			  
			  "\x01\x82" /* flags */
			  "%c%c"  /* 2 zeroes */
			  
			  /* domain string */
			  /* user string */
			  /* host string */
			  /* LanManager response */
			  /* NT response */
			  ,
			  0, /* zero termination */
			  0,0,0, /* type-3 long, the 24 upper bits */
			  
			  SHORTPAIR(0x18),  /* LanManager response length, twice */
			  SHORTPAIR(0x18),
			  SHORTPAIR(nLMRespOff),
			  0x0, 0x0,
			  
#ifdef USE_NTRESPONSES
			  SHORTPAIR(0x18),  /* NT-response length, twice */
			  SHORTPAIR(0x18),
#else
			  0x0, 0x0,
			  0x0, 0x0,
#endif
			  SHORTPAIR(nNTRespOff),
			  0x0, 0x0,
			  
			  SHORTPAIR(nDomainLen),
			  SHORTPAIR(nDomainLen),
			  SHORTPAIR(nDomainOff),
			  0x0, 0x0,
			  
			  SHORTPAIR(nUsernameLen),
			  SHORTPAIR(nUsernameLen),
			  SHORTPAIR(nUserOff),
			  0x0, 0x0,
			  
			  SHORTPAIR(nHostLen),
			  SHORTPAIR(nHostLen),
			  SHORTPAIR(nHostOff),
			  0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
			  
			  0x0, 0x0,
			  
			  0x0, 0x0 );
      
      nSize = 64;
      pszBuf[62] = pszBuf[63] = 0;
      
      //copy domain name
      memcpy((void*)(pszBuf + nSize), strDomain.c_str(), nDomainLen);
      nSize += nDomainLen;
      
      //copy user name
      memcpy((void*)(pszBuf + nSize), strUsername.c_str(), nUsernameLen);
      nSize += nUsernameLen;

      //copy host  name
      memcpy((void*)(pszBuf + nSize), m_strNTLMHost.c_str(), nHostLen);
      nSize += nHostLen;

      if(nSize < ((int)unBufLen - 0x18)) {
	memcpy((void*)(pszBuf + nSize), chLMResp, 0x18);
	nSize += 0x18;
      }

#ifdef USE_NTRESPONSES
      if(nSize < ((int)unBufLen - 0x18)) {
	memcpy((void*)(pszBuf + nSize), chNTResp, 0x18);
	nSize += 0x18;
      }
#endif /* USE_NTRESPONSES */
      
      pszBuf[56] = nSize & 0xff;
      pszBuf[57] = nSize >> 8;
      
      *((unsigned int*)(pszBuf + 60)) = 0x206;

      unsigned unOutSize = (nSize*4)/3;

      if ((nSize % 3) > 0) /* got to pad */
	unOutSize += 4 - (nSize % 3);


      unsigned char* pszOut = (unsigned char*)malloc(unOutSize + 1);

      base64Encode(pszBuf, nSize, pszOut, unOutSize + 1);

      m_strNTLM = "NTLM ";
      m_strNTLM += (const char*)pszOut;
      m_enuState = S_AUTH_NTLMTYPE3;

      free(pszOut);
      free(pszBuf);

      return OI_OK;

    }
    break;
  default:
    break;
  }

  return OI_OK;
}

bool
CDavAuthSession::requestDigest(CDavRequest* pReq)
{
  OI_STRING_A strTmp, strQop, strNonceCount;
  OI_STRING_A strHA2, strRDigest;
  char chBuf[80];

  if(m_enuQop != Q_AUTH_NONE) {
    if(T_REQ_HEAD == pReq->GetMethod()) {
      m_unNonceCount = 0;
    } else {
      m_unNonceCount++;
    }
    sprintf(chBuf, "%08x", m_unNonceCount);
    strNonceCount = chBuf;
  }

  updateCNonce();
  switch(m_enuQop) {
  case Q_AUTH_AUTH:
    strQop = "auth";
    break;
  case Q_AUTH_AUTHINT:
    //TODO: implement
    strQop = "auth-int";
    OI_ASSERT(false);
    break;
  default:
    //TODO: error code
    OI_ASSERT(false);
    break;
  }
  strTmp = pReq->GetMethodStr();
  strTmp += ":";
  strTmp += Escape(pReq->m_strURI.UTF8());

  MD5HashString(strTmp, strHA2);
  strTmp = m_strHA1;
  strTmp += ":";
  strTmp += m_strNonce;
  strTmp += ":";

  if(m_enuQop != Q_AUTH_NONE) {
    strTmp += strNonceCount;
    strTmp += ":";
    strTmp += m_strCNonce;
    strTmp += ":";

    //store a copy of digest
    m_strStoredDigest = strTmp;

    strTmp += strQop;
    strTmp += ":";
  } else {
    //store a copy of digest
    m_strStoredDigest = strTmp;
  }

  strTmp += strHA2;

  MD5HashString(strTmp, strRDigest);

  m_strDigest = "Digest username=\"";
  m_strDigest += m_strUsername;
  m_strDigest += "\", realm=\"";
  m_strDigest += m_strRealm;
  m_strDigest += "\", nonce=\"";
  m_strDigest += m_strNonce;
  m_strDigest += "\", uri=\"";
  m_strDigest += Escape(pReq->m_strURI.UTF8());
  m_strDigest += "\", response=\"";
  m_strDigest += strRDigest;
  m_strDigest += "\", algorithm=\"";
  m_strDigest += "MD5";
  m_strDigest += "\"";

  if(m_bGotOpaque) {
    m_strDigest += ", opaque=\"";
    m_strDigest += m_strOpaque;
    m_strDigest += "\"";
  }

  if(m_enuQop != Q_AUTH_NONE) {
    m_strDigest += ", cnonce=\"";
    m_strDigest += m_strCNonce;
    m_strDigest += "\", nc=\"";
    m_strDigest += strNonceCount;
    m_strDigest += "\", qop=\"";
    m_strDigest += strQop;
    m_strDigest += "\"";
  }

  return true;
}
