#include "lib.h"
#include "llist.h"
#include "str.h"
#include "network.h"
#include "istream.h"
#include "ostream.h"
#include "eacces-error.h"
#include "dict-private.h"
#include "dict-client.h"
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#define DICT_CLIENT_TIMEOUT_MSECS 0
#define DICT_CLIENT_READ_TIMEOUT_SECS 30
#define DICT_CLIENT_READ_WARN_TIMEOUT_SECS 5
struct client_dict {
struct dict dict;
pool_t pool;
int fd;
const char *uri;
const char *username;
const char *path;
enum dict_data_type value_type;
time_t last_failed_connect;
struct istream *input;
struct ostream *output;
struct io *io;
struct timeout *to_idle;
struct client_dict_transaction_context *transactions;
unsigned int connect_counter;
unsigned int transaction_id_counter;
unsigned int async_commits;
unsigned int in_iteration:1;
unsigned int handshaked:1;
};
struct client_dict_iterate_context {
struct dict_iterate_context ctx;
pool_t pool;
bool failed;
};
struct client_dict_transaction_context {
struct dict_transaction_context ctx;
struct client_dict_transaction_context *prev, *next;
dict_transaction_commit_callback_t *callback;
void *context;
unsigned int id;
unsigned int connect_counter;
unsigned int failed:1;
unsigned int sent_begin:1;
};
static int client_dict_connect(struct client_dict *dict);
static void client_dict_disconnect(struct client_dict *dict);
const char *dict_client_escape(const char *src)
{
const char *p;
string_t *dest;
for (p = src; *p != '\0'; p++) {
if (*p == '\t' || *p == '\n' || *p == '\001')
break;
}
if (*p == '\0')
return src;
dest = t_str_new(256);
str_append_n(dest, src, p - src);
for (; *p != '\0'; p++) {
switch (*p) {
case '\t':
str_append_c(dest, '\001');
str_append_c(dest, 't');
break;
case '\n':
str_append_c(dest, '\001');
str_append_c(dest, 'n');
break;
case '\001':
str_append_c(dest, '\001');
str_append_c(dest, '1');
break;
default:
str_append_c(dest, *p);
break;
}
}
return str_c(dest);
}
const char *dict_client_unescape(const char *src)
{
const char *p;
string_t *dest;
for (p = src; *p != '\0'; p++) {
if (*p == '\001')
break;
}
if (*p == '\0')
return src;
dest = t_str_new(256);
str_append_n(dest, src, p - src);
for (; *p != '\0'; p++) {
if (*p != '\001')
str_append_c(dest, *p);
else if (p[1] != '\0') {
p++;
switch (*p) {
case '1':
str_append_c(dest, '\001');
break;
case 't':
str_append_c(dest, '\t');
break;
case 'n':
str_append_c(dest, '\n');
break;
}
}
}
return str_c(dest);
}
static int client_dict_send_query(struct client_dict *dict, const char *query)
{
if (dict->output == NULL) {
if (client_dict_connect(dict) < 0)
return -1;
}
if (o_stream_send_str(dict->output, query) < 0 ||
o_stream_flush(dict->output) < 0) {
if (!dict->handshaked) {
return -1;
}
client_dict_disconnect(dict);
if (client_dict_connect(dict) < 0)
return -1;
if (o_stream_send_str(dict->output, query) < 0 ||
o_stream_flush(dict->output) < 0) {
i_error("write(%s) failed: %m", dict->path);
return -1;
}
}
return 0;
}
static int
client_dict_transaction_send_begin(struct client_dict_transaction_context *ctx)
{
struct client_dict *dict = (struct client_dict *)ctx->ctx.dict;
if (ctx->failed)
return -1;
T_BEGIN {
const char *query;
query = t_strdup_printf("%c%u\n", DICT_PROTOCOL_CMD_BEGIN,
ctx->id);
if (client_dict_send_query(dict, query) < 0)
ctx->failed = TRUE;
else
ctx->connect_counter = dict->connect_counter;
} T_END;
return ctx->failed ? -1 : 0;
}
static int
client_dict_send_transaction_query(struct client_dict_transaction_context *ctx,
const char *query)
{
struct client_dict *dict = (struct client_dict *)ctx->ctx.dict;
if (!ctx->sent_begin) {
if (client_dict_transaction_send_begin(ctx) < 0)
return -1;
ctx->sent_begin = TRUE;
}
if (ctx->connect_counter != dict->connect_counter || ctx->failed)
return -1;
if (dict->output == NULL) {
return -1;
}
if (o_stream_send_str(dict->output, query) < 0 ||
o_stream_flush(dict->output) < 0) {
ctx->failed = TRUE;
client_dict_disconnect(dict);
return -1;
}
return 0;
}
static struct client_dict_transaction_context *
client_dict_transaction_find(struct client_dict *dict, unsigned int id)
{
struct client_dict_transaction_context *ctx;
for (ctx = dict->transactions; ctx != NULL; ctx = ctx->next) {
if (ctx->id == id)
return ctx;
}
return NULL;
}
static void
client_dict_finish_transaction(struct client_dict *dict,
unsigned int id, int ret)
{
struct client_dict_transaction_context *ctx;
ctx = client_dict_transaction_find(dict, id);
if (ctx == NULL) {
i_error("dict-client: Unknown transaction id %u", id);
return;
}
if (ctx->callback != NULL)
ctx->callback(ret, ctx->context);
DLLIST_REMOVE(&dict->transactions, ctx);
i_free(ctx);
i_assert(dict->async_commits > 0);
if (--dict->async_commits == 0)
io_remove(&dict->io);
}
static ssize_t client_dict_read_timeout(struct client_dict *dict)
{
time_t now, timeout;
unsigned int diff;
ssize_t ret;
now = time(NULL);
timeout = now + DICT_CLIENT_READ_TIMEOUT_SECS;
do {
alarm(timeout - now);
ret = i_stream_read(dict->input);
alarm(0);
if (ret != 0)
break;
now = time(NULL);
} while (now < timeout);
if (ret > 0) {
diff = time(NULL) - now;
if (diff >= DICT_CLIENT_READ_WARN_TIMEOUT_SECS) {
i_warning("read(%s): dict lookup took %u seconds",
dict->path, diff);
}
}
return ret;
}
static int client_dict_read_one_line(struct client_dict *dict, char **line_r)
{
unsigned int id;
char *line;
ssize_t ret;
*line_r = NULL;
while ((line = i_stream_next_line(dict->input)) == NULL) {
ret = client_dict_read_timeout(dict);
switch (ret) {
case -1:
if (dict->input->stream_errno != 0)
i_error("read(%s) failed: %m", dict->path);
else {
i_error("read(%s) failed: Remote disconnected",
dict->path);
}
return -1;
case -2:
i_error("read(%s) returned too much data", dict->path);
return -1;
case 0:
i_error("read(%s) failed: Timeout after %u seconds",
dict->path, DICT_CLIENT_READ_TIMEOUT_SECS);
return -1;
default:
i_assert(ret > 0);
break;
}
}
if (*line == DICT_PROTOCOL_REPLY_ASYNC_COMMIT) {
switch (line[1]) {
case DICT_PROTOCOL_REPLY_OK:
ret = 1;
break;
case DICT_PROTOCOL_REPLY_NOTFOUND:
ret = 0;
break;
case DICT_PROTOCOL_REPLY_FAIL:
ret = -1;
break;
default:
i_error("dict-client: Invalid async commit line: %s",
line);
return 0;
}
if (str_to_uint(line+2, &id) < 0) {
i_error("dict-client: Invalid ID");
return 0;
}
client_dict_finish_transaction(dict, id, ret);
return 0;
}
*line_r = line;
return 1;
}
static bool client_dict_is_finished(struct client_dict *dict)
{
return dict->transactions == NULL && !dict->in_iteration &&
dict->async_commits == 0;
}
static void client_dict_timeout(struct client_dict *dict)
{
if (client_dict_is_finished(dict))
client_dict_disconnect(dict);
}
static void client_dict_add_timeout(struct client_dict *dict)
{
if (dict->to_idle != NULL) {
#if DICT_CLIENT_TIMEOUT_MSECS > 0
timeout_reset(dict->to_idle);
#endif
} else if (client_dict_is_finished(dict)) {
dict->to_idle = timeout_add(DICT_CLIENT_TIMEOUT_MSECS,
client_dict_timeout, dict);
}
}
static char *client_dict_read_line(struct client_dict *dict)
{
char *line;
while (client_dict_read_one_line(dict, &line) == 0)
;
client_dict_add_timeout(dict);
return line;
}
static int client_dict_connect(struct client_dict *dict)
{
const char *query;
i_assert(dict->fd == -1);
if (dict->last_failed_connect == ioloop_time) {
return -1;
}
dict->fd = net_connect_unix(dict->path);
if (dict->fd == -1) {
dict->last_failed_connect = ioloop_time;
if (errno == EACCES) {
i_error("%s", eacces_error_get("net_connect_unix",
dict->path));
} else {
i_error("net_connect_unix(%s) failed: %m",
dict->path);
}
return -1;
}
net_set_nonblock(dict->fd, FALSE);
dict->input = i_stream_create_fd(dict->fd, (size_t)-1, FALSE);
dict->output = o_stream_create_fd(dict->fd, 4096, FALSE);
dict->transaction_id_counter = 0;
dict->async_commits = 0;
query = t_strdup_printf("%c%u\t%u\t%d\t%s\t%s\n",
DICT_PROTOCOL_CMD_HELLO,
DICT_CLIENT_PROTOCOL_MAJOR_VERSION,
DICT_CLIENT_PROTOCOL_MINOR_VERSION,
dict->value_type, dict->username, dict->uri);
if (client_dict_send_query(dict, query) < 0) {
dict->last_failed_connect = ioloop_time;
client_dict_disconnect(dict);
return -1;
}
dict->handshaked = TRUE;
return 0;
}
static void client_dict_disconnect(struct client_dict *dict)
{
dict->connect_counter++;
dict->handshaked = FALSE;
if (dict->to_idle != NULL)
timeout_remove(&dict->to_idle);
if (dict->io != NULL)
io_remove(&dict->io);
if (dict->input != NULL)
i_stream_destroy(&dict->input);
if (dict->output != NULL)
o_stream_destroy(&dict->output);
if (dict->fd != -1) {
if (close(dict->fd) < 0)
i_error("close(%s) failed: %m", dict->path);
dict->fd = -1;
}
}
static struct dict *
client_dict_init(struct dict *driver, const char *uri,
enum dict_data_type value_type, const char *username,
const char *base_dir)
{
struct client_dict *dict;
const char *dest_uri;
pool_t pool;
dest_uri = strchr(uri, ':');
if (dest_uri == NULL) {
i_error("dict-client: Invalid URI: %s", uri);
return NULL;
}
pool = pool_alloconly_create("client dict", 1024);
dict = p_new(pool, struct client_dict, 1);
dict->pool = pool;
dict->dict = *driver;
dict->value_type = value_type;
dict->username = p_strdup(pool, username);
dict->fd = -1;
if (*uri != ':') {
dict->path = p_strdup_until(pool, uri, dest_uri);
} else {
dict->path = p_strconcat(pool, base_dir,
"/"DEFAULT_DICT_SERVER_SOCKET_FNAME, NULL);
}
dict->uri = p_strdup(pool, dest_uri + 1);
return &dict->dict;
}
static void client_dict_deinit(struct dict *_dict)
{
struct client_dict *dict = (struct client_dict *)_dict;
client_dict_disconnect(dict);
pool_unref(&dict->pool);
}
static int client_dict_wait(struct dict *_dict)
{
struct client_dict *dict = (struct client_dict *)_dict;
char *line;
int ret = 0;
if (!dict->handshaked)
return -1;
while (dict->async_commits > 0) {
if (client_dict_read_one_line(dict, &line) < 0) {
ret = -1;
break;
}
}
return ret;
}
static int client_dict_lookup(struct dict *_dict, pool_t pool,
const char *key, const char **value_r)
{
struct client_dict *dict = (struct client_dict *)_dict;
const char *line;
int ret;
T_BEGIN {
const char *query;
query = t_strdup_printf("%c%s\n", DICT_PROTOCOL_CMD_LOOKUP,
dict_client_escape(key));
ret = client_dict_send_query(dict, query);
} T_END;
if (ret < 0)
return -1;
line = client_dict_read_line(dict);
if (line == NULL)
return -1;
if (*line == DICT_PROTOCOL_REPLY_OK) {
*value_r = p_strdup(pool, dict_client_unescape(line + 1));
return 1;
} else {
*value_r = NULL;
return *line == DICT_PROTOCOL_REPLY_NOTFOUND ? 0 : -1;
}
}
static struct dict_iterate_context *
client_dict_iterate_init(struct dict *_dict, const char *const *paths,
enum dict_iterate_flags flags)
{
struct client_dict *dict = (struct client_dict *)_dict;
struct client_dict_iterate_context *ctx;
if (dict->in_iteration)
i_panic("dict-client: Only one iteration supported");
dict->in_iteration = TRUE;
ctx = i_new(struct client_dict_iterate_context, 1);
ctx->ctx.dict = _dict;
ctx->pool = pool_alloconly_create("client dict iteration", 512);
T_BEGIN {
string_t *query = t_str_new(256);
unsigned int i;
str_printfa(query, "%c%d", DICT_PROTOCOL_CMD_ITERATE, flags);
for (i = 0; paths[i] != NULL; i++) {
str_append_c(query, '\t');
str_append(query, dict_client_escape(paths[i]));
}
str_append_c(query, '\n');
if (client_dict_send_query(dict, str_c(query)) < 0)
ctx->failed = TRUE;
} T_END;
return &ctx->ctx;
}
static bool client_dict_iterate(struct dict_iterate_context *_ctx,
const char **key_r, const char **value_r)
{
struct client_dict_iterate_context *ctx =
(struct client_dict_iterate_context *)_ctx;
struct client_dict *dict = (struct client_dict *)_ctx->dict;
char *line, *value;
if (ctx->failed)
return FALSE;
line = client_dict_read_line(dict);
if (line == NULL) {
ctx->failed = TRUE;
return FALSE;
}
if (*line == '\0') {
return FALSE;
}
p_clear(ctx->pool);
switch (*line) {
case DICT_PROTOCOL_REPLY_OK:
value = strchr(++line, '\t');
break;
case DICT_PROTOCOL_REPLY_FAIL:
ctx->failed = TRUE;
return FALSE;
default:
value = NULL;
break;
}
if (value == NULL) {
i_error("dict client (%s) sent broken reply", dict->path);
ctx->failed = TRUE;
return FALSE;
}
*value++ = '\0';
*key_r = p_strdup(ctx->pool, dict_client_unescape(line));
*value_r = p_strdup(ctx->pool, dict_client_unescape(value));
return TRUE;
}
static int client_dict_iterate_deinit(struct dict_iterate_context *_ctx)
{
struct client_dict *dict = (struct client_dict *)_ctx->dict;
struct client_dict_iterate_context *ctx =
(struct client_dict_iterate_context *)_ctx;
int ret = ctx->failed ? -1 : 0;
pool_unref(&ctx->pool);
i_free(ctx);
dict->in_iteration = FALSE;
client_dict_add_timeout(dict);
return ret;
}
static struct dict_transaction_context *
client_dict_transaction_init(struct dict *_dict)
{
struct client_dict *dict = (struct client_dict *)_dict;
struct client_dict_transaction_context *ctx;
ctx = i_new(struct client_dict_transaction_context, 1);
ctx->ctx.dict = _dict;
ctx->id = ++dict->transaction_id_counter;
DLLIST_PREPEND(&dict->transactions, ctx);
return &ctx->ctx;
}
static void dict_async_input(struct client_dict *dict)
{
char *line;
size_t size;
int ret;
i_assert(!dict->in_iteration);
do {
ret = client_dict_read_one_line(dict, &line);
(void)i_stream_get_data(dict->input, &size);
} while (ret == 0 && size > 0);
if (ret < 0)
io_remove(&dict->io);
}
static int
client_dict_transaction_commit(struct dict_transaction_context *_ctx,
bool async,
dict_transaction_commit_callback_t *callback,
void *context)
{
struct client_dict_transaction_context *ctx =
(struct client_dict_transaction_context *)_ctx;
struct client_dict *dict = (struct client_dict *)_ctx->dict;
int ret = ctx->failed ? -1 : 1;
if (ctx->sent_begin && !ctx->failed) T_BEGIN {
const char *query, *line;
query = t_strdup_printf("%c%u\n", !async ?
DICT_PROTOCOL_CMD_COMMIT :
DICT_PROTOCOL_CMD_COMMIT_ASYNC,
ctx->id);
if (client_dict_send_transaction_query(ctx, query) < 0)
ret = -1;
else if (async) {
ctx->callback = callback;
ctx->context = context;
if (dict->async_commits++ == 0) {
dict->io = io_add(dict->fd, IO_READ,
dict_async_input, dict);
}
} else {
line = client_dict_read_line(dict);
if (line == NULL)
ret = -1;
else if (*line == DICT_PROTOCOL_REPLY_OK)
ret = 1;
else if (*line == DICT_PROTOCOL_REPLY_NOTFOUND)
ret = 0;
else
ret = -1;
}
} T_END;
if (ret < 0 || !async) {
DLLIST_REMOVE(&dict->transactions, ctx);
i_free(ctx);
client_dict_add_timeout(dict);
}
return ret;
}
static void
client_dict_transaction_rollback(struct dict_transaction_context *_ctx)
{
struct client_dict_transaction_context *ctx =
(struct client_dict_transaction_context *)_ctx;
struct client_dict *dict = (struct client_dict *)_ctx->dict;
if (ctx->sent_begin) T_BEGIN {
const char *query;
query = t_strdup_printf("%c%u\n", DICT_PROTOCOL_CMD_ROLLBACK,
ctx->id);
(void)client_dict_send_transaction_query(ctx, query);
} T_END;
DLLIST_REMOVE(&dict->transactions, ctx);
i_free(ctx);
client_dict_add_timeout(dict);
}
static void client_dict_set(struct dict_transaction_context *_ctx,
const char *key, const char *value)
{
struct client_dict_transaction_context *ctx =
(struct client_dict_transaction_context *)_ctx;
T_BEGIN {
const char *query;
query = t_strdup_printf("%c%u\t%s\t%s\n",
DICT_PROTOCOL_CMD_SET, ctx->id,
dict_client_escape(key),
dict_client_escape(value));
(void)client_dict_send_transaction_query(ctx, query);
} T_END;
}
static void client_dict_unset(struct dict_transaction_context *_ctx,
const char *key)
{
struct client_dict_transaction_context *ctx =
(struct client_dict_transaction_context *)_ctx;
T_BEGIN {
const char *query;
query = t_strdup_printf("%c%u\t%s\n",
DICT_PROTOCOL_CMD_UNSET, ctx->id,
dict_client_escape(key));
(void)client_dict_send_transaction_query(ctx, query);
} T_END;
}
static void client_dict_atomic_inc(struct dict_transaction_context *_ctx,
const char *key, long long diff)
{
struct client_dict_transaction_context *ctx =
(struct client_dict_transaction_context *)_ctx;
T_BEGIN {
const char *query;
query = t_strdup_printf("%c%u\t%s\t%lld\n",
DICT_PROTOCOL_CMD_ATOMIC_INC,
ctx->id, dict_client_escape(key), diff);
(void)client_dict_send_transaction_query(ctx, query);
} T_END;
}
struct dict dict_driver_client = {
.name = "proxy",
{
client_dict_init,
client_dict_deinit,
client_dict_wait,
client_dict_lookup,
client_dict_iterate_init,
client_dict_iterate,
client_dict_iterate_deinit,
client_dict_transaction_init,
client_dict_transaction_commit,
client_dict_transaction_rollback,
client_dict_set,
client_dict_unset,
client_dict_atomic_inc
}
};