#include <mach/mach.h>
#include <mach/mach_error.h>
#include <servers/bootstrap.h>
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_krb5.h>
#include <pthread.h>
#include <pwd.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "gssd_mach.h"
#include "gssd_machServer.h"
#include "gssd.h"
typedef struct {
int total;
pthread_mutex_t lock[1];
} counter, *counter_t;
#define MAXCOUNTERS 6
#define gss_init_errors (&counters[0])
#define gss_accept_errors (&counters[1])
#define server_errors (&counters[2])
#define server_deaths (&counters[3])
#define key_missmatches (&counters[4])
#define successes (&counters[5])
#define DISPLAY_ERRS(name, major, minor) \
CGSSDisplay_errs((name), (major), (minor))
#define MAXHOSTNAME 256
#define MAXRETRIES 3
#define TIMEOUT 100 // 100 microseconds.
counter counters[MAXCOUNTERS];
static void waittime(void);
void *server(void *);
void *client(void *);
static void deallocate(void *, uint32_t);
static void server_done();
static void waitall(void);
static void inc_counter(counter_t);
static void report_errors(void);
static void CGSSDisplay_errs(char* rtnName, OM_uint32 maj, OM_uint32 min);
static void HexLine(const char *, uint32_t *, char [80]);
void HexDump(const char *, uint32_t, int);
typedef struct s_channel {
int client;
int failure;
pthread_mutex_t lock[1];
pthread_cond_t cv[1];
byte_buffer ctoken;
mach_msg_type_number_t ctokenCnt;
byte_buffer stoken;
mach_msg_type_number_t stokenCnt;
byte_buffer clnt_skey;
mach_msg_type_number_t clnt_skeyCnt;
} *channel_t;
#define CHANNEL_CLOSED 0x1000000
#define CHANNEL_FAILED(c) ((c)->failure & (~CHANNEL_CLOSED))
int read_channel(int d, channel_t chan);
int write_channel(int d, channel_t chan);
int close_channel(int d, channel_t chan);
#define ERR(...) fprintf(stderr, __VA_ARGS__)
#define DEBUG(...) fprintf(stdout, __VA_ARGS__)
static char *optstrs[] = {
" if no host is specified, use the local host",
"[-C] don't canonicalize the host name",
"[-D] don't use the default credential",
"[-e] exit on mach rpc errors",
"[-f] flags for init sec context",
"[-h] print this usage message",
"[-H] don't access home directory",
"[-i] run interactively",
"[-k] use kerberos service principal name, otherwise",
" use host base service name",
"[-M retries] max retries before giving up on server death",
"[-m krb5 | spnego] mech to use, defaults to krb5",
"[-n n] number of experiments to run",
"[-p principal] use princial for client",
"[-r realm] use realm for kerberos",
"[-s n] number of concurent servers (and clients) to run",
"[-t usecs] averge time to wait in the client",
" This is a random time beteen 0 and 2*usecs",
"[-u user] creditials to run as",
"[-U] don't bring up UI.",
"[-v] verbose flag. May be repeated",
#ifdef TASK_GSSD_PORT
"",
#else
"[-N bootstrap label] bootstrap name",
#endif
};
static void
Usage(const char *prog)
{
unsigned int i;
ERR("Usage: %s [options] [host]\n", prog);
for (i = 0; i < sizeof(optstrs)/sizeof(char *); i++)
ERR("\t%s\n", optstrs[i]);
exit(EXIT_FAILURE);
}
int timeout = TIMEOUT;
int verbose = 0;
int max_retries = MAXRETRIES;
int exitonerror = 0;
int interactive = 0;
uint32_t uid;
uint32_t flags;
char *principal="";
char svcname[1024];
mach_port_t mp;
pthread_cond_t num_servers_cv[1];
pthread_mutex_t num_servers_lock[1];
int num_servers;
mechtype mech = DEFAULT_MECH;
int main(int argc, char *argv[])
{
char *bname = NULL;
int i, j, ch;
int error;
int num = 1;
int Servers = 1;
int use_kerberos = 0;
pthread_t thread;
pthread_attr_t attr[1];
char hostbuf[MAXHOSTNAME];
char *host = hostbuf;
char *realm = NULL;
char *prog;
struct passwd *pent;
kern_return_t kr;
uid = getuid();
prog = strrchr(argv[0], '/');
prog = prog ? prog + 1 : argv[0];
while ((ch = getopt(argc, argv, "CDefhHikN:n:M:m:p:r:s:t:u:Uv")) != -1) {
switch (ch) {
case 'C':
flags |= GSSD_NO_CANON;
break;
case 'D':
flags |= GSSD_NO_DEFAULT;
break;
case 'e':
exitonerror = 1;
break;
case 'f':
flags |= (atoi(optarg) & 0xffff);
break;
case 'H':
flags |= GSSD_NO_HOME_ACCESS;
break;
case 'i':
interactive = 1;
break;
case 'k':
use_kerberos = 1;
break;
case 'N':
bname = optarg;
break;
case 'n':
num = atoi(optarg);
break;
case 'M':
max_retries = atoi(optarg);
break;
case 'm':
if (strcmp(optarg, "krb5") == 0)
mech = KRB5_MECH;
else if (strcmp(optarg, "spnego") == 0)
mech = SPNEGO_MECH;
else {
ERR("Unavailable gss mechanism %s\n", optarg);
exit(EXIT_FAILURE);
}
break;
case 'p':
principal = optarg;
break;
case 'r':
realm = optarg;
break;
case 's':
Servers = atoi(optarg);
break;
case 't':
timeout = atoi(optarg);
break;
case 'u':
pent = getpwnam(optarg);
if (pent)
uid = pent->pw_uid;
else
ERR("Could no find user %s\n", optarg);
break;
case 'U':
flags |= GSSD_NO_UI;
break;
case 'v':
verbose++;
break;
default:
Usage(prog);
break;
}
}
argc -= optind;
argv += optind;
if (argc == 0) {
gethostname(hostbuf, MAXHOSTNAME);
} else if (argc == 1) {
host = argv[0];
} else {
Usage(prog);
}
if (principal)
printf("Using creds for %s host=%s\n", principal, host);
else
printf("Creds for uid=%d host=%s\n", uid, host);
if (use_kerberos) {
strlcpy(svcname, "nfs/", sizeof(svcname));
strlcat(svcname, host, sizeof(svcname));
if (realm) {
strlcat(svcname, "@", sizeof(svcname));
strlcat(svcname, realm, sizeof(svcname));
}
} else {
strlcpy(svcname, "nfs@", sizeof(svcname));
strlcat(svcname, host, sizeof(svcname));
}
printf("Service name = %s\n", svcname);
if (!bname) {
bname = "com.apple.gssd-agent";
}
if (interactive) {
printf("Hit enter to start ");
(void) getchar();
}
#ifdef TASK_GSSD_PORT
kr = task_get_gssd_port(mach_task_self(), &mp);
if (kr != KERN_SUCCESS) {
ERR("task_get_gssd_port(): %s\n", mach_error_string(kr));
exit(EXIT_FAILURE);
}
#else
kr = bootstrap_look_up(bootstrap_port, bname, &mp);
if (kr != KERN_SUCCESS) {
ERR("bootstrap_look_up(): %s\n", mach_error_string(kr));
exit(EXIT_FAILURE);
}
#endif
if (!MACH_PORT_VALID(mp)) {
ERR("Could not get a valid port (%d)\n", mp);
exit(EXIT_FAILURE);
}
pthread_attr_init(attr);
pthread_attr_setdetachstate(attr, PTHREAD_CREATE_DETACHED);
for (i = 0; i < MAXCOUNTERS; i++)
pthread_mutex_init(counters[i].lock, NULL);
pthread_mutex_init(num_servers_lock, NULL);
pthread_cond_init(num_servers_cv, NULL);
for (j = 0; j < MAXCOUNTERS; j++)
counters[j].total = 0;
for (i = 0; i < num; i++) {
num_servers = Servers;
for (j = 0; j < num_servers; j++) {
error = pthread_create(&thread, attr, server, NULL);
if (error)
ERR("Could not start server: %s\n",
strerror(error));
}
waitall();
}
report_errors();
pthread_attr_destroy(attr);
kr = mach_port_deallocate(mach_task_self(), mp);
if (kr != KERN_SUCCESS) {
ERR("Coun not delete send right!\n");
}
if (interactive) {
printf("Hit enter to stop\n");
(void) getchar();
}
return (0);
}
static void
waittime(void)
{
struct timespec to;
if (timeout == 0)
return;
to.tv_sec = 0;
to.tv_nsec = (random() % (2*1000*timeout));
nanosleep(&to, NULL);
}
static void
inc_counter(counter_t count)
{
pthread_mutex_lock(count->lock);
count->total++;
pthread_mutex_unlock(count->lock);
}
static void
report_errors(void)
{
printf("gss_init_errors %d\n", gss_init_errors->total);
printf("gss_accept_errors %d\n", gss_accept_errors->total);
printf("server_errors %d\n", server_errors->total);
printf("server_deaths %d\n", server_deaths->total);
}
static void
server_done(void)
{
pthread_mutex_lock(num_servers_lock);
num_servers-- ;
if (num_servers == 0)
pthread_cond_signal(num_servers_cv);
pthread_mutex_unlock(num_servers_lock);
}
static void
waitall(void)
{
pthread_mutex_lock(num_servers_lock);
while (num_servers > 0)
pthread_cond_wait(num_servers_cv, num_servers_lock);
pthread_mutex_unlock(num_servers_lock);
}
static void
deallocate(void *addr, uint32_t size)
{
if (addr == NULL || size == 0)
return;
(void) vm_deallocate(mach_task_self(), (vm_address_t)addr, (vm_size_t)size);
}
int read_channel(int d, channel_t chan)
{
pthread_mutex_lock(chan->lock);
while (chan->client != d && !chan->failure)
pthread_cond_wait(chan->cv, chan->lock);
waittime();
if (chan->failure) {
pthread_mutex_unlock(chan->lock);
return (-1);
}
return (0);
}
int write_channel(int d, channel_t chan)
{
if (chan->client != d)
ERR("Writing out of turn\n");
chan->client = !d;
pthread_cond_signal(chan->cv);
pthread_mutex_unlock(chan->lock);
return (0);
}
int close_channel(int d, channel_t chan)
{
int rc;
pthread_mutex_lock(chan->lock);
while (chan->client != d && !chan->failure)
pthread_cond_wait(chan->cv, chan->lock);
rc = chan->failure;
chan->failure |= CHANNEL_CLOSED;
chan->client = d;
pthread_cond_signal(chan->cv);
pthread_mutex_unlock(chan->lock);
return (rc);
}
void *client(void *arg)
{
channel_t channel = (channel_t)arg;
uint32_t major_stat;
uint32_t minor_stat;
uint32_t cred_handle = (uint32_t) GSS_C_NO_CREDENTIAL;
uint32_t gss_context = (uint32_t) GSS_C_NO_CONTEXT;
gssd_verifier verifier;
kern_return_t kr;
int gss_error = 0;
int retry_count = 0;
do {
if (read_channel(1, channel)) {
ERR("Bad read from server\n");
return (NULL);
}
if (verbose)
DEBUG("Calling mach_gss_init_sec_context from %p\n",
pthread_self());
deallocate(channel->ctoken, channel->ctokenCnt);
channel->ctoken = (byte_buffer)GSS_C_NO_BUFFER;
channel->ctokenCnt = 0;
retry:
kr = mach_gss_init_sec_context(
mp,
mech,
channel->stoken, channel->stokenCnt,
uid,
principal,
svcname,
flags,
&verifier,
&gss_context,
&cred_handle,
&channel->clnt_skey, &channel->clnt_skeyCnt,
&channel->ctoken, &channel->ctokenCnt,
&major_stat,
&minor_stat);
if (kr != KERN_SUCCESS) {
inc_counter(server_errors);
ERR("gsstest: %s\n", mach_error_string(kr));
if (exitonerror)
exit(1);
if (kr == MIG_SERVER_DIED) {
inc_counter(server_deaths);
if (gss_context == (uint32_t)GSS_C_NO_CONTEXT &&
retry_count < max_retries) {
retry_count++;
goto retry;
}
}
channel->failure = 1;
write_channel(1, channel);
return (NULL);
}
gss_error = (major_stat != GSS_S_COMPLETE &&
major_stat != GSS_S_CONTINUE_NEEDED);
if (verbose > 1) {
DEBUG("\tcred = 0x%0x\n", (int) cred_handle);
DEBUG("\tclnt_gss_context = 0x%0x\n",
(int) gss_context);
DEBUG("\ttokenCnt = %d\n", (int) channel->ctokenCnt);
if (verbose > 2)
HexDump((char *) channel->ctoken,
(uint32_t) channel->ctokenCnt, 1);
}
channel->failure = gss_error;
write_channel(1, channel);
} while (major_stat == GSS_S_CONTINUE_NEEDED);
if (gss_error) {
inc_counter(gss_init_errors);
DISPLAY_ERRS("mach_gss_init_sec_context: ",
major_stat, minor_stat);
}
close_channel(1, channel);
return (NULL);
}
void *server(void *arg __attribute__((unused)))
{
struct s_channel args;
channel_t channel = &args;
pthread_t client_thr;
int error;
uint32_t major_stat;
uint32_t minor_stat;
gssd_verifier verifier;
uint32_t cred_handle;
uint32_t gss_context;
uint32_t clnt_uid;
uint32_t clnt_gids[NGROUPS_MAX];
uint32_t clnt_ngroups;
byte_buffer svc_skey;
mach_msg_type_number_t svc_skeyCnt;
kern_return_t kr;
int retry_count = 0;
channel->client = 1;
channel->failure = 0;
pthread_mutex_init(channel->lock, NULL);
pthread_cond_init(channel->cv, NULL);
channel->ctoken = (byte_buffer) GSS_C_NO_BUFFER;
channel->ctokenCnt = 0;
channel->stoken = (byte_buffer) GSS_C_NO_BUFFER;
channel->stokenCnt = 0;
channel->clnt_skey = (byte_buffer) GSS_C_NO_BUFFER;
channel->clnt_skeyCnt = 0;
error = pthread_create(&client_thr, NULL, client, channel);
if (error) {
ERR("Could not start client: %s\n", strerror(error));
return NULL;
}
do {
if (read_channel(0, channel) == -1) {
ERR("Bad read from client\n");
goto out;
}
deallocate(channel->stoken, channel->stokenCnt);
channel->stoken = (byte_buffer)GSS_C_NO_BUFFER;
channel->stokenCnt = 0;
if (verbose)
DEBUG("Calling mach_gss_accept_sec_contex %p\n",
pthread_self());
retry:
kr = mach_gss_accept_sec_context(
mp,
channel->ctoken, channel->ctokenCnt,
svcname,
flags,
&verifier,
&gss_context,
&cred_handle,
&clnt_uid,
clnt_gids,
&clnt_ngroups,
&svc_skey, &svc_skeyCnt,
&channel->stoken, &channel->stokenCnt,
&major_stat,
&minor_stat);
if (kr != KERN_SUCCESS) {
inc_counter(server_errors);
ERR("gsstest: %s\n", mach_error_string(kr));
if (exitonerror)
exit(1);
if (kr == MIG_SERVER_DIED) {
inc_counter(server_deaths);
if (gss_context == (uint32_t)GSS_C_NO_CONTEXT &&
retry_count < max_retries) {
retry_count++;
goto retry;
}
}
channel->failure = 1;
write_channel(0, channel);
goto out;
}
error = (major_stat != GSS_S_COMPLETE &&
major_stat != GSS_S_CONTINUE_NEEDED);
channel->failure = error;
write_channel(0, channel);
} while (major_stat == GSS_S_CONTINUE_NEEDED);
if (error) {
inc_counter(gss_accept_errors);
DISPLAY_ERRS("mach_gss_accept_sec_context: ",
major_stat, minor_stat);
}
out:
close_channel(0, channel);
pthread_join(client_thr, NULL);
if (major_stat == GSS_S_COMPLETE && !CHANNEL_FAILED(channel)) {
if (svc_skeyCnt != channel->clnt_skeyCnt ||
memcmp(svc_skey, channel->clnt_skey, svc_skeyCnt)) {
ERR("Session keys don't match!\n");
ERR("\tClient key length = %d\n",
channel->clnt_skeyCnt);
HexDump((char *) channel->clnt_skey,
(uint32_t) channel->clnt_skeyCnt, 1);
ERR("\tServer key length = %d\n", svc_skeyCnt);
HexDump((char *) svc_skey, (uint32_t) svc_skeyCnt, 0);
if (uid != clnt_uid)
ERR("Wrong uid. got %d expected %d\n",
clnt_uid, uid);
}
else if (verbose) {
DEBUG("\tSession key length = %d\n", svc_skeyCnt);
HexDump((char *) svc_skey, (uint32_t) svc_skeyCnt, 1);
DEBUG("\tReturned uid = %d\n", uid);
}
} else if (verbose > 1) {
DEBUG("Failed major status = %d\n", major_stat);
DEBUG("Channel failure = %x\n", channel->failure);
}
deallocate(svc_skey, svc_skeyCnt);
deallocate(channel->ctoken, channel->ctokenCnt);
deallocate(channel->stoken, channel->stokenCnt);
deallocate(channel->clnt_skey, channel->clnt_skeyCnt);
pthread_mutex_destroy(channel->lock);
pthread_cond_destroy(channel->cv);
server_done();
return (NULL);
}
static void
CGSSDisplay_errs(char* rtnName, OM_uint32 maj, OM_uint32 min)
{
OM_uint32 msg_context = 0;
OM_uint32 min_stat = 0;
OM_uint32 maj_stat = 0;
gss_buffer_desc errBuf;
int count = 1;
ERR("Error returned by %s:\n", rtnName);
do {
maj_stat = gss_display_status(&min_stat, maj, GSS_C_GSS_CODE,
GSS_C_NULL_OID, &msg_context, &errBuf);
ERR("\tmajor error %d: %s\n", count, (char *)errBuf.value);
maj_stat = gss_release_buffer(&min_stat, &errBuf);
count++;
} while (msg_context != 0);
count = 1;
msg_context = 0;
do {
maj_stat = gss_display_status (&min_stat, min, GSS_C_MECH_CODE,
GSS_C_NULL_OID, &msg_context, &errBuf);
ERR("\tminor error %d: %s\n", count, (char *)errBuf.value);
count++;
} while (msg_context != 0);
}
static const char HexChars[16] = { '0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F' };
static void
HexLine(const char *buf, uint32_t *bufSize, char linebuf[80])
{
char *bptr = buf;
int limit;
int i;
char *cptr = linebuf;
memset(linebuf,0,sizeof(linebuf));
limit = (*bufSize > 16) ? 16 : *bufSize;
*bufSize -= limit;
for(i = 0; i < 16; i++)
{
if(i < limit)
{
*cptr++ = HexChars[(*bptr >> 4) & 0x0f];
*cptr++ = HexChars[*bptr & 0x0f];
*cptr++ = ' ';
bptr++;
} else {
*cptr++ = ' ';
*cptr++ = ' ';
*cptr++ = ' ';
}
}
bptr = buf;
*cptr++ = ' ';
*cptr++ = ' ';
*cptr++ = ' ';
for(i = 0; i < limit; i++)
{
*cptr++ = (char) (((*bptr > 0x1f) && (*bptr < 0x7f)) ? *bptr : '.');
bptr++;
}
*cptr++ = '\n';
*cptr = '\0';
}
void
HexDump(const char *inBuffer, uint32_t inLength, int debug)
{
uint32_t currentSize = inLength;
char linebuf[80];
while(currentSize > 0)
{
HexLine(inBuffer, ¤tSize, linebuf);
if (debug)
DEBUG("%s", linebuf);
else
ERR("%s", linebuf);
inBuffer += 16;
}
}