#include "ssl.h"
#include "sslContext.h"
#include "sslSession.h"
#include "sslMemory.h"
#include "sslUtils.h"
#include "sslDebug.h"
#include "cipherSpecs.h"
#include "appleSession.h"
#include <assert.h>
#include <string.h>
#include <stddef.h>
typedef struct
{ int sessionIDLen;
UInt8 sessionID[32];
SSLProtocolVersion protocolVersion;
UInt16 cipherSuite;
UInt16 padding;
UInt8 masterSecret[48];
int certCount;
UInt8 certs[1];
} ResumableSession;
OSStatus
SSLAddSessionData(const SSLContext *ctx)
{ OSStatus err;
uint32 sessionIDLen;
SSLBuffer sessionID;
ResumableSession *session;
int certCount;
SSLCertificate *cert;
uint8 *certDest;
if (ctx->peerID.data == 0)
return errSSLSessionNotFound;
sessionIDLen = offsetof(ResumableSession, certs);
cert = ctx->peerCert;
certCount = 0;
while (cert)
{ ++certCount;
sessionIDLen += 4 + cert->derCert.length;
cert = cert->next;
}
if ((err = SSLAllocBuffer(sessionID, sessionIDLen, ctx)) != 0)
return err;
session = (ResumableSession*)sessionID.data;
session->sessionIDLen = ctx->sessionID.length;
memcpy(session->sessionID, ctx->sessionID.data, session->sessionIDLen);
session->protocolVersion = ctx->negProtocolVersion;
session->cipherSuite = ctx->selectedCipher;
memcpy(session->masterSecret, ctx->masterSecret, 48);
session->certCount = certCount;
session->padding = 0;
certDest = session->certs;
cert = ctx->peerCert;
while (cert)
{ certDest = SSLEncodeInt(certDest, cert->derCert.length, 4);
memcpy(certDest, cert->derCert.data, cert->derCert.length);
certDest += cert->derCert.length;
cert = cert->next;
}
err = sslAddSession(ctx->peerID, sessionID);
SSLFreeBuffer(sessionID, ctx);
return err;
}
OSStatus
SSLGetSessionData(SSLBuffer *sessionData, const SSLContext *ctx)
{ OSStatus err;
if (ctx->peerID.data == 0)
return errSSLSessionNotFound;
sessionData->data = 0;
err = sslGetSession(ctx->peerID, sessionData);
if (sessionData->data == 0)
return errSSLSessionNotFound;
return err;
}
OSStatus
SSLDeleteSessionData(const SSLContext *ctx)
{ OSStatus err;
if (ctx->peerID.data == 0)
return errSSLSessionNotFound;
err = sslDeleteSession(ctx->peerID);
return err;
}
OSStatus
SSLRetrieveSessionID(
const SSLBuffer sessionData,
SSLBuffer *identifier,
const SSLContext *ctx)
{ OSStatus err;
ResumableSession *session;
session = (ResumableSession*) sessionData.data;
if ((err = SSLAllocBuffer(*identifier, session->sessionIDLen, ctx)) != 0)
return err;
memcpy(identifier->data, session->sessionID, session->sessionIDLen);
return noErr;
}
OSStatus
SSLRetrieveSessionProtocolVersion(
const SSLBuffer sessionData,
SSLProtocolVersion *version,
const SSLContext *ctx)
{ ResumableSession *session;
session = (ResumableSession*) sessionData.data;
*version = session->protocolVersion;
return noErr;
}
#define ALLOW_CIPHERSPEC_CHANGE 1
OSStatus
SSLInstallSessionFromData(const SSLBuffer sessionData, SSLContext *ctx)
{ OSStatus err;
ResumableSession *session;
uint8 *storedCertProgress;
SSLCertificate *cert, *lastCert;
int certCount;
uint32 certLen;
session = (ResumableSession*)sessionData.data;
assert(ctx->negProtocolVersion == session->protocolVersion);
if(ctx->negProtocolVersion == SSL_Version_2_0) {
if(ctx->protocolSide == SSL_ClientSide) {
assert(ctx->selectedCipher == 0);
ctx->selectedCipher = session->cipherSuite;
}
else {
if(ctx->selectedCipher != session->cipherSuite) {
sslErrorLog("+++SSL2: CipherSpec change from %d to %d on session "
"resume\n",
session->cipherSuite, ctx->selectedCipher);
return errSSLProtocol;
}
}
}
else {
assert(ctx->selectedCipher != 0);
if(ctx->selectedCipher != session->cipherSuite) {
#if ALLOW_CIPHERSPEC_CHANGE
sslErrorLog("+++WARNING: CipherSpec change from %d to %d "
"on session resume\n",
session->cipherSuite, ctx->selectedCipher);
#else
sslErrorLog("+++SSL: CipherSpec change from %d to %d on session resume\n",
session->cipherSuite, ctx->selectedCipher);
return errSSLProtocol;
#endif
}
}
if ((err = FindCipherSpec(ctx)) != 0) {
return err;
}
memcpy(ctx->masterSecret, session->masterSecret, 48);
lastCert = 0;
storedCertProgress = session->certs;
certCount = session->certCount;
while (certCount--)
{
cert = (SSLCertificate *)sslMalloc(sizeof(SSLCertificate));
if(cert == NULL) {
return memFullErr;
}
cert->next = 0;
certLen = SSLDecodeInt(storedCertProgress, 4);
storedCertProgress += 4;
if ((err = SSLAllocBuffer(cert->derCert, certLen, ctx)) != 0)
{
sslFree(cert);
return err;
}
memcpy(cert->derCert.data, storedCertProgress, certLen);
storedCertProgress += certLen;
if (lastCert == 0)
ctx->peerCert = cert;
else
lastCert->next = cert;
lastCert = cert;
}
return noErr;
}