ssl3RecordCallouts.c [plain text]
#include <AssertMacros.h>
#include <string.h>
#include <stdint.h>
#include <inttypes.h>
#ifdef KERNEL
#include <sys/types.h>
#else
#include <stddef.h>
#endif
#include "sslDebug.h"
#include "sslMemory.h"
#include "sslUtils.h"
#include "sslRand.h"
#include "tls_record.h"
int ssl3WriteRecord(
SSLRecord rec,
struct SSLRecordInternalContext *ctx)
{
int err;
int padding = 0, i;
WaitingRecord *out = NULL, *queue;
SSLBuffer payload, mac;
uint8_t *charPtr;
uint16_t payloadSize,blockSize = 0;
int head = 5;
switch(rec.protocolVersion) {
case DTLS_Version_1_0:
head += 8;
case SSL_Version_3_0:
case TLS_Version_1_0:
case TLS_Version_1_1:
case TLS_Version_1_2:
break;
default:
check(0);
return errSSLRecordInternal;
}
check(rec.contents.length <= 16384);
sslLogRecordIo("type = %02x, ver = %04x, len = %ld, seq = %016llx",
rec.contentType, rec.protocolVersion, rec.contents.length,
ctx->writeCipher.sequenceNum);
payloadSize = (uint16_t) rec.contents.length;
CipherType cipherType = ctx->writeCipher.symCipher->params->cipherType;
const Cipher *cipher = &ctx->writeCipher.symCipher->c.cipher;
const AEADCipher *aead = &ctx->writeCipher.symCipher->c.aead;
blockSize = ctx->writeCipher.symCipher->params->blockSize;
switch (cipherType) {
case blockCipherType:
payloadSize += ctx->writeCipher.macRef->hash->digestSize;
padding = blockSize - (payloadSize % blockSize) - 1;
payloadSize += padding + 1;
if(ctx->negProtocolVersion >= TLS_Version_1_1) {
payloadSize += blockSize;
}
break;
case streamCipherType:
payloadSize += ctx->writeCipher.macRef->hash->digestSize;
break;
case aeadCipherType:
payloadSize += aead->macSize;
break;
default:
check(0);
return errSSLRecordInternal;
}
out = (WaitingRecord *)sslMalloc(offsetof(WaitingRecord, data) +
head + payloadSize);
out->next = NULL;
out->sent = 0;
out->length = head + payloadSize;
charPtr = out->data;
*(charPtr++) = rec.contentType;
charPtr = SSLEncodeInt(charPtr, rec.protocolVersion, 2);
if(rec.protocolVersion == DTLS_Version_1_0)
charPtr = SSLEncodeUInt64(charPtr,ctx->writeCipher.sequenceNum);
charPtr = SSLEncodeInt(charPtr, payloadSize, 2);
if((ctx->negProtocolVersion >= TLS_Version_1_1) &&
(cipherType == blockCipherType))
{
SSLBuffer randomIV;
randomIV.data = charPtr;
randomIV.length = blockSize;
if((err = sslRand(&randomIV)) != 0)
return err;
charPtr += blockSize;
}
if (cipherType == aeadCipherType) {
charPtr = SSLEncodeUInt64(charPtr,ctx->writeCipher.sequenceNum);
}
memcpy(charPtr, rec.contents.data, rec.contents.length);
payload.data = charPtr;
payload.length = rec.contents.length;
charPtr += rec.contents.length;
if (cipherType != aeadCipherType) {
mac.data = charPtr;
mac.length = ctx->writeCipher.macRef->hash->digestSize;
charPtr += mac.length;
if (mac.length > 0)
{
check(ctx->sslTslCalls != NULL);
if ((err = ctx->sslTslCalls->computeMac(rec.contentType,
payload,
mac,
&ctx->writeCipher,
ctx->writeCipher.sequenceNum,
ctx)) != 0)
goto fail;
}
}
if(ctx->negProtocolVersion >= TLS_Version_1_1 &&
cipherType == blockCipherType)
{
payload.data -= blockSize;
}
payload.length = payloadSize;
switch (cipherType) {
case blockCipherType:
for (i = 1; i <= padding + 1; ++i)
payload.data[payload.length - i] = padding;
case streamCipherType:
if ((err = cipher->encrypt(payload.data,
payload.data, payload.length, ctx->writeCipher.cipherCtx)) != 0)
goto fail;
break;
case aeadCipherType:
check(0);
break;
default:
check(0);
return errSSLRecordInternal;
}
if (ctx->recordWriteQueue == 0)
ctx->recordWriteQueue = out;
else
{ queue = ctx->recordWriteQueue;
while (queue->next != 0)
queue = queue->next;
queue->next = out;
}
IncrementUInt64(&ctx->writeCipher.sequenceNum);
return 0;
fail:
sslFree(out);
return err;
}
static int ssl3DecryptRecord(
uint8_t type,
SSLBuffer *payload,
struct SSLRecordInternalContext *ctx)
{
int err;
SSLBuffer content;
CipherType cipherType = ctx->readCipher.symCipher->params->cipherType;
const Cipher *c = &ctx->readCipher.symCipher->c.cipher;
switch (cipherType) {
case blockCipherType:
if ((payload->length % ctx->readCipher.symCipher->params->blockSize) != 0)
{
return errSSLRecordDecryptionFail;
}
case streamCipherType:
err = c->decrypt(payload->data, payload->data,
payload->length, ctx->readCipher.cipherCtx);
break;
case aeadCipherType:
default:
check(0);
return errSSLRecordInternal;
}
if (err != 0)
{
return errSSLRecordDecryptionFail;
}
content.data = payload->data;
content.length = payload->length - ctx->readCipher.macRef->hash->digestSize;
if (cipherType == blockCipherType)
{
if (payload->data[payload->length - 1] >= ctx->readCipher.symCipher->params->blockSize)
{
sslErrorLog("DecryptSSLRecord: bad padding length (%d)\n",
(unsigned)payload->data[payload->length - 1]);
return errSSLRecordDecryptionFail;
}
content.length -= 1 + payload->data[payload->length - 1];
}
if (ctx->readCipher.macRef->hash->digestSize > 0)
if ((err = SSLVerifyMac(type, &content,
payload->data + content.length, ctx)) != 0)
{
return errSSLRecordBadRecordMac;
}
*payload = content;
return 0;
}
static int ssl3InitMac (
CipherContext *cipherCtx) {
const HashReference *hash;
SSLBuffer *hashCtx;
int serr;
check(cipherCtx->macRef != NULL);
hash = cipherCtx->macRef->hash;
check(hash != NULL);
hashCtx = &cipherCtx->macCtx.hashCtx;
if(hashCtx->data != NULL) {
SSLFreeBuffer(hashCtx);
}
serr = SSLAllocBuffer(hashCtx, hash->contextSize);
if(serr) {
return serr;
}
return 0;
}
static int ssl3FreeMac (
CipherContext *cipherCtx)
{
SSLBuffer *hashCtx;
check(cipherCtx != NULL);
if(cipherCtx->macRef == NULL) {
return 0;
}
hashCtx = &cipherCtx->macCtx.hashCtx;
if(hashCtx->data != NULL) {
sslFree(hashCtx->data);
hashCtx->data = NULL;
}
hashCtx->length = 0;
return 0;
}
static int ssl3ComputeMac (
uint8_t type,
SSLBuffer data,
SSLBuffer mac, CipherContext *cipherCtx, sslUint64 seqNo,
struct SSLRecordInternalContext *ctx)
{
int err;
uint8_t innerDigestData[SSL_MAX_DIGEST_LEN];
uint8_t scratchData[11], *charPtr;
SSLBuffer digest, digestCtx, scratch;
SSLBuffer secret;
const HashReference *hash;
check(cipherCtx != NULL);
check(cipherCtx->macRef != NULL);
hash = cipherCtx->macRef->hash;
check(hash != NULL);
check(hash->macPadSize <= MAX_MAC_PADDING);
check(hash->digestSize <= SSL_MAX_DIGEST_LEN);
digestCtx = cipherCtx->macCtx.hashCtx; secret.data = cipherCtx->macSecret;
secret.length = hash->digestSize;
check(SSLMACPad1[0] == 0x36 && SSLMACPad2[0] == 0x5C);
if ((err = hash->init(&digestCtx)) != 0)
goto exit;
if ((err = hash->update(&digestCtx, &secret)) != 0)
goto exit;
scratch.data = (uint8_t *)SSLMACPad1;
scratch.length = hash->macPadSize;
if ((err = hash->update(&digestCtx, &scratch)) != 0)
goto exit;
charPtr = scratchData;
charPtr = SSLEncodeUInt64(charPtr, seqNo);
*charPtr++ = type;
charPtr = SSLEncodeSize(charPtr, data.length, 2);
scratch.data = scratchData;
scratch.length = 11;
check(charPtr = scratchData+11);
if ((err = hash->update(&digestCtx, &scratch)) != 0)
goto exit;
if ((err = hash->update(&digestCtx, &data)) != 0)
goto exit;
digest.data = innerDigestData;
digest.length = hash->digestSize;
if ((err = hash->final(&digestCtx, &digest)) != 0)
goto exit;
if ((err = hash->init(&digestCtx)) != 0)
goto exit;
if ((err = hash->update(&digestCtx, &secret)) != 0)
goto exit;
scratch.data = (uint8_t *)SSLMACPad2;
scratch.length = hash->macPadSize;
if ((err = hash->update(&digestCtx, &scratch)) != 0)
goto exit;
if ((err = hash->update(&digestCtx, &digest)) != 0)
goto exit;
if ((err = hash->final(&digestCtx, &mac)) != 0)
goto exit;
err = 0;
exit:
return err;
}
const SslRecordCallouts Ssl3RecordCallouts = {
ssl3DecryptRecord,
ssl3WriteRecord,
ssl3InitMac,
ssl3FreeMac,
ssl3ComputeMac,
};