sslUtils.cpp   [plain text]


/*
 * Copyright (c) 2000-2001 Apple Computer, Inc. All Rights Reserved.
 * 
 * The contents of this file constitute Original Code as defined in and are
 * subject to the Apple Public Source License Version 1.2 (the 'License').
 * You may not use this file except in compliance with the License. Please obtain
 * a copy of the License at http://www.apple.com/publicsource and read it before
 * using this file.
 * 
 * This Original Code and all software distributed under the License are
 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS
 * OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES, INCLUDING WITHOUT
 * LIMITATION, ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
 * PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT. Please see the License for the
 * specific language governing rights and limitations under the License.
 */


/*
	File:		sslutils.ccpp

	Contains:	Misc. SSL utility functions

	Written by:	Doug Mitchell

	Copyright: (c) 1999 by Apple Computer, Inc., all rights reserved.

*/

#include "sslContext.h"
#include "sslUtils.h"
#include "sslMemory.h"
#include "sslDebug.h"
#include <Security/devrandom.h>

#include <CoreServices/../Frameworks/CarbonCore.framework/Headers/MacTypes.h>
#include <sys/time.h>

UInt32
SSLDecodeInt(const unsigned char *p, int length)
{   UInt32  val = 0;
    while (length--)
        val = (val << 8) | *p++;
    return val;
}

unsigned char *
SSLEncodeInt(unsigned char *p, UInt32 value, int length)
{   unsigned char   *retVal = p + length;       /* Return pointer to char after int */
    assert(length > 0 && length <= 4);
    while (length--)                /* Assemble backwards */
    {   p[length] = (UInt8)value;   /* Implicit masking to low byte */
        value >>= 8;
    }
    return retVal;
}

UInt8*
SSLEncodeUInt64(UInt8 *p, sslUint64 value)
{   p = SSLEncodeInt(p, value.high, 4);
    return SSLEncodeInt(p, value.low, 4);
}


void
IncrementUInt64(sslUint64 *v)
{   if (++v->low == 0)          /* Must have just rolled over */
        ++v->high;
}

UInt32
SSLGetCertificateChainLength(const SSLCertificate *c)
{   
	UInt32 rtn = 0;
	
    while (c)
    {   
    	rtn++;
        c = c->next;
    }
    return rtn;
}

Boolean sslIsSessionActive(const SSLContext *ctx)
{
	assert(ctx != NULL);
	switch(ctx->state) {
		case SSL_HdskStateUninit:
		case SSL_HdskStateServerUninit:
		case SSL_HdskStateClientUninit:
		case SSL_HdskStateGracefulClose:
		case SSL_HdskStateErrorClose:
			return false;
		default:
			return true;
	}
}

OSStatus sslDeleteCertificateChain(
    SSLCertificate		*certs,
	SSLContext 			*ctx)
{	
	SSLCertificate		*cert;
	SSLCertificate		*nextCert;
	
	assert(ctx != NULL);
	cert=certs;
	while(cert != NULL) {
		nextCert = cert->next;
		SSLFreeBuffer(cert->derCert, ctx);
		sslFree(cert);
		cert = nextCert;
	}
	return noErr;
}

#if	SSL_DEBUG

const char *protocolVersStr(SSLProtocolVersion prot)
{
	switch(prot) {
 	case SSL_Version_Undetermined: return "SSL_Version_Undetermined";
 	case SSL_Version_2_0: return "SSL_Version_2_0";
 	case SSL_Version_3_0: return "SSL_Version_3_0";
 	case TLS_Version_1_0: return "TLS_Version_1_0";
 	default: sslErrorLog("protocolVersStr: bad prot\n"); return "BAD PROTOCOL";
 	}
 	return NULL;	/* NOT REACHED */
}

#endif	/* SSL_DEBUG */

/*
 * Redirect SSLBuffer-based I/O call to user-supplied I/O. 
 */ 
OSStatus sslIoRead(
 	SSLBuffer 		buf, 
 	size_t 			*actualLength, 
 	SSLContext 		*ctx)
{
 	UInt32 		dataLength = buf.length;
 	OSStatus	ortn;
 		
	*actualLength = 0;
	ortn = (ctx->ioCtx.read)(ctx->ioCtx.ioRef,
		buf.data,
		&dataLength);
	*actualLength = dataLength;
	return ortn;
}
 
OSStatus sslIoWrite(
 	SSLBuffer 		buf, 
 	size_t 			*actualLength, 
 	SSLContext 		*ctx)
{
 	UInt32 			dataLength = buf.length;
 	OSStatus		ortn;
 		
	*actualLength = 0;
	ortn = (ctx->ioCtx.write)(ctx->ioCtx.ioRef,
		buf.data,
		&dataLength);
	*actualLength = dataLength;
	return ortn;
}

OSStatus sslTime(UInt32 *tim)
{
	time_t t;
	time(&t);
	*tim = (UInt32)t;
	return noErr;
}

/*
 * Common RNG function.
 */
OSStatus sslRand(SSLContext *ctx, SSLBuffer *buf)
{
	OSStatus		serr = noErr;
	
	assert(ctx != NULL);
	assert(buf != NULL);
	assert(buf->data != NULL);
	
	if(buf->length == 0) {
		sslErrorLog("sslRand: zero buf->length\n");
		return noErr;
	}
	try {
		Security::DevRandomGenerator devRand(false);
		devRand.random(buf->data, buf->length);
	}
	catch(...) {
		serr = errSSLCrypto;
	}
	return serr;
}

/*
 * Given a protocol version sent by peer, determine if we accept that version
 * and downgrade if appropriate (which can not be done for the client side).
 */
OSStatus sslVerifyProtVersion(
	SSLContext 			*ctx,
	SSLProtocolVersion	peerVersion,	// sent by peer
	SSLProtocolVersion 	*negVersion)	// final negotiated version if return success
{
	OSStatus ortn = noErr;
	
	switch(peerVersion) {
		case SSL_Version_2_0:
			if(ctx->versionSsl2Enable) {
				*negVersion = SSL_Version_2_0;
			}
			else {
				/* SSL2 is the best peer can do but we don't support it */
				ortn = errSSLNegotiation;
			}
			break;
		case SSL_Version_3_0:
			if(ctx->versionSsl3Enable) {
				*negVersion = SSL_Version_3_0;
			}
			/* downgrade if possible */
			else if(ctx->protocolSide == SSL_ClientSide) {
				/* client side - no more negotiation possible */
				ortn = errSSLNegotiation;
			}
			else if(ctx->versionSsl2Enable) {
				/* server downgrading to SSL2 */
				*negVersion = SSL_Version_2_0;
			}
			else {
				/* Peer requested SSL3, we don't support SSL2 or SSL3 */
				ortn = errSSLNegotiation;
			}
			break;
		case TLS_Version_1_0:
			if(ctx->versionTls1Enable) {
				*negVersion = TLS_Version_1_0;
			}
			/* downgrade if possible */
			else if(ctx->protocolSide == SSL_ClientSide) {
				/* 
				 * Client side - no more negotiation possible 
				 * Note this actually implies a pretty serious server
				 * side violation; it's sending back a protocol version
				 * HIGHER than we requested 
				 */
				ortn = errSSLNegotiation;
			}
			else if(ctx->versionSsl3Enable) {
				/* server downgrading to SSL3 */
				*negVersion = SSL_Version_3_0;
			}
			else if(ctx->versionSsl2Enable) {
				/* server downgrading to SSL2 */
				*negVersion = SSL_Version_2_0;
			}
			else {
				/* we appear not to support any protocol */
				sslErrorLog("sslVerifyProtVersion: no protocols supported\n");
				ortn = errSSLNegotiation;
			}
			break;
		default:
			ortn = errSSLNegotiation;
			break;
		
	}
	return ortn;
}

/*
 * Determine max enabled protocol, i.e., the one we try to negotiate for.
 * Only returns an error (paramErr) if NO protocols are enabled, which can
 * in fact happen by malicious or ignorant use of SSLSetProtocolVersionEnabled().
 */
OSStatus sslGetMaxProtVersion(
	SSLContext 			*ctx,
	SSLProtocolVersion	*version)	// RETURNED
{
	OSStatus ortn = noErr;
	if(ctx->versionTls1Enable) {
		*version = TLS_Version_1_0;
	}
	else if(ctx->versionSsl3Enable) {
		*version =  SSL_Version_3_0;
	}
	else if(ctx->versionSsl2Enable) {
		*version =  SSL_Version_2_0;
	}
	else {
		ortn = paramErr;
	}
	return ortn;
}