#include "scod.h"
extern int scod_mech_anonymous_init(scod_mech_t);
extern int scod_mech_digest_md5_init(scod_mech_t);
extern int scod_mech_plain_init(scod_mech_t);
scod_mech_init_fn mech_inits[] = {
scod_mech_digest_md5_init,
NULL
};
scod_ctx_t scod_ctx_new(scod_callback_t cb, void *cbarg) {
int i = 0;
scod_ctx_t ctx;
scod_mech_t mech;
assert((int) cb);
log_debug(ZONE, "creating new scod context");
ctx = (scod_ctx_t) malloc(sizeof(struct _scod_ctx_st));
memset(ctx, 0, sizeof(struct _scod_ctx_st));
ctx->cb = cb;
ctx->cbarg = cbarg;
while(mech_inits[i] != NULL) {
mech = (scod_mech_t) malloc(sizeof(struct _scod_mech_st));
memset(mech, 0, sizeof(struct _scod_mech_st));
mech->ctx = ctx;
if((mech_inits[i])(mech) != sd_SUCCESS) {
log_debug(ZONE, "mech failed to init");
free(mech);
}
else {
ctx->mechs = (scod_mech_t *) realloc(ctx->mechs, sizeof(scod_mech_t) * (ctx->nmechs + 1));
ctx->mechs[ctx->nmechs] = mech;
ctx->names = (char **) realloc(ctx->names, sizeof(char *) * (ctx->nmechs + 1));
ctx->names[ctx->nmechs] = strdup(mech->name);
ctx->nmechs++;
log_debug(ZONE, "mech '%s' initialised", mech->name);
}
i++;
}
if(ctx->nmechs == 0) {
free(ctx);
return NULL;
}
return ctx;
}
void scod_ctx_free(scod_ctx_t ctx) {
int i;
assert((int) ctx);
log_debug(ZONE, "freeing scod context");
for(i = 0; i < ctx->nmechs; i++) {
log_debug(ZONE, "freeing '%s' mech", ctx->mechs[i]->name);
if(ctx->mechs[i]->free != NULL)
(ctx->mechs[i]->free)(ctx->mechs[i]);
free(ctx->mechs[i]);
free(ctx->names[i]);
}
free(ctx->mechs);
free(ctx->names);
free(ctx);
}
int scod_mech_flags(scod_ctx_t ctx, char *name) {
int i;
assert((int) ctx);
assert((int) name);
for(i = 0; i < ctx->nmechs; i++)
if(strcmp(ctx->mechs[i]->name, name) == 0)
return ctx->mechs[i]->flags;
return 0;
}
scod_t scod_new(scod_ctx_t ctx, scod_type_t type) {
scod_t sd;
assert((int) ctx);
assert((int) (type == sd_type_CLIENT || type == sd_type_SERVER));
log_debug(ZONE, "creating new scod");
sd = (scod_t) malloc(sizeof(struct _scod_st));
memset(sd, 0, sizeof(struct _scod_st));
sd->ctx = ctx;
sd->type = type;
return sd;
}
void scod_free(scod_t sd) {
assert((int) sd);
log_debug(ZONE, "freeing scod");
if(sd->authzid != NULL) free(sd->authzid);
if(sd->authnid != NULL) free(sd->authnid);
if(sd->pass != NULL) free(sd->pass);
if(sd->realm != NULL) free(sd->realm);
free(sd);
}
static scod_mech_t _scod_get_mech(scod_ctx_t ctx, char *name) {
int i;
log_debug(ZONE, "looking for mech '%s'", name);
for(i = 0; i < ctx->nmechs; i++)
if(strcmp(ctx->mechs[i]->name, name) == 0)
return ctx->mechs[i];
return NULL;
}
int scod_client_start(scod_t sd, char *name, char *authzid, char *authnid, char *pass, char **resp, int *resplen) {
int ret;
assert((int) sd);
assert((int) name);
assert((int) authnid);
assert((int) pass);
assert((int) resp);
assert((int) resplen);
*resp = NULL;
*resplen = 0;
if(sd->type != sd_type_CLIENT)
return sd_err_WRONG_TYPE;
if(sd->authd || sd->failed)
return sd_err_COMPLETED;
log_debug(ZONE, "client start; authzid=%s, authnid=%s, pass=%s", authzid, authnid, pass);
if(sd->mech != NULL)
return sd_err_IN_PROGRESS;
if((sd->mech = _scod_get_mech(sd->ctx, name)) == NULL)
return sd_err_UNKNOWN_MECH;
if(authzid != NULL) sd->authzid = strdup(authzid);
sd->authnid = strdup(authnid);
sd->pass = strdup(pass);
if(sd->mech->client_start != NULL)
ret = (sd->mech->client_start)(sd->mech, sd, resp, resplen);
else
ret = sd_err_NOT_IMPLEMENTED;
if(ret == sd_SUCCESS)
sd->authd = 1;
else if((ret & sd_auth_MASK) == sd_auth_MASK)
sd->failed = 1;
return ret;
}
int scod_client_step(scod_t sd, const char *chal, int challen, char **resp, int *resplen) {
int ret;
assert((int) sd);
assert((int) chal);
assert((int) challen);
assert((int) resp);
assert((int) resplen);
*resp = NULL;
*resplen = 0;
if(sd->type != sd_type_CLIENT)
return sd_err_WRONG_TYPE;
if(sd->authd || sd->failed)
return sd_err_COMPLETED;
log_debug(ZONE, "client step");
if(sd->mech->client_step != NULL)
ret = (sd->mech->client_step)(sd->mech, sd, chal, challen, resp, resplen);
else
ret = sd_err_NOT_IMPLEMENTED;
if(ret == sd_SUCCESS)
sd->authd = 1;
else if((ret & sd_auth_MASK) == sd_auth_MASK)
sd->failed = 1;
return ret;
}
int scod_server_start(scod_t sd, char *name, char *realm, const char *resp, int resplen, char **chal, int *challen) {
int ret;
assert((int) sd);
assert((int) name);
assert((int) resp);
assert((int) chal);
assert((int) challen);
*chal = NULL;
*challen = 0;
if(sd->type != sd_type_SERVER)
return sd_err_WRONG_TYPE;
if(sd->authd || sd->failed)
return sd_err_COMPLETED;
log_debug(ZONE, "server start");
if(sd->mech != NULL)
return sd_err_IN_PROGRESS;
if((sd->mech = _scod_get_mech(sd->ctx, name)) == NULL)
return sd_err_UNKNOWN_MECH;
if(realm != NULL)
sd->realm = strdup(realm);
if(sd->mech->server_start != NULL)
ret = (sd->mech->server_start)(sd->mech, sd, resp, resplen, chal, challen);
else
ret = sd_err_NOT_IMPLEMENTED;
if(ret == sd_SUCCESS)
sd->authd = 1;
else if((ret & sd_auth_MASK) == sd_auth_MASK)
sd->failed = 1;
return ret;
}
int scod_server_step(scod_t sd, const char *resp, int resplen, char **chal, int *challen) {
int ret;
assert((int) sd);
assert((int) resp);
assert((int) chal);
assert((int) challen);
*chal = NULL;
*challen = 0;
if(sd->type != sd_type_SERVER)
return sd_err_WRONG_TYPE;
if(sd->authd || sd->failed)
return sd_err_COMPLETED;
log_debug(ZONE, "server step");
if(sd->mech->client_start != NULL)
ret = (sd->mech->server_step)(sd->mech, sd, resp, resplen, chal, challen);
else
ret = sd_err_NOT_IMPLEMENTED;
if(ret == sd_SUCCESS)
sd->authd = 1;
else if((ret & sd_auth_MASK) == sd_auth_MASK)
sd->failed = 1;
return ret;
}
int scod_sasl_encode(scod_t sd, const char *in, int inlen, char **out, char *outlen) {
assert((int) sd);
assert((int) in);
assert((int) out);
assert((int) outlen);
log_debug(ZONE, "encode");
if(sd->mech->encode != NULL)
return (sd->mech->encode)(sd->mech, sd, in, inlen, out, outlen);
*out = (char *) malloc(sizeof(char) * inlen);
memcpy(*out, in, inlen);
*outlen = inlen;
return sd_SUCCESS;
}
int scod_sasl_decode(scod_t sd, const char *in, int inlen, char **out, char *outlen) {
assert((int) sd);
assert((int) in);
assert((int) out);
assert((int) outlen);
log_debug(ZONE, "decode");
if(sd->mech->decode != NULL)
return (sd->mech->decode)(sd->mech, sd, in, inlen, out, outlen);
*out = (char *) malloc(sizeof(char) * inlen);
memcpy(*out, in, inlen);
*outlen = inlen;
return sd_SUCCESS;
}