udpfromto.c   [plain text]


/*
 *   This library is free software; you can redistribute it and/or
 *   modify it under the terms of the GNU Lesser General Public
 *   License as published by the Free Software Foundation; either
 *   version 2.1 of the License, or (at your option) any later version.
 *
 *   This library is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 *   Lesser General Public License for more details.
 *
 *   You should have received a copy of the GNU Lesser General Public
 *   License along with this library; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
 *
 *  Helper functions to get/set addresses of UDP packets
 *  based on recvfromto by Miquel van Smoorenburg
 *
 * recvfromto	Like recvfrom, but also stores the destination
 *		IP address. Useful on multihomed hosts.
 *
 *		Should work on Linux and BSD.
 *
 *		Copyright (C) 2002 Miquel van Smoorenburg.
 *
 *		This program is free software; you can redistribute it and/or
 *		modify it under the terms of the GNU Lesser General Public
 *		License as published by the Free Software Foundation; either
 *		version 2 of the License, or (at your option) any later version.
 *	Copyright (C) 2007 Alan DeKok <aland@deployingradius.com>
 *
 * sendfromto	added 18/08/2003, Jan Berkel <jan@sitadelle.com>
 *		Works on Linux and FreeBSD (5.x)
 *
 * Version: $Id$
 */

#include <freeradius-devel/ident.h>
RCSID("$Id$")

#include <freeradius-devel/udpfromto.h>

#ifdef WITH_UDPFROMTO

#ifdef HAVE_SYS_UIO_H
#include <sys/uio.h>
#endif

#include <fcntl.h>

int udpfromto_init(int s)
{
	int proto, flag, opt = 1;
	struct sockaddr_storage si;
	socklen_t si_len = sizeof(si);

	errno = ENOSYS;

	proto = -1;

	if (getsockname(s, (struct sockaddr *) &si, &si_len) < 0) {
		return -1;
	}

	if (si.ss_family == AF_INET) {
#ifdef HAVE_IP_PKTINFO
		/*
		 *	Linux
		 */
		proto = SOL_IP;
		flag = IP_PKTINFO;
#endif
		
#ifdef IP_RECVDSTADDR
		/*
		 *	Set the IP_RECVDSTADDR option (BSD).  Note:
		 *	IP_RECVDSTADDR == IP_SENDSRCADDR
		 */
		proto = IPPROTO_IP;
		flag = IP_RECVDSTADDR;
#endif

#ifdef AF_INET6
	} else if (si.ss_family == AF_INET6) {
#ifdef HAVE_IN6_PKTINFO
		/*
		 *	This should actually be standard IPv6
		 */
		proto = IPPROTO_IPV6;
		flag = IPV6_PKTINFO;
#endif
#endif
	} else {
		/*
		 *	Unknown AF.
		 */
		return -1;
	}
		
	/*
	 *	Unsupported.  Don't worry about it.
	 */
	if (proto < 0) return 0;

	return setsockopt(s, proto, flag, &opt, sizeof(opt));
}

int recvfromto(int s, void *buf, size_t len, int flags,
	       struct sockaddr *from, socklen_t *fromlen,
	       struct sockaddr *to, socklen_t *tolen)
{
	struct msghdr msgh;
	struct cmsghdr *cmsg;
	struct iovec iov;
	char cbuf[256];
	int err;
	struct sockaddr_storage si;
	socklen_t si_len = sizeof(si);

#if !defined(HAVE_IP_PKTINFO) && !defined(IP_RECVDSTADDR) && !defined (HAVE_IN6_PKTINFO)
	/*
	 *	If the recvmsg() flags aren't defined, fall back to
	 *	using recvfrom().
	 */
	to = NULL;
#endif

	/*
	 *	Catch the case where the caller passes invalid arguments.
	 */
	if (!to || !tolen) return recvfrom(s, buf, len, flags, from, fromlen);


	/*
	 *	recvmsg doesn't provide sin_port so we have to
	 *	retrieve it using getsockname().
	 */
	if (getsockname(s, (struct sockaddr *)&si, &si_len) < 0) {
		return -1;
	}

	/*
	 *	Initialize the 'to' address.  It may be INADDR_ANY here,
	 *	with a more specific address given by recvmsg(), below.
	 */
	if (si.ss_family == AF_INET) {
		struct sockaddr_in *dst = (struct sockaddr_in *) to;
		struct sockaddr_in *src = (struct sockaddr_in *) &si;
		
		if (*tolen < sizeof(*dst)) {
			errno = EINVAL;
			return -1;
		}
		*tolen = sizeof(*dst);
		*dst = *src;

#if !defined(HAVE_IP_PKTINFO) && !defined(IP_RECVDSTADDR)
		/*
		 *	recvmsg() flags aren't defined.  Use recvfrom()
		 */
		return recvfrom(s, buf, len, flags, from, fromlen);
#endif
	}

#ifdef AF_INET6
	else if (si.ss_family == AF_INET6) {
		struct sockaddr_in6 *dst = (struct sockaddr_in6 *) to;
		struct sockaddr_in6 *src = (struct sockaddr_in6 *) &si;
		
		if (*tolen < sizeof(*dst)) {
			errno = EINVAL;
			return -1;
		}
		*tolen = sizeof(*dst);
		*dst = *src;

#if !defined(HAVE_IN6_PKTINFO)
		/*
		 *	recvmsg() flags aren't defined.  Use recvfrom()
		 */
		return recvfrom(s, buf, len, flags, from, fromlen);
#endif
	}
#endif
	/*
	 *	Unknown address family.
	 */		
	else {
		errno = EINVAL;
		return -1;
	}

	/* Set up iov and msgh structures. */
	memset(&msgh, 0, sizeof(struct msghdr));
	iov.iov_base = buf;
	iov.iov_len  = len;
	msgh.msg_control = cbuf;
	msgh.msg_controllen = sizeof(cbuf);
	msgh.msg_name = from;
	msgh.msg_namelen = fromlen ? *fromlen : 0;
	msgh.msg_iov  = &iov;
	msgh.msg_iovlen = 1;
	msgh.msg_flags = 0;

	/* Receive one packet. */
	if ((err = recvmsg(s, &msgh, flags)) < 0) {
		return err;
	}

	if (fromlen) *fromlen = msgh.msg_namelen;

	/* Process auxiliary received data in msgh */
	for (cmsg = CMSG_FIRSTHDR(&msgh);
	     cmsg != NULL;
	     cmsg = CMSG_NXTHDR(&msgh,cmsg)) {

#ifdef HAVE_IP_PKTINFO
		if ((cmsg->cmsg_level == SOL_IP) &&
		    (cmsg->cmsg_type == IP_PKTINFO)) {
			struct in_pktinfo *i =
				(struct in_pktinfo *) CMSG_DATA(cmsg);
			((struct sockaddr_in *)to)->sin_addr = i->ipi_addr;
			*tolen = sizeof(struct sockaddr_in);
			break;
		}
#endif

#ifdef IP_RECVDSTADDR
		if ((cmsg->cmsg_level == IPPROTO_IP) &&
		    (cmsg->cmsg_type == IP_RECVDSTADDR)) {
			struct in_addr *i = (struct in_addr *) CMSG_DATA(cmsg);
			((struct sockaddr_in *)to)->sin_addr = *i;
			*tolen = sizeof(struct sockaddr_in);
			break;
		}
#endif

#ifdef HAVE_IN6_PKTINFO
		if ((cmsg->cmsg_level == IPPROTO_IPV6) &&
		    (cmsg->cmsg_type == IPV6_PKTINFO)) {
			struct in6_pktinfo *i =
				(struct in6_pktinfo *) CMSG_DATA(cmsg);
			((struct sockaddr_in6 *)to)->sin6_addr = i->ipi6_addr;
			*tolen = sizeof(struct sockaddr_in6);
			break;
		}
#endif
	}

	return err;
}

int sendfromto(int s, void *buf, size_t len, int flags,
	       struct sockaddr *from, socklen_t fromlen,
	       struct sockaddr *to, socklen_t tolen)
{
	struct msghdr msgh;
	struct cmsghdr *cmsg;
	struct iovec iov;
	char cbuf[256];

#if !defined(HAVE_IP_PKTINFO) && !defined(IP_SENDSRCADDR) && !defined(HAVE_IN6_PKTINFO)
	/*
	 *	If the sendmsg() flags aren't defined, fall back to
	 *	using sendto().
	 */
	from = NULL;
#endif

	/*
	 *	Catch the case where the caller passes invalid arguments.
	 */
	if (!from || (fromlen == 0) || (from->sa_family == AF_UNSPEC)) {
		return sendto(s, buf, len, flags, to, tolen);
	}

	/* Set up iov and msgh structures. */
	memset(&msgh, 0, sizeof(struct msghdr));
	iov.iov_base = buf;
	iov.iov_len = len;
	msgh.msg_iov = &iov;
	msgh.msg_iovlen = 1;
	msgh.msg_name = to;
	msgh.msg_namelen = tolen;

	if (from->sa_family == AF_INET) {
		struct sockaddr_in *s4 = (struct sockaddr_in *) from;

#ifdef HAVE_IP_PKTINFO
		struct in_pktinfo *pkt;

		msgh.msg_control = cbuf;
		msgh.msg_controllen = CMSG_SPACE(sizeof(*pkt));

		cmsg = CMSG_FIRSTHDR(&msgh);
		cmsg->cmsg_level = SOL_IP;
		cmsg->cmsg_type = IP_PKTINFO;
		cmsg->cmsg_len = CMSG_LEN(sizeof(*pkt));

		pkt = (struct in_pktinfo *) CMSG_DATA(cmsg);
		memset(pkt, 0, sizeof(*pkt));
		pkt->ipi_spec_dst = s4->sin_addr;
#endif

#ifdef IP_SENDSRCADDR
		struct in_addr *in;

		msgh.msg_control = cbuf;
		msgh.msg_controllen = CMSG_SPACE(sizeof(*in));

		cmsg = CMSG_FIRSTHDR(&msgh);
		cmsg->cmsg_level = IPPROTO_IP;
		cmsg->cmsg_type = IP_SENDSRCADDR;
		cmsg->cmsg_len = CMSG_LEN(sizeof(*in));

		in = (struct in_addr *) CMSG_DATA(cmsg);
		*in = s4->sin_addr;
#endif
	}

#ifdef AF_INET6
	else if (from->sa_family == AF_INET6) {
#ifdef HAVE_IN6_PKTINFO
		struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) from;

		struct in6_pktinfo *pkt;

		msgh.msg_control = cbuf;
		msgh.msg_controllen = CMSG_SPACE(sizeof(*pkt));

		cmsg = CMSG_FIRSTHDR(&msgh);
		cmsg->cmsg_level = IPPROTO_IPV6;
		cmsg->cmsg_type = IPV6_PKTINFO;
		cmsg->cmsg_len = CMSG_LEN(sizeof(*pkt));

		pkt = (struct in6_pktinfo *) CMSG_DATA(cmsg);
		memset(pkt, 0, sizeof(*pkt));
		pkt->ipi6_addr = s6->sin6_addr;
#endif
	}
#endif

	/*
	 *	Unknown address family.
	 */		
	else {
		errno = EINVAL;
		return -1;
	}

	return sendmsg(s, &msgh, flags);
}


#ifdef TESTING
/*
 *	Small test program to test recvfromto/sendfromto
 *
 *	use a virtual IP address as first argument to test
 *
 *	reply packet should originate from virtual IP and not
 *	from the default interface the alias is bound to
 */

#include <stdio.h>
#include <stdlib.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/wait.h>

#define DEF_PORT 20000		/* default port to listen on */
#define DESTIP "127.0.0.1"	/* send packet to localhost per default */
#define TESTSTRING "foo"	/* what to send */
#define TESTLEN 4			/* 4 bytes */

int main(int argc, char **argv)
{
	struct sockaddr_in from, to, in;
	char buf[TESTLEN];
	char *destip = DESTIP;
	int port = DEF_PORT;
	int n, server_socket, client_socket, fl, tl, pid;

	if (argc > 1) destip = argv[1];
	if (argc > 2) port = atoi(argv[2]);

	in.sin_family = AF_INET;
	in.sin_addr.s_addr = INADDR_ANY;
	in.sin_port = htons(port);
	fl = tl = sizeof(struct sockaddr_in);
	memset(&from, 0, sizeof(from));
	memset(&to,   0, sizeof(to));

	switch(pid = fork()) {
		case -1:
			perror("fork");
			return 0;
		case 0:
			/* child */
			usleep(100000);
			goto client;
	}

	/* parent: server */
	server_socket = socket(PF_INET, SOCK_DGRAM, 0);
	if (udpfromto_init(server_socket) != 0) {
		perror("udpfromto_init\n");
		waitpid(pid, NULL, WNOHANG);
		return 0;
	}

	if (bind(server_socket, (struct sockaddr *)&in, sizeof(in)) < 0) {
		perror("server: bind");
		waitpid(pid, NULL, WNOHANG);
		return 0;
	}

	printf("server: waiting for packets on INADDR_ANY:%d\n", port);
	if ((n = recvfromto(server_socket, buf, sizeof(buf), 0,
	    (struct sockaddr *)&from, &fl,
	    (struct sockaddr *)&to, &tl)) < 0) {
		perror("server: recvfromto");
		waitpid(pid, NULL, WNOHANG);
		return 0;
	}

	printf("server: received a packet of %d bytes [%s] ", n, buf);
	printf("(src ip:port %s:%d ",
		inet_ntoa(from.sin_addr), ntohs(from.sin_port));
	printf(" dst ip:port %s:%d)\n",
		inet_ntoa(to.sin_addr), ntohs(to.sin_port));

	printf("server: replying from address packet was received on to source address\n");

	if ((n = sendfromto(server_socket, buf, n, 0,
		(struct sockaddr *)&to, tl,
		(struct sockaddr *)&from, fl)) < 0) {
		perror("server: sendfromto");
	}

	waitpid(pid, NULL, 0);
	return 0;

client:
	close(server_socket);
	client_socket = socket(PF_INET, SOCK_DGRAM, 0);
	if (udpfromto_init(client_socket) != 0) {
		perror("udpfromto_init");
		_exit(0);
	}
	/* bind client on different port */
	in.sin_port = htons(port+1);
	if (bind(client_socket, (struct sockaddr *)&in, sizeof(in)) < 0) {
		perror("client: bind");
		_exit(0);
	}

	in.sin_port = htons(port);
	in.sin_addr.s_addr = inet_addr(destip);

	printf("client: sending packet to %s:%d\n", destip, port);
	if (sendto(client_socket, TESTSTRING, TESTLEN, 0,
			(struct sockaddr *)&in, sizeof(in)) < 0) {
		perror("client: sendto");
		_exit(0);
	}

	printf("client: waiting for reply from server on INADDR_ANY:%d\n", port+1);

	if ((n = recvfromto(client_socket, buf, sizeof(buf), 0,
	    (struct sockaddr *)&from, &fl,
	    (struct sockaddr *)&to, &tl)) < 0) {
		perror("client: recvfromto");
		_exit(0);
	}

	printf("client: received a packet of %d bytes [%s] ", n, buf);
	printf("(src ip:port %s:%d",
		inet_ntoa(from.sin_addr), ntohs(from.sin_port));
	printf(" dst ip:port %s:%d)\n",
		inet_ntoa(to.sin_addr), ntohs(to.sin_port));

	_exit(0);
}

#endif /* TESTING */
#endif /* WITH_UDPFROMTO */