/*
 * Flask Netlink Userspace Library
 *
 * Copyright (c) 2001-2002 James Morris <jmorris@intercode.com.au>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/time.h>
#include <sys/types.h>

#include "libflnetlink.h"

/****************************************************************************
 *
 * Private interface
 *
 ****************************************************************************/

enum {
	FLN_ERR_NONE = 0,
	FLN_ERR_IMPL,
	FLN_ERR_HANDLE,
	FLN_ERR_SOCKET,
	FLN_ERR_BIND,
	FLN_ERR_BUFFER,
	FLN_ERR_RECV,
	FLN_ERR_NLEOF,
	FLN_ERR_ADDRLEN,
	FLN_ERR_STRUNC,
	FLN_ERR_RTRUNC,
	FLN_ERR_NLRECV,
	FLN_ERR_SEND,
	FLN_ERR_SUPP,
	FLN_ERR_RECVBUF,
	FLN_ERR_TIMEOUT
};
#define FLN_MAXERR FLN_ERR_TIMEOUT

struct fln_errmap_t {
	int errcode;
	char *message;
} fln_errmap[] = {
	{ FLN_ERR_NONE, "Unknown error" },
	{ FLN_ERR_IMPL, "Implementation error" },
	{ FLN_ERR_HANDLE, "Unable to create netlink handle" },
	{ FLN_ERR_SOCKET, "Unable to create netlink socket" },
	{ FLN_ERR_BIND, "Unable to bind netlink socket" },
	{ FLN_ERR_BUFFER, "Unable to allocate buffer" },
	{ FLN_ERR_RECV, "Failed to receive netlink message" },
	{ FLN_ERR_NLEOF, "Received EOF on netlink socket" },
	{ FLN_ERR_ADDRLEN, "Invalid peer address length" },
	{ FLN_ERR_STRUNC, "Sent message truncated" },
	{ FLN_ERR_RTRUNC, "Received message truncated" },
	{ FLN_ERR_NLRECV, "Received error from netlink" },
	{ FLN_ERR_SEND, "Failed to send netlink message" },
	{ FLN_ERR_SUPP, "Operation not supported" },
	{ FLN_ERR_RECVBUF, "Receive buffer size invalid" },
	{ FLN_ERR_TIMEOUT, "Timeout"}
};

static int fln_errno = FLN_ERR_NONE;

static ssize_t fln_netlink_sendto(const struct fln_handle *h,
                                  const void *msg, size_t len)
{
	int status = sendto(h->fd, msg, len, 0,
	                    (struct sockaddr *)&h->peer, sizeof(h->peer));
	if (status < 0)
		fln_errno = FLN_ERR_SEND;
	return status;
}

static ssize_t fln_netlink_sendmsg(const struct fln_handle *h,
                                   const struct msghdr *msg,
                                   unsigned int flags)
{
	int status = sendmsg(h->fd, msg, flags);
	if (status < 0)
		fln_errno = FLN_ERR_SEND;
	return status;
}

static ssize_t fln_netlink_recvfrom(const struct fln_handle *h,
                                    unsigned char *buf, size_t len,
                                    int timeout)
{
	int addrlen, status;
	struct nlmsghdr *nlh;

	if (len < sizeof(struct nlmsgerr)) {
		fln_errno = FLN_ERR_RECVBUF;
		return -1;
	}
	addrlen = sizeof(h->peer);

	if (timeout != 0) {
		int ret;
		struct timeval tv;
		fd_set read_fds;
		
		if (timeout < 0) {
			/* non-block non-timeout */
			tv.tv_sec = 0;
			tv.tv_usec = 0;
		} else {
			tv.tv_sec = timeout / 1000000;
			tv.tv_usec = timeout % 1000000;
		}

		FD_ZERO(&read_fds);
		FD_SET(h->fd, &read_fds);
		ret = select(h->fd+1, &read_fds, NULL, NULL, &tv);
		if (ret < 0) {
			if (errno == EINTR) {
				return 0;
			} else {
				fln_errno = FLN_ERR_RECV;
				return -1;
			}
		}
		if (!FD_ISSET(h->fd, &read_fds)) {
			fln_errno = FLN_ERR_TIMEOUT;
			return 0;
		}
	}
	status = recvfrom(h->fd, buf, len, 0,
	                      (struct sockaddr *)&h->peer, &addrlen);
	if (status < 0) {
		fln_errno = FLN_ERR_RECV;
		return status;
	}
	if (addrlen != sizeof(h->peer)) {
		fln_errno = FLN_ERR_RECV;
		return -1;
	}
	if (status == 0) {
		fln_errno = FLN_ERR_NLEOF;
		return -1;
	}
	nlh = (struct nlmsghdr *)buf;
	if (nlh->nlmsg_flags & MSG_TRUNC || nlh->nlmsg_len > status) {
		fln_errno = FLN_ERR_RTRUNC;
		return -1;
	}
	return status;
}

static char *fln_strerror(int errcode)
{
	if (errcode < 0 || errcode > FLN_MAXERR)
		errcode = FLN_ERR_IMPL;
	return fln_errmap[errcode].message;
}

static void fln_init_nlhdr(struct fln_handle *h, struct nlmsghdr *nlh,
                           u_int32_t len, u_int16_t type, u_int16_t extraflags)
{
	u_int32_t flags = NLM_F_REQUEST|extraflags;

	if (h->flags & FLN_F_ACKS)
		flags |= NLM_F_ACK;
	
	nlh->nlmsg_len = NLMSG_LENGTH(len);
	nlh->nlmsg_flags = flags;
	nlh->nlmsg_type = type;
	nlh->nlmsg_seq = ++h->sequence;
	nlh->nlmsg_pid = h->local.nl_pid;
}                           

/* XXX: these function names suck */
static int fln_perimeter_entry(struct fln_handle *h, u_int32_t addr,
                               u_int32_t mask, u_int16_t type)
{
	struct {
		struct nlmsghdr nlh;
		struct flmsg_perim_entry msg;
	} req;
	
	memset(&req, 0, sizeof(req));
	fln_init_nlhdr(h, &req.nlh, sizeof(req.msg), type, 0);
	req.msg.entry.addr = addr;
	req.msg.entry.mask = mask;
	return fln_netlink_sendto(h, (void *)&req, req.nlh.nlmsg_len);
}

static int fln_base_msg(struct fln_handle *h,
                        u_int16_t type, u_int16_t extraflags)
{
	struct nlmsghdr nlh;
	
	memset(&nlh, 0, sizeof(nlh));
	fln_init_nlhdr(h, &nlh, 0, type, extraflags);
	return fln_netlink_sendto(h, (void *)&nlh, nlh.nlmsg_len);
}

/****************************************************************************
 *
 * Public interface
 *
 ****************************************************************************/

struct fln_handle *fln_create_handle(u_int32_t flags, u_int32_t groups)
{
	int status;
	struct fln_handle *h;

	h = (struct fln_handle *)malloc(sizeof(struct fln_handle));
	if (h == NULL) {
		fln_errno = FLN_ERR_HANDLE;
		return NULL;
	}
	
	memset(h, 0, sizeof(struct fln_handle));
	
	h->fd = socket(PF_NETLINK, SOCK_RAW, NETLINK_FLASK);
	if (h->fd == -1) {
		fln_errno = FLN_ERR_SOCKET;
		close(h->fd);
		free(h);
		return NULL;
	}
	
	memset(&h->local, 0, sizeof(struct sockaddr_nl));
	h->local.nl_family = AF_NETLINK;
	h->local.nl_pid = getpid();
	h->local.nl_groups = groups;
	
	status = bind(h->fd, (struct sockaddr *)&h->local, sizeof(h->local));
	if (status == -1) {
		fln_errno = FLN_ERR_BIND;
		close(h->fd);
		free(h);
		return NULL;
	}
	
	memset(&h->peer, 0, sizeof(struct sockaddr_nl));
	h->peer.nl_family = AF_NETLINK;
	h->peer.nl_pid = 0;
	h->peer.nl_groups = 0;
	
	h->flags = flags;
	
	return h;
}

/*
 * No error condition is checked here at this stage, but it may happen
 * if/when reliable messaging is implemented.
 */
int fln_destroy_handle(struct fln_handle *h)
{
	if (h) {
		close(h->fd);
		free(h);
	}
	return 0;
}

int fln_perimeter_add(struct fln_handle *h, u_int32_t addr, u_int32_t mask)
{
	return fln_perimeter_entry(h, addr, mask, FLMSG_PERIM_ADD);
}

int fln_perimeter_del(struct fln_handle *h, u_int32_t addr, u_int32_t mask)
{
	return fln_perimeter_entry(h, addr, mask, FLMSG_PERIM_DEL);
}

int fln_perimeter_dump(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_PERIM_GET, NLM_F_DUMP);
}

int fln_perimeter_flush(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_PERIM_FLUSH, 0);
}

void fln_parse_perim_entry(struct fln_handle *h, unsigned char *buf,
                           int status,
                           void (*userfn)(struct flmsg_perim_entry *msg))
{
	struct nlmsghdr *nlh;
	
	
	nlh = (struct nlmsghdr *)buf; 
	
	while (NLMSG_OK(nlh, status)) {
		struct flmsg_perim_entry *entry;
		
		if (nlh->nlmsg_pid != h->local.nl_pid || nlh->nlmsg_seq != h->sequence)
			continue;
		
		entry = NLMSG_DATA(nlh);
		
		if (userfn)
			userfn(entry);

		nlh = NLMSG_NEXT(nlh, status);
	}
}

void fln_parse_cache_map(struct fln_handle *h, unsigned char *buf,
                         int status,
                         void (*userfn)(struct flmsg_map_res *res))
{
	struct nlmsghdr *nlh;
	
	nlh = (struct nlmsghdr *)buf; 
	
	while (NLMSG_OK(nlh, status)) {
		struct flmsg_map_res *res;
		
		if (nlh->nlmsg_pid != h->local.nl_pid || nlh->nlmsg_seq != h->sequence)
			continue;
		
		res = NLMSG_DATA(nlh);
		
		if (userfn)
			userfn(res);

		nlh = NLMSG_NEXT(nlh, status);
	}
}

void fln_parse_queue_entry(struct fln_handle *h, unsigned char *buf,
                           int status,
                           void (*userfn)(struct flmsg_queue_entry *entry))
{
	struct nlmsghdr *nlh;
	
	nlh = (struct nlmsghdr *)buf; 
	
	while (NLMSG_OK(nlh, status)) {
		struct flmsg_queue_entry *entry;
		
		if (nlh->nlmsg_pid != h->local.nl_pid || nlh->nlmsg_seq != h->sequence)
			continue;
		
		entry = NLMSG_DATA(nlh);
		
		if (userfn)
			userfn(entry);

		nlh = NLMSG_NEXT(nlh, status);
	}
}

ssize_t fln_read(const struct fln_handle *h,
                 unsigned char *buf, size_t len, int timeout)
{
	return fln_netlink_recvfrom(h, buf, len, timeout);
}

int fln_message_type(const unsigned char *buf)
{
	return ((struct nlmsghdr*)buf)->nlmsg_type;
}

u_int32_t fln_message_seq(const unsigned char *buf)
{
	return ((struct nlmsghdr*)buf)->nlmsg_seq;
}

pid_t fln_message_pid(const unsigned char *buf)
{
	return ((struct nlmsghdr*)buf)->nlmsg_pid;
}

int fln_get_msgerr(const unsigned char *buf)
{
	struct nlmsghdr *h = (struct nlmsghdr *)buf;
	struct nlmsgerr *err = (struct nlmsgerr*)NLMSG_DATA(h);
	return -err->error;
}

void *fln_get_msg(const unsigned char *buf)
{
	return NLMSG_DATA((struct nlmsghdr *)(buf));
}

char *fln_errstr(void)
{
	return fln_strerror(fln_errno);
}

void fln_perror(const char *s)
{
	if (s)
		fputs(s, stderr);
	else
		fputs("ERROR", stderr);
	if (fln_errno)
		fprintf(stderr, ": %s", fln_errstr());
	if (errno)
		fprintf(stderr, ": %s", strerror(errno));
	fputc('\n', stderr);
}

/*
 * Return true if the current internal flnetlink errno is set.
 */
int fln_errno_set(void)
{
	return fln_errno != FLN_ERR_NONE;
}

/*
 * Return the current value of the sequence number of the flnetlink handle.
 */
u_int32_t fln_current_seq(struct fln_handle *h)
{
	return h->sequence;
}

/*
 * Return the Netlink filedescriptor being used by the flnetlink handle.
 */
int fln_fd(struct fln_handle *h)
{
	return h->fd;
}

/*
 * Send a map response to the kernel peer cache.
 */
int fln_cache_map_res(struct fln_handle *h, struct flmsg_map_res *res, size_t len)
{
	struct nlmsghdr nlh;
	struct iovec iov[2];
	struct msghdr msg;
	
	fln_init_nlhdr(h, &nlh, 0, FLMSG_CACHE_MAP_RES, 0);
	nlh.nlmsg_len = NLMSG_LENGTH(len);
	
	iov[0].iov_base = &nlh;
	iov[0].iov_len = sizeof(nlh);
	
	iov[1].iov_base = res;
	iov[1].iov_len = len;

	msg.msg_name = (void *)&h->peer;
	msg.msg_namelen = sizeof(h->peer);
	msg.msg_iov = iov;
	msg.msg_iovlen = 2;
	msg.msg_control = NULL;
	msg.msg_controllen = 0;
	msg.msg_flags = 0;
		
	return fln_netlink_sendmsg(h, &msg, 0);
}

int fln_cache_get(struct fln_handle *h, u_int32_t addr, security_id_t rsid)
{
	return -1;
}

int fln_cache_dump(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_CACHE_GET, NLM_F_DUMP);
}

int fln_cache_flush(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_CACHE_FLUSH, 0);
}

int fln_queue_dump(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_QUEUE_GET, NLM_F_DUMP);
}

int fln_queue_flush(struct fln_handle *h)
{
	return fln_base_msg(h, FLMSG_QUEUE_FLUSH, 0);
}

/*
 * Helpers
 */
void print_perimtab_entry(struct flmsg_perim_entry *entry)
{
	printf("%u.%u.%u.%u/%d\n", NIPQUAD(entry->entry.addr),
	ip_masklen(entry->entry.mask));
}

void print_mapreq(struct flmsg_map_req *req)
{
	int i;
	
	printf("%u.%u.%u.%u serial %u", NIPQUAD(req->base.peer), req->base.serial);
	
	for (i = 0; i < req->base.count; i++)
		printf(" %u", req->sid[i]);
	
	putchar('\n');
}

void print_mapres(struct flmsg_map_res *res)
{
	printf("%u.%u.%u.%u serial %u rsid %u lsid %u\n",
	       NIPQUAD(res->base.peer), res->base.serial,
	       res->map[0].rsid, res->map[0].lsid);
}

void print_cache_entry(struct flmsg_base *entry)
{
	printf("%u.%u.%u.%u serial %u\n", NIPQUAD(entry->peer), entry->serial);
}

void print_queue_entry(struct flmsg_queue_entry *entry)
{
	u_int8_t proto = entry->proto;
	u_int32_t sid;
	struct protoent *pent;
	
	printf("%u.%u.%u.%u serial %u ttl %u proto ",
	       NIPQUAD(entry->base.peer), entry->base.serial, entry->ttl);
	
	pent = getprotobynumber(proto);
	if (pent == NULL)
		printf("%u ", proto);
	else
		printf("%s ", pent->p_name);

	if (proto == IPPROTO_TCP || proto == IPPROTO_UDP)
		printf("sport %u dport %u ", ntohs(entry->sport), ntohs(entry->dport));

	sid = entry->ssid_attr.sid;
	if (sid)
		printf("ssid %u %s ", sid, entry->ssid_attr.mapped ? "mapped" : "unmapped");

	sid = entry->msid_attr.sid;
	if (sid)
		printf("msid %u %s ", sid, entry->msid_attr.mapped ? "mapped" : "unmapped");

	sid = entry->dsid_attr.sid;
	if (sid)
		printf("dsid %u %s ", sid, entry->dsid_attr.mapped ? "mapped" : "unmapped");
		
	putchar('\n');
}
