SSLRecordInternal.c   [plain text]


//
//  SSLRecordInternal.c
//  Security
//
//  Created by Fabrice Gautier on 10/25/11.
//  Copyright (c) 2011 Apple, Inc. All rights reserved.
//

/* THIS FILE CONTAINS KERNEL CODE */

#include "sslBuildFlags.h"
#include "SSLRecordInternal.h"
#include "sslDebug.h"
#include "cipherSpecs.h"
#include "symCipher.h"
#include "sslUtils.h"
#include "tls_record.h"

#include <AssertMacros.h>
#include <string.h>

#include <inttypes.h>

#define DEFAULT_BUFFER_SIZE 4096


/*
 * Redirect SSLBuffer-based I/O call to user-supplied I/O.
 */
static
int sslIoRead(SSLBuffer                        buf,
              size_t                           *actualLength,
              struct SSLRecordInternalContext  *ctx)
{
	size_t  dataLength = buf.length;
	int     ortn;

	*actualLength = 0;
	ortn = (ctx->read)(ctx->ioRef,
                       buf.data,
                       &dataLength);
	*actualLength = dataLength;
	return ortn;
}

static
int sslIoWrite(SSLBuffer                       buf,
               size_t                          *actualLength,
               struct SSLRecordInternalContext *ctx)
{
	size_t  dataLength = buf.length;
	int     ortn;

	*actualLength = 0;
	ortn = (ctx->write)(ctx->ioRef,
                        buf.data,
                        &dataLength);
	*actualLength = dataLength;
	return ortn;
}


static int
SSLDisposeCipherSuite(CipherContext *cipher, struct SSLRecordInternalContext *ctx)
{   int      err;

	/* symmetric encryption context */
	if(cipher->symCipher) {
		if ((err = cipher->symCipher->finish(cipher->cipherCtx)) != 0) {
			return err;
		}
	}

	/* per-record hash/hmac context */
	ctx->sslTslCalls->freeMac(cipher);

    return 0;
}



/* common for sslv3 and tlsv1, except for the computeMac callout */
int SSLVerifyMac(uint8_t type,
                 SSLBuffer *data,
                 uint8_t *compareMAC,
                 struct SSLRecordInternalContext *ctx)
{
	int        err;
    uint8_t           macData[SSL_MAX_DIGEST_LEN];
    SSLBuffer       secret, mac;

    secret.data = ctx->readCipher.macSecret;
    secret.length = ctx->readCipher.macRef->hash->digestSize;
    mac.data = macData;
    mac.length = ctx->readCipher.macRef->hash->digestSize;

	check(ctx->sslTslCalls != NULL);
    if ((err = ctx->sslTslCalls->computeMac(type,
                                            *data,
                                            mac,
                                            &ctx->readCipher,
                                            ctx->readCipher.sequenceNum,
                                            ctx)) != 0)
        return err;

    if ((memcmp(mac.data, compareMAC, mac.length)) != 0) {
		sslErrorLog("SSLVerifyMac: Mac verify failure\n");
        return errSSLRecordProtocol;
    }
    return 0;
}

#include "cipherSpecs.h"
#include "symCipher.h"

static const HashHmacReference *sslCipherSuiteGetHashHmacReference(uint16_t selectedCipher)
{
    HMAC_Algs alg = sslCipherSuiteGetMacAlgorithm(selectedCipher);

    switch (alg) {
        case HA_Null:
            return &HashHmacNull;
        case HA_MD5:
            return &HashHmacMD5;
        case HA_SHA1:
            return &HashHmacSHA1;
        case HA_SHA256:
            return &HashHmacSHA256;
        case HA_SHA384:
            return &HashHmacSHA384;
        default:
            sslErrorLog("Invalid hashAlgorithm %d", alg);
            check(0);
            return &HashHmacNull;
    }
}

static const SSLSymmetricCipher *sslCipherSuiteGetSymmetricCipher(uint16_t selectedCipher)
{

    SSL_CipherAlgorithm alg = sslCipherSuiteGetSymmetricCipherAlgorithm(selectedCipher);
    switch(alg) {
        case SSL_CipherAlgorithmNull:
            return &SSLCipherNull;
#if ENABLE_RC2
        case SSL_CipherAlgorithmRC2_128:
            return &SSLCipherRC2_128;
#endif
#if ENABLE_RC4
        case SSL_CipherAlgorithmRC4_128:
            return &SSLCipherRC4_128;
#endif
#if ENABLE_DES
        case SSL_CipherAlgorithmDES_CBC:
            return &SSLCipherDES_CBC;
#endif
        case SSL_CipherAlgorithm3DES_CBC:
            return &SSLCipher3DES_CBC;
        case SSL_CipherAlgorithmAES_128_CBC:
            return &SSLCipherAES_128_CBC;
        case SSL_CipherAlgorithmAES_256_CBC:
            return &SSLCipherAES_256_CBC;
#if ENABLE_AES_GCM
        case SSL_CipherAlgorithmAES_128_GCM:
            return &SSLCipherAES_128_GCM;
        case SSL_CipherAlgorithmAES_256_GCM:
            return &SSLCipherAES_256_GCM;
#endif
        default:
            check(0);
            return &SSLCipherNull;
    }
}

static void InitCipherSpec(struct SSLRecordInternalContext *ctx, uint16_t selectedCipher)
{
    SSLRecordCipherSpec *dst = &ctx->selectedCipherSpec;

    ctx->selectedCipher = selectedCipher;
    dst->cipher = sslCipherSuiteGetSymmetricCipher(selectedCipher);
    dst->macAlgorithm = sslCipherSuiteGetHashHmacReference(selectedCipher);
};

/* Entry points to Record Layer */

static int SSLRecordReadInternal(SSLRecordContextRef ref, SSLRecord *rec)
{   int        err;
    size_t          len, contentLen;
    uint8_t           *charPtr;
    SSLBuffer       readData, cipherFragment;
    size_t          head=5;
    int             skipit=0;
    struct SSLRecordInternalContext *ctx = ref;

    if(ctx->isDTLS)
        head+=8;

    if (!ctx->partialReadBuffer.data || ctx->partialReadBuffer.length < head)
    {   if (ctx->partialReadBuffer.data)
        if ((err = SSLFreeBuffer(&ctx->partialReadBuffer)) != 0)
        {
            return err;
        }
        if ((err = SSLAllocBuffer(&ctx->partialReadBuffer,
                                  DEFAULT_BUFFER_SIZE)) != 0)
        {
            return err;
        }
    }

    if (ctx->negProtocolVersion == SSL_Version_Undetermined) {
        if (ctx->amountRead < 1)
        {   readData.length = 1 - ctx->amountRead;
            readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
            len = readData.length;
            err = sslIoRead(readData, &len, ctx);
            if(err != 0)
            {   if (err == errSSLRecordWouldBlock) {
                ctx->amountRead += len;
                return err;
            }
            else {
                /* abort */
                err = errSSLRecordClosedAbort;
#if 0 // TODO: revisit this in the transport layer
                if((ctx->protocolSide == kSSLClientSide) &&
                   (ctx->amountRead == 0) &&
                   (len == 0)) {
                    /*
                     * Detect "server refused to even try to negotiate"
                     * error, when the server drops the connection before
                     * sending a single byte.
                     */
                    switch(ctx->state) {
                        case SSL_HdskStateServerHello:
                            sslHdskStateDebug("Server dropped initial connection\n");
                            err = errSSLConnectionRefused;
                            break;
                        default:
                            break;
                    }
                }
#endif
                return err;
            }
            }
            ctx->amountRead += len;
        }
    }

    if (ctx->amountRead < head)
    {   readData.length = head - ctx->amountRead;
        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
        len = readData.length;
        err = sslIoRead(readData, &len, ctx);
        if(err != 0)
        {
			switch(err) {
				case errSSLRecordWouldBlock:
					ctx->amountRead += len;
					break;
#if	SSL_ALLOW_UNNOTICED_DISCONNECT
				case errSSLClosedGraceful:
					/* legal if we're on record boundary and we've gotten past
					 * the handshake */
					if((ctx->amountRead == 0) && 				/* nothing pending */
					   (len == 0) &&							/* nothing new */
					   (ctx->state == SSL_HdskStateClientReady)) {	/* handshake done */
					    /*
						 * This means that the server has disconnected without
						 * sending a closure alert notice. This is technically
						 * illegal per the SSL3 spec, but about half of the
						 * servers out there do it, so we report it as a separate
						 * error which most clients - including (currently)
						 * URLAccess - ignore by treating it the same as
						 * a errSSLClosedGraceful error. Paranoid
						 * clients can detect it and handle it however they
						 * want to.
						 */
						SSLChangeHdskState(ctx, SSL_HdskStateNoNotifyClose);
						err = errSSLClosedNoNotify;
						break;
					}
					else {
						/* illegal disconnect */
						err = errSSLClosedAbort;
						/* and drop thru to default: fatal alert */
					}
#endif	/* SSL_ALLOW_UNNOTICED_DISCONNECT */
				default:
					break;
            }
            return err;
        }
        ctx->amountRead += len;
    }

    check(ctx->amountRead >= head);

    charPtr = ctx->partialReadBuffer.data;
    rec->contentType = *charPtr++;
    if (rec->contentType < SSL_RecordTypeV3_Smallest ||
        rec->contentType > SSL_RecordTypeV3_Largest)
        return errSSLRecordProtocol;

    rec->protocolVersion = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
    charPtr += 2;

    if(rec->protocolVersion == DTLS_Version_1_0)
    {
        sslUint64 seqNum;
        SSLDecodeUInt64(charPtr, 8, &seqNum);
        charPtr += 8;
        sslLogRecordIo("Read DTLS Record %016llx (seq is: %016llx)",
                       seqNum, ctx->readCipher.sequenceNum);

        /* if the epoch of the record is different of current read cipher, just drop it */
        if((seqNum>>48)!=(ctx->readCipher.sequenceNum>>48)) {
            skipit=1;
        } else {
            ctx->readCipher.sequenceNum=seqNum;
        }
    }

    contentLen = SSLDecodeInt(charPtr, 2);
    charPtr += 2;
    if (contentLen > (16384 + 2048))    /* Maximum legal length of an
										 * SSLCipherText payload */
    {
        return errSSLRecordRecordOverflow;
    }

    if (ctx->partialReadBuffer.length < head + contentLen)
    {   if ((err = SSLReallocBuffer(&ctx->partialReadBuffer, head + contentLen)) != 0)
    {
        return err;
    }
    }

    if (ctx->amountRead < head + contentLen)
    {   readData.length = head + contentLen - ctx->amountRead;
        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
        len = readData.length;
        err = sslIoRead(readData, &len, ctx);
        if(err != 0)
        {   if (err == errSSLRecordWouldBlock)
            ctx->amountRead += len;
            return err;
        }
        ctx->amountRead += len;
    }

    check(ctx->amountRead >= head + contentLen);

    cipherFragment.data = ctx->partialReadBuffer.data + head;
    cipherFragment.length = contentLen;

    ctx->amountRead = 0;        /* We've used all the data in the cache */

    /* We dont decrypt if we were told to skip this record */
    if(skipit) {
        return errSSLRecordUnexpectedRecord;
    }
	/*
	 * Decrypt the payload & check the MAC, modifying the length of the
	 * buffer to indicate the amount of plaintext data after adjusting
	 * for the block size and removing the MAC */
	check(ctx->sslTslCalls != NULL);
    if ((err = ctx->sslTslCalls->decryptRecord(rec->contentType,
                                               &cipherFragment, ctx)) != 0)
        return err;

	/*
	 * We appear to have sucessfully received a record; increment the
	 * sequence number
	 */
    IncrementUInt64(&ctx->readCipher.sequenceNum);

	/* Allocate a buffer to return the plaintext in and return it */
    if ((err = SSLAllocBuffer(&rec->contents, cipherFragment.length)) != 0)
    {
        return err;
    }
    memcpy(rec->contents.data, cipherFragment.data, cipherFragment.length);


    return 0;
}

static int SSLRecordWriteInternal(SSLRecordContextRef ref, SSLRecord rec)
{
    int err;
    struct SSLRecordInternalContext *ctx = ref;

    err=ctx->sslTslCalls->writeRecord(rec, ctx);

    check_noerr(err);

    return err;
}

/* Record Layer Entry Points */

static int
SSLRollbackInternalRecordLayerWriteCipher(SSLRecordContextRef ref)
{
    int err;
    struct SSLRecordInternalContext *ctx = ref;

    if ((err = SSLDisposeCipherSuite(&ctx->writePending, ctx)) != 0)
        return err;

    ctx->writePending = ctx->writeCipher;
    ctx->writeCipher = ctx->prevCipher;

    /* Zero out old data */
    memset(&ctx->prevCipher, 0, sizeof(CipherContext));

    return 0;
}

static int
SSLAdvanceInternalRecordLayerWriteCipher(SSLRecordContextRef ref)
{
    int err;
    struct SSLRecordInternalContext *ctx = ref;

    if ((err = SSLDisposeCipherSuite(&ctx->prevCipher, ctx)) != 0)
        return err;

    ctx->prevCipher = ctx->writeCipher;
    ctx->writeCipher = ctx->writePending;

    /* Zero out old data */
    memset(&ctx->writePending, 0, sizeof(CipherContext));

    return 0;
}

static int
SSLAdvanceInternalRecordLayerReadCipher(SSLRecordContextRef ref)
{
    struct SSLRecordInternalContext *ctx = ref;
    int err;

    if ((err = SSLDisposeCipherSuite(&ctx->readCipher, ctx)) != 0)
        return err;

    ctx->readCipher = ctx->readPending;
    memset(&ctx->readPending, 0, sizeof(CipherContext)); 	/* Zero out old data */

    return 0;
}

static int
SSLInitInternalRecordLayerPendingCiphers(SSLRecordContextRef ref, uint16_t selectedCipher, bool isServer, SSLBuffer key)
{   int        err;
    uint8_t         *keyDataProgress, *keyPtr, *ivPtr;
    CipherContext   *serverPending, *clientPending;

    struct SSLRecordInternalContext *ctx = ref;

    InitCipherSpec(ctx, selectedCipher);

    ctx->readPending.macRef = ctx->selectedCipherSpec.macAlgorithm;
    ctx->writePending.macRef = ctx->selectedCipherSpec.macAlgorithm;
    ctx->readPending.symCipher = ctx->selectedCipherSpec.cipher;
    ctx->writePending.symCipher = ctx->selectedCipherSpec.cipher;
    /* This need to be reinitialized because the whole thing is zeroed sometimes */
    ctx->readPending.encrypting = 0;
    ctx->writePending.encrypting = 1;

    if(ctx->negProtocolVersion == DTLS_Version_1_0)
    {
        ctx->readPending.sequenceNum = (ctx->readPending.sequenceNum & (0xffffULL<<48)) + (1ULL<<48);
        ctx->writePending.sequenceNum = (ctx->writePending.sequenceNum & (0xffffULL<<48)) + (1ULL<<48);
    } else {
        ctx->writePending.sequenceNum = 0;
        ctx->readPending.sequenceNum = 0;
    }

    if (isServer)
    {   serverPending = &ctx->writePending;
        clientPending = &ctx->readPending;
    }
    else
    {   serverPending = &ctx->readPending;
        clientPending = &ctx->writePending;
    }

    /* Check the size of the 'key' buffer - <rdar://problem/11204357> */
    if(key.length != ctx->selectedCipherSpec.macAlgorithm->hash->digestSize*2
                   + ctx->selectedCipherSpec.cipher->params->keySize*2
                   + ctx->selectedCipherSpec.cipher->params->ivSize*2)
    {
        return errSSLRecordInternal;
    }

    keyDataProgress = key.data;
    memcpy(clientPending->macSecret, keyDataProgress,
           ctx->selectedCipherSpec.macAlgorithm->hash->digestSize);
    keyDataProgress += ctx->selectedCipherSpec.macAlgorithm->hash->digestSize;
    memcpy(serverPending->macSecret, keyDataProgress,
           ctx->selectedCipherSpec.macAlgorithm->hash->digestSize);
    keyDataProgress += ctx->selectedCipherSpec.macAlgorithm->hash->digestSize;

    if (ctx->selectedCipherSpec.cipher->params->cipherType == aeadCipherType)
        goto skipInit;

    /* init the reusable-per-record MAC contexts */
    err = ctx->sslTslCalls->initMac(clientPending);
    if(err) {
        goto fail;
    }
    err = ctx->sslTslCalls->initMac(serverPending);
    if(err) {
        goto fail;
    }

    keyPtr = keyDataProgress;
    keyDataProgress += ctx->selectedCipherSpec.cipher->params->keySize;
    /* Skip server write key to get to IV */
    ivPtr = keyDataProgress + ctx->selectedCipherSpec.cipher->params->keySize;
    if ((err = ctx->selectedCipherSpec.cipher->c.cipher.initialize(clientPending->symCipher->params, clientPending->encrypting, keyPtr, ivPtr,
                                                                   &clientPending->cipherCtx)) != 0)
        goto fail;
    keyPtr = keyDataProgress;
    keyDataProgress += ctx->selectedCipherSpec.cipher->params->keySize;
    /* Skip client write IV to get to server write IV */
    ivPtr = keyDataProgress + ctx->selectedCipherSpec.cipher->params->ivSize;
    if ((err = ctx->selectedCipherSpec.cipher->c.cipher.initialize(serverPending->symCipher->params, serverPending->encrypting, keyPtr, ivPtr,
                                                                   &serverPending->cipherCtx)) != 0)
        goto fail;

skipInit:
    /* Ciphers are ready for use */
    ctx->writePending.ready = 1;
    ctx->readPending.ready = 1;

    /* Ciphers get swapped by sending or receiving a change cipher spec message */
    err = 0;

fail:
    return err;
}

static int
SSLSetInternalRecordLayerProtocolVersion(SSLRecordContextRef ref, SSLProtocolVersion negVersion)
{
    struct SSLRecordInternalContext *ctx = ref;

    switch(negVersion) {
        case SSL_Version_3_0:
            ctx->sslTslCalls = &Ssl3RecordCallouts;
            break;
        case TLS_Version_1_0:
        case TLS_Version_1_1:
        case DTLS_Version_1_0:
        case TLS_Version_1_2:
            ctx->sslTslCalls = &Tls1RecordCallouts;
            break;
        case SSL_Version_2_0:
        case SSL_Version_Undetermined:
        default:
            return errSSLRecordNegotiation;
    }
    ctx->negProtocolVersion = negVersion;

    return 0;
}

static int
SSLRecordFreeInternal(SSLRecordContextRef ref, SSLRecord rec)
{
    return SSLFreeBuffer(&rec.contents);
}

static int
SSLRecordServiceWriteQueueInternal(SSLRecordContextRef ref)
{
    int             err = 0, werr = 0;
    size_t          written = 0;
    SSLBuffer       buf;
    WaitingRecord   *rec;
    struct SSLRecordInternalContext *ctx= ref;

    while (!werr && ((rec = ctx->recordWriteQueue) != 0))
    {   buf.data = rec->data + rec->sent;
        buf.length = rec->length - rec->sent;
        werr = sslIoWrite(buf, &written, ctx);
        rec->sent += written;
        if (rec->sent >= rec->length)
        {
            check(rec->sent == rec->length);
            check(err == 0);
            ctx->recordWriteQueue = rec->next;
			sslFree(rec);
        }
        if (err) {
            check_noerr(err);
            return err;
        }
    }

    return werr;
}

/***** Internal Record Layer APIs *****/

SSLRecordContextRef
SSLCreateInternalRecordLayer(bool dtls)
{
    struct SSLRecordInternalContext *ctx;

    ctx = sslMalloc(sizeof(struct SSLRecordInternalContext));
    if(ctx==NULL)
        return NULL;

    memset(ctx, 0, sizeof(struct SSLRecordInternalContext));

    ctx->negProtocolVersion = SSL_Version_Undetermined;

    ctx->sslTslCalls = &Ssl3RecordCallouts;
    ctx->recordWriteQueue = NULL;

    InitCipherSpec(ctx, TLS_NULL_WITH_NULL_NULL);

    ctx->writeCipher.macRef    = ctx->selectedCipherSpec.macAlgorithm;
    ctx->readCipher.macRef     = ctx->selectedCipherSpec.macAlgorithm;
    ctx->readCipher.symCipher  = ctx->selectedCipherSpec.cipher;
    ctx->writeCipher.symCipher = ctx->selectedCipherSpec.cipher;
    ctx->readCipher.encrypting = 0;
    ctx->writeCipher.encrypting = 1;

    ctx->isDTLS = dtls;

    return ctx;

}

int
SSLSetInternalRecordLayerIOFuncs(
                                 SSLRecordContextRef ref,
                                 SSLIOReadFunc    readFunc,
                                 SSLIOWriteFunc   writeFunc)
{
    struct SSLRecordInternalContext *ctx = ref;

    ctx->read = readFunc;
    ctx->write = writeFunc;

    return 0;
}

int
SSLSetInternalRecordLayerConnection(
                                    SSLRecordContextRef ref,
                                    SSLIOConnectionRef ioRef)
{
    struct SSLRecordInternalContext *ctx = ref;

    ctx->ioRef = ioRef;

    return 0;
}

void
SSLDestroyInternalRecordLayer(SSLRecordContextRef ref)
{
    struct SSLRecordInternalContext *ctx = ref;
	WaitingRecord   *waitRecord, *next;

    /* RecordContext cleanup : */
    SSLFreeBuffer(&ctx->partialReadBuffer);
    waitRecord = ctx->recordWriteQueue;
    while (waitRecord)
    {   next = waitRecord->next;
        sslFree(waitRecord);
        waitRecord = next;
    }


    /* Cleanup cipher structs */
    SSLDisposeCipherSuite(&ctx->readCipher, ctx);
    SSLDisposeCipherSuite(&ctx->writeCipher, ctx);
    SSLDisposeCipherSuite(&ctx->readPending, ctx);
    SSLDisposeCipherSuite(&ctx->writePending, ctx);
    SSLDisposeCipherSuite(&ctx->prevCipher, ctx);

    sslFree(ctx);

}

struct SSLRecordFuncs SSLRecordLayerInternal =
{
    .read  = SSLRecordReadInternal,
    .write = SSLRecordWriteInternal,
    .initPendingCiphers = SSLInitInternalRecordLayerPendingCiphers,
    .advanceWriteCipher = SSLAdvanceInternalRecordLayerWriteCipher,
    .advanceReadCipher = SSLAdvanceInternalRecordLayerReadCipher,
    .rollbackWriteCipher = SSLRollbackInternalRecordLayerWriteCipher,
    .setProtocolVersion = SSLSetInternalRecordLayerProtocolVersion,
    .free = SSLRecordFreeInternal,
    .serviceWriteQueue = SSLRecordServiceWriteQueueInternal,
};