#include <config.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <syslog.h>
#include <signal.h>
#ifdef HAVE_STDARG_H
#include <stdarg.h>
#else
#include <varargs.h>
#endif
#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif
#include <sys/types.h>
#include <sys/stat.h>
#include <netinet/in.h>
#ifdef HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif
#include "assert.h"
#include "exitcodes.h"
#include "map.h"
#include "nonblock.h"
#include "prot.h"
#include "util.h"
#include "xmalloc.h"
struct protgroup
{
size_t nalloced;
size_t next_element;
struct protstream **group;
};
struct protstream *prot_new(fd, write)
int fd;
int write;
{
struct protstream *newstream;
newstream = (struct protstream *) xzmalloc(sizeof(struct protstream));
newstream->buf = (unsigned char *)
xmalloc(sizeof(char) * (PROT_BUFSIZE));
newstream->buf_size = PROT_BUFSIZE;
newstream->ptr = newstream->buf;
newstream->maxplain = PROT_BUFSIZE;
newstream->fd = fd;
newstream->write = write;
newstream->logfd = PROT_NO_FD;
newstream->big_buffer = PROT_NO_FD;
if(write)
newstream->cnt = PROT_BUFSIZE;
return newstream;
}
int prot_free(struct protstream *s)
{
if (s->error) free(s->error);
free(s->buf);
if(s->big_buffer != PROT_NO_FD) {
map_free(&(s->bigbuf_base), &(s->bigbuf_siz));
close(s->big_buffer);
}
free((char*)s);
return 0;
}
int prot_setlog(struct protstream *s, int fd)
{
s->logfd = fd;
return 0;
}
#ifdef HAVE_SSL
int prot_settls(struct protstream *s, SSL *tlsconn)
{
s->tls_conn = tlsconn;
SSL_set_mode(tlsconn,
SSL_MODE_ENABLE_PARTIAL_WRITE
| SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
return 0;
}
#endif
int prot_setsasl(s, conn)
struct protstream *s;
sasl_conn_t *conn;
{
const int *ssfp;
int result;
if (s->write && s->ptr != s->buf) {
if(prot_flush_internal(s,0) == EOF)
return EOF;
}
s->conn = conn;
result = sasl_getprop(conn, SASL_SSF, (const void **) &ssfp);
if (result != SASL_OK) {
return -1;
}
s->saslssf = *ssfp;
if (s->write) {
int result;
const unsigned int *maxp;
unsigned int max;
result = sasl_getprop(conn, SASL_MAXOUTBUF, (const void **) &maxp);
max = *maxp;
if (result != SASL_OK) {
return -1;
}
if (max == 0 || max > PROT_BUFSIZE) {
max = PROT_BUFSIZE;
}
s->maxplain = max;
s->cnt = max;
}
else if (s->cnt) {
s->cnt = 0;
}
return 0;
}
int prot_settimeout(struct protstream *s, int timeout)
{
assert(!s->write);
s->read_timeout = timeout;
return 0;
}
int prot_setflushonread(struct protstream *s, struct protstream *flushs)
{
assert(!s->write);
if(flushs) assert(flushs->write);
s->flushonread = flushs;
return 0;
}
int prot_setreadcallback(struct protstream *s,
prot_readcallback_t *proc, void *rock)
{
assert(!s->write);
s->readcallback_proc = proc;
s->readcallback_rock = rock;
return 0;
}
struct prot_waitevent *prot_addwaitevent(struct protstream *s, time_t mark,
prot_waiteventcallback_t *proc,
void *rock)
{
struct prot_waitevent *new, *cur;
if (!proc) return s->waitevent;
new = (struct prot_waitevent *) xmalloc(sizeof(struct prot_waitevent));
new->mark = mark;
new->proc = proc;
new->rock = rock;
new->next = NULL;
if (!s->waitevent)
s->waitevent = new;
else {
cur = s->waitevent;
while (cur && cur->next) cur = cur->next;
cur->next = new;
}
return new;
}
void prot_removewaitevent(struct protstream *s, struct prot_waitevent *event)
{
struct prot_waitevent *prev, *cur;
prev = NULL;
cur = s->waitevent;
while (cur && cur != event) {
prev = cur;
cur = cur->next;
}
if (!cur) return;
if (!prev)
s->waitevent = cur->next;
else
prev->next = cur->next;
free(cur);
}
const char *prot_error(struct protstream *s)
{
if(!s) return "bad protstream passed to prot_error";
else if(s->error) return s->error;
else if(s->eof) return PROT_EOF_STRING;
else return NULL;
}
int prot_rewind(struct protstream *s)
{
assert(!s->write);
if (lseek(s->fd, 0L, 0) == -1) {
s->error = xstrdup(strerror(errno));
return EOF;
}
s->cnt = 0;
s->error = 0;
s->eof = 0;
return 0;
}
int prot_fill(struct protstream *s)
{
int n;
unsigned char *ptr;
int left;
int r;
struct timeval timeout;
fd_set rfds;
int haveinput;
time_t read_timeout;
struct prot_waitevent *event, *next;
assert(!s->write);
errno = 0;
if (s->eof || s->error) return EOF;
do {
haveinput = 0;
#ifdef HAVE_SSL
if (s->tls_conn != NULL) {
haveinput = SSL_pending(s->tls_conn);
}
#endif
if (s->readcallback_proc ||
(s->flushonread && s->flushonread->ptr != s->flushonread->buf)) {
timeout.tv_sec = timeout.tv_usec = 0;
FD_ZERO(&rfds);
FD_SET(s->fd, &rfds);
if (!haveinput &&
(select(s->fd + 1, &rfds, (fd_set *)0, (fd_set *)0,
&timeout) <= 0)) {
if (s->readcallback_proc) {
(*s->readcallback_proc)(s, s->readcallback_rock);
s->readcallback_proc = 0;
s->readcallback_rock = 0;
}
if (s->flushonread)
prot_flush_internal(s->flushonread, !s->dontblock);
}
else {
haveinput = 1;
}
}
if (!haveinput && (s->read_timeout || s->dontblock)) {
time_t now = time(NULL);
time_t sleepfor;
read_timeout = now + (s->dontblock ? 0 : s->read_timeout);
do {
sleepfor = read_timeout - now;
for (event = s->waitevent; event; event = next)
{
next = event->next;
if (now >= event->mark) {
event = (*event->proc)(s, event, event->rock);
}
if (event && sleepfor > (event->mark - now)) {
sleepfor = event->mark - now;
}
}
timeout.tv_sec = sleepfor;
timeout.tv_usec = 0;
FD_ZERO(&rfds);
FD_SET(s->fd, &rfds);
r = select(s->fd + 1, &rfds, (fd_set *)0, (fd_set *)0,
&timeout);
now = time(NULL);
} while ((r == 0 || (r == -1 && errno == EINTR)) &&
(now < read_timeout));
if ((r == 0) ||
(r == -1 && errno == EINTR && now >= read_timeout)) {
if (!s->dontblock) {
s->error = xstrdup("idle for too long");
return EOF;
} else {
errno = EAGAIN;
return EOF;
}
}
else if (r == -1) {
syslog(LOG_ERR, "select() failed: %m");
s->error = xstrdup(strerror(errno));
return EOF;
}
}
do {
#ifdef HAVE_SSL
if (s->tls_conn != NULL) {
n = SSL_read(s->tls_conn, (char *) s->buf, PROT_BUFSIZE);
} else {
n = read(s->fd, s->buf, PROT_BUFSIZE);
}
#else
n = read(s->fd, s->buf, PROT_BUFSIZE);
#endif
} while (n == -1 && errno == EINTR);
if (n <= 0) {
if (n) s->error = xstrdup(strerror(errno));
else s->eof = 1;
return EOF;
}
if (s->saslssf) {
int result;
const char *out;
unsigned outlen;
result = sasl_decode(s->conn, (const char *) s->buf, n,
&out, &outlen);
if (result != SASL_OK) {
char errbuf[256];
const char *ed = sasl_errdetail(s->conn);
snprintf(errbuf, 256, "decoding error: %s; %s",
sasl_errstring(result, NULL, NULL),
ed ? ed : "no detail");
s->error = xstrdup(errbuf);
return EOF;
}
if (outlen > 0) {
if (outlen > s->buf_size) {
s->buf = (unsigned char *)
xrealloc(s->buf, sizeof(char) * (outlen + 4));
s->buf_size = outlen;
}
memcpy(s->buf, out, outlen);
s->ptr = s->buf + 1;
s->cnt = outlen;
} else {
s->cnt = 0;
}
} else {
s->ptr = s->buf+1;
s->cnt = n;
}
if (s->cnt > 0) {
if (s->logfd != -1) {
time_t newtime;
char timebuf[20];
time(&newtime);
snprintf(timebuf, sizeof(timebuf), "<%ld<", newtime);
write(s->logfd, timebuf, strlen(timebuf));
left = s->cnt;
ptr = s->buf;
do {
n = write(s->logfd, ptr, left);
if (n == -1 && errno != EINTR) {
break;
}
if (n > 0) {
ptr += n;
left -= n;
}
} while (left);
}
s->cnt--;
return *s->buf;
}
} while (1);
}
int prot_flush(struct protstream *s)
{
return prot_flush_internal(s, 1);
}
static void prot_flush_log(struct protstream *s)
{
if(s->logfd != PROT_NO_FD) {
unsigned char *ptr = s->buf;
int left = s->ptr - s->buf;
int n;
time_t newtime;
char timebuf[20];
time(&newtime);
snprintf(timebuf, sizeof(timebuf), ">%ld>", newtime);
write(s->logfd, timebuf, strlen(timebuf));
do {
n = write(s->logfd, ptr, left);
if (n == -1 && errno != EINTR) {
break;
}
if (n > 0) {
ptr += n;
left -= n;
}
} while (left);
}
}
static int prot_flush_encode(struct protstream *s,
const char **output_buf,
int *output_len)
{
unsigned char *ptr = s->buf;
int left = s->ptr - s->buf;
if (s->saslssf != 0) {
int result = sasl_encode(s->conn, (char *) ptr, left,
output_buf, output_len);
if (result != SASL_OK) {
char errbuf[256];
const char *ed = sasl_errdetail(s->conn);
snprintf(errbuf, 256, "encoding error: %s; %s",
sasl_errstring(result, NULL, NULL),
ed ? ed : "no detail");
s->error = xstrdup(errbuf);
return EOF;
}
} else {
*output_buf = ptr;
*output_len = left;
}
return 0;
}
static int prot_flush_writebuffer(struct protstream *s,
const char *buf, size_t len)
{
int n;
do {
#ifdef HAVE_SSL
if (s->tls_conn != NULL) {
n = SSL_write(s->tls_conn, (char *)buf, len);
} else {
n = write(s->fd, buf, len);
}
#else
n = write(s->fd, buf, len);
#endif
} while (n == -1 && errno == EINTR);
return n;
}
int prot_flush_internal(struct protstream *s, int force)
{
int n;
int save_dontblock = s->dontblock;
const char *ptr = s->buf;
int left = s->ptr - s->buf;
assert(s->write);
assert(s->cnt >= 0);
if (s->eof || s->error) {
s->ptr = s->buf;
s->cnt = 1;
return EOF;
}
if(force)
s->dontblock = 0;
if(s->dontblock != s->dontblock_isset) {
nonblock(s->fd,s->dontblock);
s->dontblock_isset = s->dontblock;
}
if(!s->dontblock) {
if(s->big_buffer != PROT_NO_FD) {
do {
n = prot_flush_writebuffer(s, s->bigbuf_base + s->bigbuf_pos,
s->bigbuf_len - s->bigbuf_pos);
if(n == -1) {
s->error = xstrdup(strerror(errno));
goto done;
} else if (n > 0) {
s->bigbuf_pos += n;
}
} while(s->bigbuf_len != s->bigbuf_pos);
map_free(&(s->bigbuf_base), &(s->bigbuf_siz));
close(s->big_buffer);
s->bigbuf_len = s->bigbuf_pos = 0;
s->big_buffer = PROT_NO_FD;
}
if(!left) {
goto done;
}
prot_flush_log(s);
if(prot_flush_encode(s, &ptr, &left) == EOF) {
goto done;
}
do {
n = prot_flush_writebuffer(s, ptr, left);
if(n == -1) {
s->error = xstrdup(strerror(errno));
goto done;
} else if (n > 0) {
ptr += n;
left -= n;
}
} while(left);
} else {
if (s->big_buffer != PROT_NO_FD) {
n = prot_flush_writebuffer(s, s->bigbuf_base + s->bigbuf_pos,
s->bigbuf_len - s->bigbuf_pos);
if(n == -1 && errno == EAGAIN) {
n = 0;
} else if(n == -1) {
s->error = xstrdup(strerror(errno));
goto done;
}
if (n > 0) {
s->bigbuf_pos += n;
}
}
if(!left) {
goto done;
}
prot_flush_log(s);
if(prot_flush_encode(s, &ptr, &left) == EOF) {
goto done;
}
if(s->big_buffer == PROT_NO_FD || s->bigbuf_pos == s->bigbuf_len) {
n = prot_flush_writebuffer(s, ptr, left);
if(n == -1 && errno == EAGAIN) {
n = 0;
} else if(n == -1) {
s->error = xstrdup(strerror(errno));
goto done;
}
if(n > 0) {
ptr += n;
left -= n;
}
}
if(left) {
struct stat sbuf;
if(s->big_buffer == PROT_NO_FD) {
int fd = create_tempfile();
if(fd == -1) {
s->error = xstrdup(strerror(errno));
goto done;
}
s->big_buffer = fd;
}
do {
n = write(s->big_buffer, ptr, left);
if (n == -1 && errno != EINTR) {
syslog(LOG_ERR, "write to protstream buffer failed: %s",
strerror(errno));
fatal("write to big buffer failed", EC_OSFILE);
}
if (n > 0) {
ptr += n;
left -= n;
}
} while (left);
if (fstat(s->big_buffer, &sbuf) == -1) {
syslog(LOG_ERR, "IOERROR: fstating temp protlayer buffer: %m");
fatal("failed to fstat protlayer buffer", EC_IOERR);
}
s->bigbuf_len = sbuf.st_size;
map_refresh(s->big_buffer, 0, &(s->bigbuf_base), &(s->bigbuf_siz),
s->bigbuf_len, "temp protlayer buffer", NULL);
}
}
s->ptr = s->buf;
s->cnt = s->maxplain;
done:
if(s->big_buffer != PROT_NO_FD &&
(s->bigbuf_pos == s->bigbuf_len || s->error)) {
map_free(&(s->bigbuf_base), &(s->bigbuf_siz));
close(s->big_buffer);
s->bigbuf_len = s->bigbuf_pos = 0;
s->big_buffer = PROT_NO_FD;
}
if(force) {
s->dontblock = save_dontblock;
}
if(s->error) {
s->ptr = s->buf;
s->cnt = s->maxplain;
return EOF;
}
return 0;
}
int prot_write(struct protstream *s, const char *buf, unsigned len)
{
assert(s->write);
if(s->error || s->eof) return EOF;
if(len == 0) return 0;
while (len >= s->cnt) {
memcpy(s->ptr, buf, s->cnt);
s->ptr += s->cnt;
buf += s->cnt;
len -= s->cnt;
s->cnt = 0;
if (prot_flush_internal(s,0) == EOF) return EOF;
}
memcpy(s->ptr, buf, len);
s->ptr += len;
s->cnt -= len;
if (s->error || s->eof) return EOF;
assert(s->cnt > 0);
return 0;
}
int prot_printf(struct protstream *s, const char *fmt, ...)
{
va_list pvar;
char *percent, *p;
long l;
unsigned long ul;
int i;
unsigned u;
char buf[30];
va_start(pvar, fmt);
assert(s->write);
while ((percent = strchr(fmt, '%')) != 0) {
prot_write(s, fmt, percent-fmt);
switch (*++percent) {
case '%':
prot_putc('%', s);
break;
case 'l':
switch (*++percent) {
case 'd':
l = va_arg(pvar, long);
snprintf(buf, sizeof(buf), "%ld", l);
prot_write(s, buf, strlen(buf));
break;
case 'u':
ul = va_arg(pvar, long);
snprintf(buf, sizeof(buf), "%lu", ul);
prot_write(s, buf, strlen(buf));
break;
default:
abort();
}
break;
case 'd':
i = va_arg(pvar, int);
snprintf(buf, sizeof(buf), "%d", i);
prot_write(s, buf, strlen(buf));
break;
case 'u':
u = va_arg(pvar, int);
snprintf(buf, sizeof(buf), "%u", u);
prot_write(s, buf, strlen(buf));
break;
case 's':
p = va_arg(pvar, char *);
prot_write(s, p, strlen(p));
break;
case 'c':
i = va_arg(pvar, int);
prot_putc(i, s);
break;
default:
abort();
}
fmt = percent+1;
}
prot_write(s, fmt, strlen(fmt));
va_end(pvar);
if (s->error || s->eof) return EOF;
return 0;
}
int prot_read(struct protstream *s, char *buf, unsigned size)
{
int c;
assert(!s->write);
if (!size) return 0;
if (s->cnt) {
if (size > s->cnt) size = s->cnt;
memcpy(buf, s->ptr, size);
s->ptr += size;
s->cnt -= size;
return size;
}
c = prot_fill(s);
if (c == EOF) return 0;
buf[0] = c;
if (--size > s->cnt) size = s->cnt;
memcpy(buf+1, s->ptr, size);
s->ptr += size;
s->cnt -= size;
return size+1;
}
int prot_select(struct protgroup *readstreams, int extra_read_fd,
struct protgroup **out, int *extra_read_flag,
struct timeval *timeout)
{
struct protstream *s, *timeout_prot = NULL;
struct protgroup *retval = NULL;
int max_fd, found_fds = 0;
int i;
fd_set rfds;
int have_readtimeout = 0;
struct timeval my_timeout;
struct prot_waitevent *event;
time_t now = time(NULL);
time_t read_timeout = 0;
assert(readstreams || extra_read_fd != PROT_NO_FD);
assert(extra_read_fd == PROT_NO_FD || extra_read_flag);
assert(out);
errno = 0;
found_fds = 0;
FD_ZERO(&rfds);
max_fd = extra_read_fd;
for(i = 0; i<readstreams->next_element; i++) {
int have_thistimeout = 0;
time_t this_timeout = 0;
s = readstreams->group[i];
assert(!s->write);
have_thistimeout = 0;
for (event = s->waitevent; event; event = event->next)
{
if(!have_thistimeout || event->mark - now < this_timeout) {
this_timeout = event->mark - now;
have_thistimeout = 1;
}
}
if(!have_thistimeout || this_timeout > s->read_timeout)
this_timeout = s->read_timeout;
if(!have_readtimeout && !s->dontblock) {
read_timeout = now + this_timeout;
have_readtimeout = 1;
if(!timeout || read_timeout <= timeout->tv_sec)
timeout_prot = s;
} else if(!s->dontblock) {
time_t new_timeout;
new_timeout = now + this_timeout;
if(new_timeout < read_timeout) {
read_timeout = new_timeout;
if(!timeout || read_timeout <= timeout->tv_sec)
timeout_prot = s;
}
}
FD_SET(s->fd, &rfds);
if(s->fd > max_fd)
max_fd = s->fd;
if(s->cnt > 0) {
found_fds++;
if(!retval)
retval = protgroup_new(readstreams->next_element + 1);
protgroup_insert(retval, s);
}
#ifdef HAVE_SSL
else if(s->tls_conn != NULL && SSL_pending(s->tls_conn)) {
found_fds++;
if(!retval)
retval = protgroup_new(readstreams->next_element + 1);
protgroup_insert(retval, s);
}
#endif
}
if(!retval) {
time_t sleepfor;
if(extra_read_fd != PROT_NO_FD) {
FD_SET(extra_read_fd, &rfds);
}
if(read_timeout < now)
sleepfor = 0;
else
sleepfor = read_timeout - now;
if((!timeout && have_readtimeout)
|| (timeout && read_timeout < timeout->tv_sec)) {
if(!timeout)
timeout = &my_timeout;
timeout->tv_sec = sleepfor;
timeout->tv_usec = 0;
}
if(select(max_fd + 1, &rfds, NULL, NULL, timeout) == -1)
return -1;
now = time(NULL);
if(extra_read_fd != PROT_NO_FD && FD_ISSET(extra_read_fd, &rfds)) {
*extra_read_flag = 1;
found_fds++;
} else if(extra_read_flag) {
*extra_read_flag = 0;
}
for(i = 0; i<readstreams->next_element; i++) {
s = readstreams->group[i];
if(FD_ISSET(s->fd, &rfds)) {
found_fds++;
if(!retval)
retval = protgroup_new(readstreams->next_element + 1);
protgroup_insert(retval, s);
} else if(s == timeout_prot && now >= read_timeout) {
if(!retval)
retval = protgroup_new(readstreams->next_element + 1);
protgroup_insert(retval, s);
}
}
}
*out = retval;
return found_fds;
}
char *prot_fgets(char *buf, unsigned size, struct protstream *s)
{
char *p = buf;
int c;
assert(!s->write);
if (size < 2 || s->eof) return 0;
size--;
while (size && (c = prot_getc(s)) != EOF) {
size--;
*p++ = c;
if (c == '\n') break;
}
if (p == buf) return 0;
*p++ = '\0';
return buf;
}
struct protgroup *protgroup_new(size_t size)
{
struct protgroup *ret = xmalloc(sizeof(struct protgroup));
if(!size) size = PROTGROUP_SIZE_DEFAULT;
ret->nalloced = size;
ret->next_element = 0;
ret->group = xzmalloc(size * sizeof(struct protstream *));
return ret;
}
struct protgroup *protgroup_copy(struct protgroup *src)
{
struct protgroup *dest;
assert(src);
dest = protgroup_new(src->nalloced);
if(src->next_element) {
memcpy(dest->group, src->group,
src->next_element * sizeof(struct protstream *));
}
return dest;
}
void protgroup_reset(struct protgroup *group)
{
if(group) {
memset(group->group, 0,
group->next_element * sizeof(struct protstream *));
group->next_element = 0;
}
}
void protgroup_free(struct protgroup *group)
{
if(group) {
assert(group->group);
free(group->group);
free(group);
}
}
void protgroup_insert(struct protgroup *group, struct protstream *item)
{
assert(group);
assert(item);
if(group->next_element == group->nalloced) {
group->nalloced *= 2;
group->group = xrealloc(group->group,
group->nalloced * sizeof(struct protstream *));
}
group->group[group->next_element++] = item;
}
struct protstream *protgroup_getelement(struct protgroup *group,
size_t element)
{
assert(group);
if(element >= group->next_element) return NULL;
else return group->group[element];
}
#undef prot_getc
#undef prot_ungetc
#undef prot_putc
int prot_getc(struct protstream *s)
{
assert(!s->write);
if (s->cnt-- > 0) {
return *(s->ptr)++;
} else {
return prot_fill(s);
}
}
int prot_ungetc(int c, struct protstream *s)
{
assert(!s->write);
s->cnt++;
*--(s->ptr) = c;
return c;
}
int prot_putc(int c, struct protstream *s)
{
assert(s->write);
assert(s->cnt > 0);
*s->ptr++ = c;
if (--s->cnt == 0) {
return prot_flush_internal(s,0);
} else {
return 0;
}
}