tlssocket.c   [plain text]


//
//  tlssocket.c
//  tlsnke
//
//  Created by Fabrice Gautier on 1/6/12.
//  Copyright (c) 2012 Apple, Inc. All rights reserved.
//

#include <Security/SecureTransportPriv.h>
#include <string.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#include <stdlib.h>
#include <stdio.h>
#include <assert.h>

#include <net/kext_net.h>

#include "tlssocket.h"
#include "tlsnke.h"

#include <AssertMacros.h>
#include <errno.h>

/* TLSSocket functions */

static 
int TLSSocket_Read(SSLRecordContextRef ref,
                        SSLRecord *rec)
{
    int socket = (int)ref;
    int rc;
    ssize_t sz;
    struct sockaddr_in client_addr;
    int avail;
    socklen_t avail_size;
    struct cmsghdr *cmsg;
    tls_record_hdr_t hdr;
    struct msghdr msg;
    struct iovec iov;
    int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
    uint8_t cbuf[cbuf_len];
   

    //    printf("%s: Waiting for some data...\n", __FUNCTION__);
    /* PEEK only... */
    char b;
    rc = (int)recv(socket, &b, 1, MSG_PEEK);
    
    if(rc==-1)
    {
        if(errno==EAGAIN)
            return errSSLRecordWouldBlock;
        else {
            perror("recv");
            return errno;
        }
    }
    
    /* get the next packet size */
    avail_size = sizeof(avail);
    rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
    
    check_noerr(rc); 
    check(avail_size==sizeof(avail));
    
    if(rc || (avail_size !=sizeof(avail)))
        return errSSLRecordInternal;

    //    printf("%s: Available = %d\n", __FUNCTION__, avail);
    
    if(avail==0)
        return errSSLRecordWouldBlock;

        
    /* Allocate a buffer */
    rec->contents.data = malloc(avail);
    rec->contents.length = avail;
    
    /* read the message */
    iov.iov_base = rec->contents.data;
    iov.iov_len = rec->contents.length;
    msg.msg_name = &client_addr;
    msg.msg_namelen = sizeof(client_addr);
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = cbuf;
    msg.msg_controllen = cbuf_len;
    
    sz = recvmsg(socket, &msg, 0);
    check(sz==avail);
    
    //    printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
    rec->contents.length = sz;
    
    cmsg = CMSG_FIRSTHDR(&msg);
    check(cmsg);
    if(!cmsg)
        return 0;
    
    check(cmsg->cmsg_type == SCM_TLS_HEADER);
    check(cmsg->cmsg_level == SOL_SOCKET);
    check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
    check(hdr);
    
    /* print msg info */
    /*
    printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
           iov.iov_len,
           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
           hdr->content_type, hdr->protocol_version,
           inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port)); 
    */
    rec->contentType = hdr->content_type;
    rec->protocolVersion = hdr->protocol_version;
    
    if(rec->contentType==SSL_RecordTypeChangeCipher) {
        printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
    }
    return 0;
}

static
int TLSSocket_Free(SSLRecordContextRef ref,
                         SSLRecord rec)
{
    free(rec.contents.data);
    return 0;
}

static 
int TLSSocket_Write(SSLRecordContextRef ref,
                          SSLRecord rec)
{
    int socket = (int)ref;
    ssize_t sz;
    
    struct msghdr msg;
    struct iovec iov;
    tls_record_hdr_t hdr;
    struct cmsghdr *cmsg;
    int cbuf_len=CMSG_SPACE(sizeof(*hdr));
    uint8_t cbuf[cbuf_len];

    if(rec.contentType==SSL_RecordTypeChangeCipher) {
        printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
    }
    // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);

    /* write the message */
    iov.iov_base = rec.contents.data;
    iov.iov_len = rec.contents.length;
    msg.msg_name = NULL;
    msg.msg_namelen = 0;
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = cbuf;
    msg.msg_controllen = cbuf_len;

    cmsg = CMSG_FIRSTHDR(&msg);
    cmsg->cmsg_level = SOL_SOCKET;
    cmsg->cmsg_type = SCM_TLS_HEADER;
    cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
    hdr->content_type = rec.contentType;
    hdr->protocol_version = rec.protocolVersion;
    
    /* print msg info */
    sz = sendmsg(socket, &msg, 0);
    
    if(sz<0)
        perror("sendmsg");
    
    /*
       printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
           iov.iov_len,
           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
    */
    
    check(sz==rec.contents.length);

    if(sz<0)
        return (int)sz;
    else
        return 0;
}


static
int TLSSocket_InitPendingCiphers(SSLRecordContextRef   ref,
                                       uint16_t              selectedCipher,
                                       bool                  server,
                                       SSLBuffer             key)
{
    int socket = (int)ref;
    int rc;
    char *buf;
    
    buf = malloc(key.length+3);
    buf[0] = selectedCipher >> 8;
    buf[1] = selectedCipher & 0xff;
    buf[2] = server;
    memcpy(buf+3, key.data, key.length);
    
    printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
    
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
    
    printf("%s: rc=%d\n", __FUNCTION__, rc);
    
    free(buf);
    
    return rc;
}

static 
int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
{
    int socket = (int)ref;
    int rc;
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
    
    printf("%s: rc=%d\n", __FUNCTION__, rc);
    
    return rc;
}

static 
int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
{
    int socket = (int)ref;
    int rc;
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
    
    printf("%s: rc=%d\n", __FUNCTION__, rc);
    
    return rc;
}

static 
int TLSSocket_AdvanceReadCipher(SSLRecordContextRef    ref)
{
    int socket = (int)ref;
    int rc;
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
    
    printf("%s: rc=%d\n", __FUNCTION__, rc);
    
    return rc;
}

static 
int TLSSocket_SetProtocolVersion(SSLRecordContextRef    ref,
                                 SSLProtocolVersion     protocolVersion)
{
    int socket = (int)ref;
    int rc;
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
    
    printf("%s: rc=%d\n", __FUNCTION__, rc);
    
    return rc;
}


static
int TLSSocket_ServiceWriteQueue(SSLRecordContextRef    ref)
{
    int socket = (int)ref;
    int rc;
    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);

    return rc;
}


const struct SSLRecordFuncs TLSSocket_Funcs = {
    .read                = TLSSocket_Read,
    .write               = TLSSocket_Write,
    .initPendingCiphers  = TLSSocket_InitPendingCiphers,
    .advanceWriteCipher  = TLSSocket_AdvanceWriteCipher,
    .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
    .advanceReadCipher   = TLSSocket_AdvanceReadCipher,
    .setProtocolVersion  = TLSSocket_SetProtocolVersion,
    .free                = TLSSocket_Free,
    .serviceWriteQueue   = TLSSocket_ServiceWriteQueue,
};


/* TLSSocket SPIs */

int TLSSocket_Attach(int socket)
{
    
    /* Attach the TLS socket filter and return handle */
    struct so_nke so_tlsnke;
    int rc;
    int handle;
    socklen_t len;
    
    memset(&so_tlsnke, 0, sizeof(so_tlsnke));
    so_tlsnke.nke_handle = TLS_HANDLE_IP4;
    rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
    if(rc)
        return rc;

    len = sizeof(handle);
    rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
    if(rc)
        return rc;

    assert(len==sizeof(handle));
    
    return handle;
}