#include "sspi-private.h"
#include "debug-private.h"
#pragma comment(lib, "Crypt32.lib")
#pragma comment(lib, "Secur32.lib")
#pragma comment(lib, "Ws2_32.lib")
#if !defined(SECURITY_FLAG_IGNORE_UNKNOWN_CA)
# define SECURITY_FLAG_IGNORE_UNKNOWN_CA 0x00000100
#endif
#if !defined(SECURITY_FLAG_IGNORE_CERT_DATE_INVALID)
# define SECURITY_FLAG_IGNORE_CERT_DATE_INVALID 0x00002000
#endif
static DWORD sspi_verify_certificate(PCCERT_CONTEXT serverCert,
const CHAR *serverName,
DWORD dwCertFlags);
_sspi_struct_t*
_sspiAlloc(void)
{
_sspi_struct_t *conn = calloc(sizeof(_sspi_struct_t), 1);
if (conn)
conn->sock = INVALID_SOCKET;
return (conn);
}
BOOL
_sspiGetCredentials(_sspi_struct_t *conn,
const LPWSTR container,
const TCHAR *cn,
BOOL isServer)
{
HCERTSTORE store = NULL;
PCCERT_CONTEXT storedContext = NULL;
PCCERT_CONTEXT createdContext = NULL;
DWORD dwSize = 0;
PBYTE p = NULL;
HCRYPTPROV hProv = (HCRYPTPROV) NULL;
CERT_NAME_BLOB sib;
SCHANNEL_CRED SchannelCred;
TimeStamp tsExpiry;
SECURITY_STATUS Status;
HCRYPTKEY hKey = (HCRYPTKEY) NULL;
CRYPT_KEY_PROV_INFO kpi;
SYSTEMTIME et;
CERT_EXTENSIONS exts;
CRYPT_KEY_PROV_INFO ckp;
BOOL ok = TRUE;
if (!conn)
return (FALSE);
if (!cn)
return (FALSE);
if (!CryptAcquireContextW(&hProv, (LPWSTR) container, MS_DEF_PROV_W,
PROV_RSA_FULL,
CRYPT_NEWKEYSET | CRYPT_MACHINE_KEYSET))
{
if (GetLastError() == NTE_EXISTS)
{
if (!CryptAcquireContextW(&hProv, (LPWSTR) container, MS_DEF_PROV_W,
PROV_RSA_FULL, CRYPT_MACHINE_KEYSET))
{
DEBUG_printf(("_sspiGetCredentials: CryptAcquireContext failed: %x\n",
GetLastError()));
ok = FALSE;
goto cleanup;
}
}
}
store = CertOpenStore(CERT_STORE_PROV_SYSTEM,
X509_ASN_ENCODING|PKCS_7_ASN_ENCODING,
hProv,
CERT_SYSTEM_STORE_LOCAL_MACHINE |
CERT_STORE_NO_CRYPT_RELEASE_FLAG |
CERT_STORE_OPEN_EXISTING_FLAG,
L"MY");
if (!store)
{
DEBUG_printf(("_sspiGetCredentials: CertOpenSystemStore failed: %x\n",
GetLastError()));
ok = FALSE;
goto cleanup;
}
dwSize = 0;
if (!CertStrToName(X509_ASN_ENCODING, cn, CERT_OID_NAME_STR,
NULL, NULL, &dwSize, NULL))
{
DEBUG_printf(("_sspiGetCredentials: CertStrToName failed: %x\n",
GetLastError()));
ok = FALSE;
goto cleanup;
}
p = (PBYTE) malloc(dwSize);
if (!p)
{
DEBUG_printf(("_sspiGetCredentials: malloc failed for %d bytes", dwSize));
ok = FALSE;
goto cleanup;
}
if (!CertStrToName(X509_ASN_ENCODING, cn, CERT_OID_NAME_STR, NULL,
p, &dwSize, NULL))
{
DEBUG_printf(("_sspiGetCredentials: CertStrToName failed: %x",
GetLastError()));
ok = FALSE;
goto cleanup;
}
sib.cbData = dwSize;
sib.pbData = p;
storedContext = CertFindCertificateInStore(store, X509_ASN_ENCODING|PKCS_7_ASN_ENCODING,
0, CERT_FIND_SUBJECT_NAME, &sib, NULL);
if (!storedContext)
{
if (!CryptGenKey(hProv, AT_KEYEXCHANGE, CRYPT_EXPORTABLE, &hKey))
{
DEBUG_printf(("_sspiGetCredentials: CryptGenKey failed: %x",
GetLastError()));
ok = FALSE;
goto cleanup;
}
ZeroMemory(&kpi, sizeof(kpi));
kpi.pwszContainerName = (LPWSTR) container;
kpi.pwszProvName = MS_DEF_PROV_W;
kpi.dwProvType = PROV_RSA_FULL;
kpi.dwFlags = CERT_SET_KEY_CONTEXT_PROP_ID;
kpi.dwKeySpec = AT_KEYEXCHANGE;
GetSystemTime(&et);
et.wYear += 10;
ZeroMemory(&exts, sizeof(exts));
createdContext = CertCreateSelfSignCertificate(hProv, &sib, 0, &kpi, NULL, NULL,
&et, &exts);
if (!createdContext)
{
DEBUG_printf(("_sspiGetCredentials: CertCreateSelfSignCertificate failed: %x",
GetLastError()));
ok = FALSE;
goto cleanup;
}
if (!CertAddCertificateContextToStore(store, createdContext,
CERT_STORE_ADD_REPLACE_EXISTING,
&storedContext))
{
DEBUG_printf(("_sspiGetCredentials: CertAddCertificateContextToStore failed: %x",
GetLastError()));
ok = FALSE;
goto cleanup;
}
ZeroMemory(&ckp, sizeof(ckp));
ckp.pwszContainerName = (LPWSTR) container;
ckp.pwszProvName = MS_DEF_PROV_W;
ckp.dwProvType = PROV_RSA_FULL;
ckp.dwFlags = CRYPT_MACHINE_KEYSET;
ckp.dwKeySpec = AT_KEYEXCHANGE;
if (!CertSetCertificateContextProperty(storedContext,
CERT_KEY_PROV_INFO_PROP_ID,
0, &ckp))
{
DEBUG_printf(("_sspiGetCredentials: CertSetCertificateContextProperty failed: %x",
GetLastError()));
ok = FALSE;
goto cleanup;
}
}
ZeroMemory(&SchannelCred, sizeof(SchannelCred));
SchannelCred.dwVersion = SCHANNEL_CRED_VERSION;
SchannelCred.cCreds = 1;
SchannelCred.paCred = &storedContext;
if (isServer)
SchannelCred.grbitEnabledProtocols = SP_PROT_SSL3TLS1;
Status = AcquireCredentialsHandle(NULL, UNISP_NAME,
isServer ? SECPKG_CRED_INBOUND:SECPKG_CRED_OUTBOUND,
NULL, &SchannelCred, NULL, NULL, &conn->creds,
&tsExpiry);
if (Status != SEC_E_OK)
{
DEBUG_printf(("_sspiGetCredentials: AcquireCredentialsHandle failed: %x", Status));
ok = FALSE;
goto cleanup;
}
cleanup:
if (hKey)
CryptDestroyKey(hKey);
if (createdContext)
CertFreeCertificateContext(createdContext);
if (storedContext)
CertFreeCertificateContext(storedContext);
if (p)
free(p);
if (store)
CertCloseStore(store, 0);
if (hProv)
CryptReleaseContext(hProv, 0);
return (ok);
}
BOOL
_sspiConnect(_sspi_struct_t *conn,
const CHAR *hostname)
{
PCCERT_CONTEXT serverCert;
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
SECURITY_STATUS scRet;
DWORD cbData;
SecBufferDesc inBuffer;
SecBuffer inBuffers[2];
SecBufferDesc outBuffer;
SecBuffer outBuffers[1];
BOOL ok = TRUE;
serverCert = NULL;
dwSSPIFlags = ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY |
ISC_RET_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
outBuffers[0].pvBuffer = NULL;
outBuffers[0].BufferType = SECBUFFER_TOKEN;
outBuffers[0].cbBuffer = 0;
outBuffer.cBuffers = 1;
outBuffer.pBuffers = outBuffers;
outBuffer.ulVersion = SECBUFFER_VERSION;
scRet = InitializeSecurityContext(&conn->creds, NULL, TEXT(""), dwSSPIFlags,
0, SECURITY_NATIVE_DREP, NULL, 0, &conn->context,
&outBuffer, &dwSSPIOutFlags, &tsExpiry);
if (scRet != SEC_I_CONTINUE_NEEDED)
{
DEBUG_printf(("_sspiConnect: InitializeSecurityContext(1) failed: %x", scRet));
ok = FALSE;
goto cleanup;
}
if (outBuffers[0].cbBuffer && outBuffers[0].pvBuffer)
{
cbData = send(conn->sock, outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0);
if ((cbData == SOCKET_ERROR) || !cbData)
{
DEBUG_printf(("_sspiConnect: send failed: %d", WSAGetLastError()));
FreeContextBuffer(outBuffers[0].pvBuffer);
DeleteSecurityContext(&conn->context);
ok = FALSE;
goto cleanup;
}
DEBUG_printf(("_sspiConnect: %d bytes of handshake data sent", cbData));
FreeContextBuffer(outBuffers[0].pvBuffer);
outBuffers[0].pvBuffer = NULL;
}
dwSSPIFlags = ISC_REQ_MANUAL_CRED_VALIDATION |
ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT |
ISC_REQ_CONFIDENTIALITY |
ISC_RET_EXTENDED_ERROR |
ISC_REQ_ALLOCATE_MEMORY |
ISC_REQ_STREAM;
conn->decryptBufferUsed = 0;
scRet = SEC_I_CONTINUE_NEEDED;
while(scRet == SEC_I_CONTINUE_NEEDED ||
scRet == SEC_E_INCOMPLETE_MESSAGE ||
scRet == SEC_I_INCOMPLETE_CREDENTIALS)
{
if ((conn->decryptBufferUsed == 0) || (scRet == SEC_E_INCOMPLETE_MESSAGE))
{
if (conn->decryptBufferLength <= conn->decryptBufferUsed)
{
conn->decryptBufferLength += 4096;
conn->decryptBuffer = (BYTE*) realloc(conn->decryptBuffer, conn->decryptBufferLength);
if (!conn->decryptBuffer)
{
DEBUG_printf(("_sspiConnect: unable to allocate %d byte decrypt buffer",
conn->decryptBufferLength));
SetLastError(E_OUTOFMEMORY);
ok = FALSE;
goto cleanup;
}
}
cbData = recv(conn->sock, conn->decryptBuffer + conn->decryptBufferUsed,
(int) (conn->decryptBufferLength - conn->decryptBufferUsed), 0);
if (cbData == SOCKET_ERROR)
{
DEBUG_printf(("_sspiConnect: recv failed: %d", WSAGetLastError()));
ok = FALSE;
goto cleanup;
}
else if (cbData == 0)
{
DEBUG_printf(("_sspiConnect: server unexpectedly disconnected"));
ok = FALSE;
goto cleanup;
}
DEBUG_printf(("_sspiConnect: %d bytes of handshake data received",
cbData));
conn->decryptBufferUsed += cbData;
}
inBuffers[0].pvBuffer = conn->decryptBuffer;
inBuffers[0].cbBuffer = (unsigned long) conn->decryptBufferUsed;
inBuffers[0].BufferType = SECBUFFER_TOKEN;
inBuffers[1].pvBuffer = NULL;
inBuffers[1].cbBuffer = 0;
inBuffers[1].BufferType = SECBUFFER_EMPTY;
inBuffer.cBuffers = 2;
inBuffer.pBuffers = inBuffers;
inBuffer.ulVersion = SECBUFFER_VERSION;
outBuffers[0].pvBuffer = NULL;
outBuffers[0].BufferType= SECBUFFER_TOKEN;
outBuffers[0].cbBuffer = 0;
outBuffer.cBuffers = 1;
outBuffer.pBuffers = outBuffers;
outBuffer.ulVersion = SECBUFFER_VERSION;
scRet = InitializeSecurityContext(&conn->creds, &conn->context, NULL, dwSSPIFlags,
0, SECURITY_NATIVE_DREP, &inBuffer, 0, NULL,
&outBuffer, &dwSSPIOutFlags, &tsExpiry);
if (scRet == SEC_E_OK ||
scRet == SEC_I_CONTINUE_NEEDED ||
FAILED(scRet) && (dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR))
{
if (outBuffers[0].cbBuffer && outBuffers[0].pvBuffer)
{
cbData = send(conn->sock, outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0);
if ((cbData == SOCKET_ERROR) || !cbData)
{
DEBUG_printf(("_sspiConnect: send failed: %d", WSAGetLastError()));
FreeContextBuffer(outBuffers[0].pvBuffer);
DeleteSecurityContext(&conn->context);
ok = FALSE;
goto cleanup;
}
DEBUG_printf(("_sspiConnect: %d bytes of handshake data sent", cbData));
FreeContextBuffer(outBuffers[0].pvBuffer);
outBuffers[0].pvBuffer = NULL;
}
}
if (scRet == SEC_E_INCOMPLETE_MESSAGE)
continue;
if (scRet == SEC_E_OK)
{
DEBUG_printf(("_sspiConnect: Handshake was successful"));
if (inBuffers[1].BufferType == SECBUFFER_EXTRA)
{
if (conn->decryptBufferLength < inBuffers[1].cbBuffer)
{
conn->decryptBuffer = realloc(conn->decryptBuffer, inBuffers[1].cbBuffer);
if (!conn->decryptBuffer)
{
DEBUG_printf(("_sspiConnect: unable to allocate %d bytes for decrypt buffer",
inBuffers[1].cbBuffer));
SetLastError(E_OUTOFMEMORY);
ok = FALSE;
goto cleanup;
}
}
memmove(conn->decryptBuffer,
conn->decryptBuffer + (conn->decryptBufferUsed - inBuffers[1].cbBuffer),
inBuffers[1].cbBuffer);
conn->decryptBufferUsed = inBuffers[1].cbBuffer;
DEBUG_printf(("_sspiConnect: %d bytes of app data was bundled with handshake data",
conn->decryptBufferUsed));
}
else
conn->decryptBufferUsed = 0;
break;
}
if (FAILED(scRet))
{
DEBUG_printf(("_sspiConnect: InitializeSecurityContext(2) failed: %x", scRet));
ok = FALSE;
break;
}
if (scRet == SEC_I_INCOMPLETE_CREDENTIALS)
{
DEBUG_printf(("_sspiConnect: server requested client credentials"));
ok = FALSE;
break;
}
if (inBuffers[1].BufferType == SECBUFFER_EXTRA)
{
memmove(conn->decryptBuffer,
conn->decryptBuffer + (conn->decryptBufferUsed - inBuffers[1].cbBuffer),
inBuffers[1].cbBuffer);
conn->decryptBufferUsed = inBuffers[1].cbBuffer;
}
else
{
conn->decryptBufferUsed = 0;
}
}
if (ok)
{
conn->contextInitialized = TRUE;
scRet = QueryContextAttributes(&conn->context, SECPKG_ATTR_REMOTE_CERT_CONTEXT, (VOID*) &serverCert );
if (scRet != SEC_E_OK)
{
DEBUG_printf(("_sspiConnect: QueryContextAttributes failed(SECPKG_ATTR_REMOTE_CERT_CONTEXT): %x", scRet));
ok = FALSE;
goto cleanup;
}
scRet = sspi_verify_certificate(serverCert, hostname, conn->certFlags);
if (scRet != SEC_E_OK)
{
DEBUG_printf(("_sspiConnect: sspi_verify_certificate failed: %x", scRet));
ok = FALSE;
goto cleanup;
}
scRet = QueryContextAttributes(&conn->context, SECPKG_ATTR_STREAM_SIZES, &conn->streamSizes);
if (scRet != SEC_E_OK)
{
DEBUG_printf(("_sspiConnect: QueryContextAttributes failed(SECPKG_ATTR_STREAM_SIZES): %x", scRet));
ok = FALSE;
}
}
cleanup:
if (serverCert)
CertFreeCertificateContext(serverCert);
return (ok);
}
BOOL
_sspiAccept(_sspi_struct_t *conn)
{
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
SECURITY_STATUS scRet;
SecBufferDesc inBuffer;
SecBuffer inBuffers[2];
SecBufferDesc outBuffer;
SecBuffer outBuffers[1];
DWORD num = 0;
BOOL fInitContext = TRUE;
BOOL ok = TRUE;
if (!conn)
return (FALSE);
dwSSPIFlags = ASC_REQ_SEQUENCE_DETECT |
ASC_REQ_REPLAY_DETECT |
ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR |
ASC_REQ_ALLOCATE_MEMORY |
ASC_REQ_STREAM;
conn->decryptBufferUsed = 0;
outBuffer.cBuffers = 1;
outBuffer.pBuffers = outBuffers;
outBuffer.ulVersion = SECBUFFER_VERSION;
scRet = SEC_I_CONTINUE_NEEDED;
while (scRet == SEC_I_CONTINUE_NEEDED ||
scRet == SEC_E_INCOMPLETE_MESSAGE ||
scRet == SEC_I_INCOMPLETE_CREDENTIALS)
{
if ((conn->decryptBufferUsed == 0) || (scRet == SEC_E_INCOMPLETE_MESSAGE))
{
if (conn->decryptBufferLength <= conn->decryptBufferUsed)
{
conn->decryptBufferLength += 4096;
conn->decryptBuffer = (BYTE*) realloc(conn->decryptBuffer,
conn->decryptBufferLength);
if (!conn->decryptBuffer)
{
DEBUG_printf(("_sspiAccept: unable to allocate %d byte decrypt buffer",
conn->decryptBufferLength));
ok = FALSE;
goto cleanup;
}
}
for (;;)
{
num = recv(conn->sock,
conn->decryptBuffer + conn->decryptBufferUsed,
(int)(conn->decryptBufferLength - conn->decryptBufferUsed),
0);
if ((num == SOCKET_ERROR) && (WSAGetLastError() == WSAEWOULDBLOCK))
Sleep(1);
else
break;
}
if (num == SOCKET_ERROR)
{
DEBUG_printf(("_sspiAccept: recv failed: %d", WSAGetLastError()));
ok = FALSE;
goto cleanup;
}
else if (num == 0)
{
DEBUG_printf(("_sspiAccept: client disconnected"));
ok = FALSE;
goto cleanup;
}
DEBUG_printf(("_sspiAccept: received %d (handshake) bytes from client",
num));
conn->decryptBufferUsed += num;
}
inBuffers[0].pvBuffer = conn->decryptBuffer;
inBuffers[0].cbBuffer = (unsigned long) conn->decryptBufferUsed;
inBuffers[0].BufferType = SECBUFFER_TOKEN;
inBuffers[1].pvBuffer = NULL;
inBuffers[1].cbBuffer = 0;
inBuffers[1].BufferType = SECBUFFER_EMPTY;
inBuffer.cBuffers = 2;
inBuffer.pBuffers = inBuffers;
inBuffer.ulVersion = SECBUFFER_VERSION;
outBuffers[0].pvBuffer = NULL;
outBuffers[0].BufferType = SECBUFFER_TOKEN;
outBuffers[0].cbBuffer = 0;
scRet = AcceptSecurityContext(&conn->creds, (fInitContext?NULL:&conn->context),
&inBuffer, dwSSPIFlags, SECURITY_NATIVE_DREP,
(fInitContext?&conn->context:NULL), &outBuffer,
&dwSSPIOutFlags, &tsExpiry);
fInitContext = FALSE;
if (scRet == SEC_E_OK ||
scRet == SEC_I_CONTINUE_NEEDED ||
(FAILED(scRet) && ((dwSSPIOutFlags & ISC_RET_EXTENDED_ERROR) != 0)))
{
if (outBuffers[0].cbBuffer && outBuffers[0].pvBuffer)
{
num = send(conn->sock, outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0);
if ((num == SOCKET_ERROR) || (num == 0))
{
DEBUG_printf(("_sspiAccept: handshake send failed: %d", WSAGetLastError()));
ok = FALSE;
goto cleanup;
}
DEBUG_printf(("_sspiAccept: send %d handshake bytes to client",
outBuffers[0].cbBuffer));
FreeContextBuffer(outBuffers[0].pvBuffer);
outBuffers[0].pvBuffer = NULL;
}
}
if (scRet == SEC_E_OK)
{
if (inBuffers[1].BufferType == SECBUFFER_EXTRA)
{
memcpy(conn->decryptBuffer,
(LPBYTE) (conn->decryptBuffer + (conn->decryptBufferUsed - inBuffers[1].cbBuffer)),
inBuffers[1].cbBuffer);
conn->decryptBufferUsed = inBuffers[1].cbBuffer;
}
else
{
conn->decryptBufferUsed = 0;
}
ok = TRUE;
break;
}
else if (FAILED(scRet) && (scRet != SEC_E_INCOMPLETE_MESSAGE))
{
DEBUG_printf(("_sspiAccept: AcceptSecurityContext failed: %x", scRet));
ok = FALSE;
break;
}
if (scRet != SEC_E_INCOMPLETE_MESSAGE &&
scRet != SEC_I_INCOMPLETE_CREDENTIALS)
{
if (inBuffers[1].BufferType == SECBUFFER_EXTRA)
{
memcpy(conn->decryptBuffer,
(LPBYTE) (conn->decryptBuffer + (conn->decryptBufferUsed - inBuffers[1].cbBuffer)),
inBuffers[1].cbBuffer);
conn->decryptBufferUsed = inBuffers[1].cbBuffer;
}
else
{
conn->decryptBufferUsed = 0;
}
}
}
if (ok)
{
conn->contextInitialized = TRUE;
scRet = QueryContextAttributes(&conn->context, SECPKG_ATTR_STREAM_SIZES, &conn->streamSizes);
if (scRet != SEC_E_OK)
{
DEBUG_printf(("_sspiAccept: QueryContextAttributes failed: %x", scRet));
ok = FALSE;
}
}
cleanup:
return (ok);
}
void
_sspiSetAllowsAnyRoot(_sspi_struct_t *conn,
BOOL allow)
{
conn->certFlags = (allow) ? conn->certFlags | SECURITY_FLAG_IGNORE_UNKNOWN_CA :
conn->certFlags & ~SECURITY_FLAG_IGNORE_UNKNOWN_CA;
}
void
_sspiSetAllowsExpiredCerts(_sspi_struct_t *conn,
BOOL allow)
{
conn->certFlags = (allow) ? conn->certFlags | SECURITY_FLAG_IGNORE_CERT_DATE_INVALID :
conn->certFlags & ~SECURITY_FLAG_IGNORE_CERT_DATE_INVALID;
}
int
_sspiWrite(_sspi_struct_t *conn,
void *buf,
size_t len)
{
SecBufferDesc message;
SecBuffer buffers[4] = { 0 };
BYTE *buffer = NULL;
int bufferLen;
size_t bytesLeft;
int index = 0;
int num = 0;
if (!conn || !buf || !len)
{
WSASetLastError(WSAEINVAL);
num = SOCKET_ERROR;
goto cleanup;
}
bufferLen = conn->streamSizes.cbMaximumMessage +
conn->streamSizes.cbHeader +
conn->streamSizes.cbTrailer;
buffer = (BYTE*) malloc(bufferLen);
if (!buffer)
{
DEBUG_printf(("_sspiWrite: buffer alloc of %d bytes failed", bufferLen));
WSASetLastError(E_OUTOFMEMORY);
num = SOCKET_ERROR;
goto cleanup;
}
bytesLeft = len;
while (bytesLeft)
{
size_t chunk = min(conn->streamSizes.cbMaximumMessage,
bytesLeft);
SECURITY_STATUS scRet;
memcpy(buffer + conn->streamSizes.cbHeader,
((BYTE*) buf) + index,
chunk);
message.ulVersion = SECBUFFER_VERSION;
message.cBuffers = 4;
message.pBuffers = buffers;
buffers[0].pvBuffer = buffer;
buffers[0].cbBuffer = conn->streamSizes.cbHeader;
buffers[0].BufferType = SECBUFFER_STREAM_HEADER;
buffers[1].pvBuffer = buffer + conn->streamSizes.cbHeader;
buffers[1].cbBuffer = (unsigned long) chunk;
buffers[1].BufferType = SECBUFFER_DATA;
buffers[2].pvBuffer = buffer + conn->streamSizes.cbHeader + chunk;
buffers[2].cbBuffer = conn->streamSizes.cbTrailer;
buffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
buffers[3].BufferType = SECBUFFER_EMPTY;
scRet = EncryptMessage(&conn->context, 0, &message, 0);
if (FAILED(scRet))
{
DEBUG_printf(("_sspiWrite: EncryptMessage failed: %x", scRet));
WSASetLastError(WSASYSCALLFAILURE);
num = SOCKET_ERROR;
goto cleanup;
}
num = send(conn->sock,
buffer,
buffers[0].cbBuffer + buffers[1].cbBuffer + buffers[2].cbBuffer,
0);
if ((num == SOCKET_ERROR) || (num == 0))
{
DEBUG_printf(("_sspiWrite: send failed: %ld", WSAGetLastError()));
goto cleanup;
}
bytesLeft -= (int) chunk;
index += (int) chunk;
}
num = (int) len;
cleanup:
if (buffer)
free(buffer);
return (num);
}
int
_sspiRead(_sspi_struct_t *conn,
void *buf,
size_t len)
{
SecBufferDesc message;
SecBuffer buffers[4] = { 0 };
int num = 0;
if (!conn)
{
WSASetLastError(WSAEINVAL);
num = SOCKET_ERROR;
goto cleanup;
}
if (buf && (conn->readBufferUsed > 0))
{
int bytesToCopy = (int) min(conn->readBufferUsed, len);
memcpy(buf, conn->readBuffer, bytesToCopy);
conn->readBufferUsed -= bytesToCopy;
if (conn->readBufferUsed > 0)
memmove(conn->readBuffer,
conn->readBuffer + bytesToCopy,
conn->readBufferUsed);
num = bytesToCopy;
}
else
{
PSecBuffer pDataBuffer;
PSecBuffer pExtraBuffer;
SECURITY_STATUS scRet;
int i;
message.ulVersion = SECBUFFER_VERSION;
message.cBuffers = 4;
message.pBuffers = buffers;
do
{
if (conn->decryptBufferLength <= conn->decryptBufferUsed)
{
conn->decryptBufferLength += 4096;
conn->decryptBuffer = (BYTE*) realloc(conn->decryptBuffer,
conn->decryptBufferLength);
if (!conn->decryptBuffer)
{
DEBUG_printf(("_sspiRead: unable to allocate %d byte buffer",
conn->decryptBufferLength));
WSASetLastError(E_OUTOFMEMORY);
num = SOCKET_ERROR;
goto cleanup;
}
}
buffers[0].pvBuffer = conn->decryptBuffer;
buffers[0].cbBuffer = (unsigned long) conn->decryptBufferUsed;
buffers[0].BufferType = SECBUFFER_DATA;
buffers[1].BufferType = SECBUFFER_EMPTY;
buffers[2].BufferType = SECBUFFER_EMPTY;
buffers[3].BufferType = SECBUFFER_EMPTY;
scRet = DecryptMessage(&conn->context, &message, 0, NULL);
if (scRet == SEC_E_INCOMPLETE_MESSAGE)
{
if (buf)
{
num = recv(conn->sock,
conn->decryptBuffer + conn->decryptBufferUsed,
(int)(conn->decryptBufferLength - conn->decryptBufferUsed),
0);
if (num == SOCKET_ERROR)
{
DEBUG_printf(("_sspiRead: recv failed: %d", WSAGetLastError()));
goto cleanup;
}
else if (num == 0)
{
DEBUG_printf(("_sspiRead: server disconnected"));
goto cleanup;
}
conn->decryptBufferUsed += num;
}
else
{
num = (int) conn->readBufferUsed;
goto cleanup;
}
}
}
while (scRet == SEC_E_INCOMPLETE_MESSAGE);
if (scRet == SEC_I_CONTEXT_EXPIRED)
{
DEBUG_printf(("_sspiRead: context expired"));
WSASetLastError(WSAECONNRESET);
num = SOCKET_ERROR;
goto cleanup;
}
else if (scRet != SEC_E_OK)
{
DEBUG_printf(("_sspiRead: DecryptMessage failed: %lx", scRet));
WSASetLastError(WSASYSCALLFAILURE);
num = SOCKET_ERROR;
goto cleanup;
}
pDataBuffer = NULL;
pExtraBuffer = NULL;
for (i = 1; i < 4; i++)
{
if (buffers[i].BufferType == SECBUFFER_DATA)
pDataBuffer = &buffers[i];
else if (!pExtraBuffer && (buffers[i].BufferType == SECBUFFER_EXTRA))
pExtraBuffer = &buffers[i];
}
if (pDataBuffer)
{
int bytesToCopy = min(pDataBuffer->cbBuffer, (int) len);
int bytesToSave = pDataBuffer->cbBuffer - bytesToCopy;
if (bytesToCopy)
memcpy(buf, pDataBuffer->pvBuffer, bytesToCopy);
if (bytesToSave)
{
if ((int)(conn->readBufferLength - conn->readBufferUsed) < bytesToSave)
{
conn->readBufferLength = conn->readBufferUsed + bytesToSave;
conn->readBuffer = realloc(conn->readBuffer,
conn->readBufferLength);
if (!conn->readBuffer)
{
DEBUG_printf(("_sspiRead: unable to allocate %d bytes", conn->readBufferLength));
WSASetLastError(E_OUTOFMEMORY);
num = SOCKET_ERROR;
goto cleanup;
}
}
memcpy(((BYTE*) conn->readBuffer) + conn->readBufferUsed,
((BYTE*) pDataBuffer->pvBuffer) + bytesToCopy,
bytesToSave);
conn->readBufferUsed += bytesToSave;
}
num = (buf) ? bytesToCopy : (int) conn->readBufferUsed;
}
else
{
DEBUG_printf(("_sspiRead: unable to find data buffer"));
WSASetLastError(WSASYSCALLFAILURE);
num = SOCKET_ERROR;
goto cleanup;
}
if (pExtraBuffer)
{
memmove(conn->decryptBuffer, pExtraBuffer->pvBuffer, pExtraBuffer->cbBuffer);
conn->decryptBufferUsed = pExtraBuffer->cbBuffer;
}
else
{
conn->decryptBufferUsed = 0;
}
}
cleanup:
return (num);
}
int
_sspiPending(_sspi_struct_t *conn)
{
return (_sspiRead(conn, NULL, 0));
}
void
_sspiFree(_sspi_struct_t *conn)
{
if (!conn)
return;
if (conn->contextInitialized)
{
SecBufferDesc message;
SecBuffer buffers[1] = { 0 };
DWORD dwType;
DWORD status;
dwType = SCHANNEL_SHUTDOWN;
buffers[0].pvBuffer = &dwType;
buffers[0].BufferType = SECBUFFER_TOKEN;
buffers[0].cbBuffer = sizeof(dwType);
message.cBuffers = 1;
message.pBuffers = buffers;
message.ulVersion = SECBUFFER_VERSION;
status = ApplyControlToken(&conn->context, &message);
if (SUCCEEDED(status))
{
PBYTE pbMessage;
DWORD cbMessage;
DWORD cbData;
DWORD dwSSPIFlags;
DWORD dwSSPIOutFlags;
TimeStamp tsExpiry;
dwSSPIFlags = ASC_REQ_SEQUENCE_DETECT |
ASC_REQ_REPLAY_DETECT |
ASC_REQ_CONFIDENTIALITY |
ASC_REQ_EXTENDED_ERROR |
ASC_REQ_ALLOCATE_MEMORY |
ASC_REQ_STREAM;
buffers[0].pvBuffer = NULL;
buffers[0].BufferType = SECBUFFER_TOKEN;
buffers[0].cbBuffer = 0;
message.cBuffers = 1;
message.pBuffers = buffers;
message.ulVersion = SECBUFFER_VERSION;
status = AcceptSecurityContext(&conn->creds, &conn->context, NULL,
dwSSPIFlags, SECURITY_NATIVE_DREP, NULL,
&message, &dwSSPIOutFlags, &tsExpiry);
if (SUCCEEDED(status))
{
pbMessage = buffers[0].pvBuffer;
cbMessage = buffers[0].cbBuffer;
if (pbMessage && cbMessage)
{
cbData = send(conn->sock, pbMessage, cbMessage, 0);
if ((cbData == SOCKET_ERROR) || (cbData == 0))
{
status = WSAGetLastError();
DEBUG_printf(("_sspiFree: sending close notify failed: %d", status));
}
else
{
FreeContextBuffer(pbMessage);
}
}
}
else
{
DEBUG_printf(("_sspiFree: AcceptSecurityContext failed: %x", status));
}
}
else
{
DEBUG_printf(("_sspiFree: ApplyControlToken failed: %x", status));
}
DeleteSecurityContext(&conn->context);
conn->contextInitialized = FALSE;
}
if (conn->decryptBuffer)
{
free(conn->decryptBuffer);
conn->decryptBuffer = NULL;
}
if (conn->readBuffer)
{
free(conn->readBuffer);
conn->readBuffer = NULL;
}
if (conn->sock != INVALID_SOCKET)
{
closesocket(conn->sock);
conn->sock = INVALID_SOCKET;
}
free(conn);
}
static DWORD
sspi_verify_certificate(PCCERT_CONTEXT serverCert,
const CHAR *serverName,
DWORD dwCertFlags)
{
HTTPSPolicyCallbackData httpsPolicy;
CERT_CHAIN_POLICY_PARA policyPara;
CERT_CHAIN_POLICY_STATUS policyStatus;
CERT_CHAIN_PARA chainPara;
PCCERT_CHAIN_CONTEXT chainContext = NULL;
PWSTR serverNameUnicode = NULL;
LPSTR rgszUsages[] = { szOID_PKIX_KP_SERVER_AUTH,
szOID_SERVER_GATED_CRYPTO,
szOID_SGC_NETSCAPE };
DWORD cUsages = sizeof(rgszUsages) / sizeof(LPSTR);
DWORD count;
DWORD status;
if (!serverCert)
{
status = SEC_E_WRONG_PRINCIPAL;
goto cleanup;
}
if (!serverName || (strlen(serverName) == 0))
{
status = SEC_E_WRONG_PRINCIPAL;
goto cleanup;
}
count = MultiByteToWideChar(CP_ACP, 0, serverName, -1, NULL, 0);
serverNameUnicode = LocalAlloc(LMEM_FIXED, count * sizeof(WCHAR));
if (!serverNameUnicode)
{
status = SEC_E_INSUFFICIENT_MEMORY;
goto cleanup;
}
count = MultiByteToWideChar(CP_ACP, 0, serverName, -1, serverNameUnicode, count);
if (count == 0)
{
status = SEC_E_WRONG_PRINCIPAL;
goto cleanup;
}
ZeroMemory(&chainPara, sizeof(chainPara));
chainPara.cbSize = sizeof(chainPara);
chainPara.RequestedUsage.dwType = USAGE_MATCH_TYPE_OR;
chainPara.RequestedUsage.Usage.cUsageIdentifier = cUsages;
chainPara.RequestedUsage.Usage.rgpszUsageIdentifier = rgszUsages;
if (!CertGetCertificateChain(NULL, serverCert, NULL, serverCert->hCertStore,
&chainPara, 0, NULL, &chainContext))
{
status = GetLastError();
DEBUG_printf(("CertGetCertificateChain returned 0x%x\n", status));
goto cleanup;
}
ZeroMemory(&httpsPolicy, sizeof(HTTPSPolicyCallbackData));
httpsPolicy.cbStruct = sizeof(HTTPSPolicyCallbackData);
httpsPolicy.dwAuthType = AUTHTYPE_SERVER;
httpsPolicy.fdwChecks = dwCertFlags;
httpsPolicy.pwszServerName = serverNameUnicode;
memset(&policyPara, 0, sizeof(policyPara));
policyPara.cbSize = sizeof(policyPara);
policyPara.pvExtraPolicyPara = &httpsPolicy;
memset(&policyStatus, 0, sizeof(policyStatus));
policyStatus.cbSize = sizeof(policyStatus);
if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chainContext,
&policyPara, &policyStatus))
{
status = GetLastError();
DEBUG_printf(("CertVerifyCertificateChainPolicy returned %d", status));
goto cleanup;
}
if (policyStatus.dwError)
{
status = policyStatus.dwError;
goto cleanup;
}
status = SEC_E_OK;
cleanup:
if (chainContext)
CertFreeCertificateChain(chainContext);
if (serverNameUnicode)
LocalFree(serverNameUnicode);
return (status);
}