/*
 * Security Context Mapping Protocol Daemon
 *
 * SCMP Protocol Routines
 *
 * Copyright (c) 2001-2002 James Morris <jmorris@intercode.com.au>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 */
#include <stdlib.h>
#include <netinet/in.h>
#include <stdarg.h>
#include <stdio.h>
#include <syslog.h>
#include <arpa/inet.h>
#include <string.h>
#include <stdio.h>
#include <ss.h>

#include "libflnetlink.h"
#include "server.h"
#include "protocol.h"
#include "scmp.h"
#include "transport.h"
#include "flnetlink.h"
#include "perimtab.h"
#include "server.h"
#include "debug.h"

/*
 * SCMP error code strings.
 */
static char *scmp_errtab[] = {
	[SCMP_ERR_UNSPEC]	"unspecified error",
	[SCMP_ERR_LENGTH]	"invalid message length",
	[SCMP_ERR_ADDRESS]	"invalid payload address",
	[SCMP_ERR_VERSION]	"invalid scmp version",
	[SCMP_ERR_TYPE]		"invalid mesage type",
	[SCMP_ERR_LOCAL]	"local processing error",
	[SCMP_ERR_INTERNAL]	"internal error",
	[SCMP_ERR_RECORDS]	"invalid number of records"
};

/*
 * SCMP message types.
 */
static char *scmp_msgtab[] = {
	[SCMP_MSG_UNSPEC]	"mystery",
	[SCMP_MSG_BASE]		"base",
	[SCMP_MSG_MAP_REQ]	"map request",
	[SCMP_MSG_MAP_RES]	"map response",
	[SCMP_MSG_ERR_RES]	"error response",
};

/*
 * Convert an SCMP error code to a descriptive string.
 */
static char *scmp_errstr(u_int16_t errcode)
{
	if (errcode > SCMP_ERR_MAX)
		errcode = SCMP_ERR_INTERNAL;
	
	return scmp_errtab[errcode];	
}

/*
 * Convert an SCMP message type to a descriptive string.
 */
static char *scmp_msgstr(u_int8_t type)
{
	if (type > SCMP_MSG_MAX)
		type = SCMP_MSG_UNSPEC;
	
	return scmp_msgtab[type];
}

/*
 * Send an SCMP error response message to a peer, encapsulating the
 * base header of the original message.
 */
static void scmp_send_error_response(struct scmp_msg_base *orig,
                                     struct in_addr *peer,
                                     u_int16_t errcode, u_int16_t pointer)
{
	int rc, len;
	struct scmp_msg_err_res msg;
	
	len = SCMP_ALIGN(sizeof(msg));
	memset(&msg, 0, len);
	
	msg.base.version   = SCMP_VERSION;
	msg.base.type      = SCMP_MSG_ERR_RES;
	msg.base.len       = htons(sizeof(struct scmp_msg_err_res));
	msg.base.peer      = htonl(INADDR_ANY);
	msg.base.sequence  = orig->sequence;
	msg.base.serial    = htonl(current_serial());
	
	msg.errcode        = htons(errcode);
	msg.pointer        = htons(pointer);
	
	memcpy(&msg.orig, orig, sizeof(struct scmp_msg_base));
	
	debug("sending %s message to %s seq %u err %u ptr %u",
	      scmp_msgstr(SCMP_MSG_ERR_RES), inet_ntoa(*peer),
	      ntohl(msg.base.sequence), errcode, pointer);
	
	rc = transport_send((char *)&msg, peer, len);
	if (rc < 0)
		syslog(LOG_ERR, "%s: transport_send: %m", __FUNCTION__);

	return;
}

/*
 * Report an SCMP message procesing error, first locally, then remotely if
 * the original message was not itself an error response.
 */
static void scmp_report_error(char *buf, struct in_addr *peer,
                              u_int16_t errcode, u_int16_t pointer)
{
	struct scmp_msg_base *orig = (struct scmp_msg_base *)buf;
	
	syslog(LOG_ERR, "error processing scmp message from from %s: %s: %s: "
	       "ptr %u ver %u len %u peer %u.%u.%u.%u seq %u ser %u",
	       inet_ntoa(*peer), scmp_msgstr(orig->type), scmp_errstr(errcode),
	       pointer, orig->version, ntohs(orig->len), NIPQUAD(orig->peer),
	       ntohl(orig->sequence), ntohl(orig->serial));
	
	if (orig->type != SCMP_MSG_ERR_RES)
		scmp_send_error_response(orig, peer, errcode, pointer);
}

/*
 * Send an SCMP map response message in response to a map request message,
 * mapping each requested SID to a security context.
 */
static void scmp_send_map_response(struct scmp_msg_map_req *req,
                                   struct in_addr *peer)
{
	char *curr;
	int i, rc, len, nrecs;
	struct scmp_msg_map_res *res;
	
	res = malloc(SCMP_MSGLEN_MAX);
	if (res == NULL) {
		syslog(LOG_ERR, "%s: malloc: %m", __FUNCTION__);
		return;
	}
	
	memset(res, 0, SCMP_MSGLEN_MAX);
	len = sizeof(struct scmp_msg_map_res);
	
	res->base.version   = SCMP_VERSION;
	res->base.type      = SCMP_MSG_MAP_RES;
	res->base.len       = htons(0);
	res->base.peer      = req->base.peer;
	res->base.sequence  = req->base.sequence;
	res->base.serial    = htonl(current_serial());
	res->records        = req->records;

 	nrecs = ntohs(req->records);
 
	/*
	 * Add an attribute map for each SID.
	 */
	for (i = 0, curr = (char *)res->map; i < nrecs; i++) {
		int ctxlen = SCMP_CONTEXT_MAX;
		struct scmp_attr_map *map = (void *)curr;
		
		map->sid = req->sid[i];
		
		rc = security_sid_to_context(ntohl(req->sid[i]),
		                             map->context, &ctxlen);
		if (rc) {
			syslog(LOG_ERR, "%s: security_sid_to_context: %m",
			       __FUNCTION__);
			scmp_report_error((unsigned char *)req, peer,
			                  SCMP_ERR_CONTEXT, i+1);
			free(res);
			return;
		}
		
		debug("lsid %u is %s", ntohl(req->sid[i]), map->context);
		
		/*
		 * Right zero pad to 32-bits
		 */
		ctxlen = SCMP_ALIGN(ctxlen); 
		
		map->context_len = htons(ctxlen);
		len += sizeof(struct scmp_attr_map) + ctxlen;
		curr += len;
	}

	len = SCMP_ALIGN(len);
	res->base.len = htons(len);

	debug("sending %s message to %s seq %u ser %u recs %u",
	      scmp_msgstr(SCMP_MSG_MAP_RES), inet_ntoa(*peer),
	      ntohl(res->base.sequence), ntohl(res->base.serial), nrecs);
	      
	rc = transport_send((unsigned char *)res, peer, len);
	if (rc < 0)
		syslog(LOG_ERR, "%s: transport_send: %m", __FUNCTION__);
	
	free(res);
	return;
}

/*
 * Send an SCMP map request message to a peer, effectively proxying the
 * kernel request.
 */
void scmp_send_map_request(struct flmsg_map_req *flreq)
{
	int i, len, rc, nrecs;
	struct in_addr peer;
	struct scmp_msg_map_req *msg;
	
	nrecs = flreq->base.count;
	peer.s_addr = flreq->base.peer;
	
	if (nrecs > SCMP_RECORDS_MAX) {
		syslog(LOG_ERR, "%u mappings requested, limit is %u",
		       nrecs, SCMP_RECORDS_MAX);
		return;
	}
	
	len = SCMP_ALIGN(sizeof(struct scmp_msg_map_req)
	                 + nrecs * (sizeof(security_id_t)));
	      
	msg = malloc(len);
	if (msg == NULL) {
		syslog(LOG_ERR, "%s: malloc: %m", __FUNCTION__);
		return;
	}
	
	memset(msg, 0, len);
	
	msg->base.version   = SCMP_VERSION;
	msg->base.type      = SCMP_MSG_MAP_REQ;
	msg->base.len       = htons(len);
	msg->base.peer      = peer.s_addr;
	msg->base.sequence  = htonl(send_sequence++);
	msg->base.serial    = htonl(flreq->base.serial);
	msg->records        = htons(nrecs);

	for (i = 0; i < nrecs; i++) {
		debug("requesting mapping for rsid %u", flreq->sid[i]);
		msg->sid[i] = htonl(flreq->sid[i]);
	}
	
	debug("sending %s message to %s seq %u ser %u recs %u",
	      scmp_msgstr(SCMP_MSG_MAP_REQ), inet_ntoa(peer),
	      ntohl(msg->base.sequence), flreq->base.serial, nrecs);

	rc = transport_send((unsigned char *)msg, &peer, len);
	if (rc < 0)
		syslog(LOG_ERR, "%s: transport_send: %m", __FUNCTION__);
	
	free(msg);
	return;
}

/*
 * Receive an SCMP map request message.
 */
static void scmp_recv_map_request(unsigned char *buf,
                                  struct in_addr *peer, int len)
{
	int nrecs;
	struct scmp_msg_map_req *req = (struct scmp_msg_map_req *)buf;
	
	if (len < SCMP_ALIGN(sizeof(struct scmp_msg_map_req))) {
		scmp_report_error(buf, peer, SCMP_ERR_LENGTH, 0);
		return;
	}
	
	nrecs = ntohs(req->records);
	
	if (nrecs < 1 || nrecs > SCMP_RECORDS_MAX) {
		scmp_report_error(buf, peer, SCMP_ERR_RECORDS, 0);
		return;
	}
	
	if (len != SCMP_ALIGN(sizeof(struct scmp_msg_map_req)
	                      + nrecs * sizeof(security_id_t))) {
		scmp_report_error(buf, peer, SCMP_ERR_LENGTH, 0);
		return;
	}
	
	scmp_send_map_response(req, peer);
	return;
}

/*
 * Receive an SCMP map response message, convert remote security contexts to
 * local SIDs, then update kernel mappings via flnetlink.
 */
static void scmp_recv_map_response(unsigned char *buf,
                                   struct in_addr *peer, int len)
{
	char *curr, *end;
	int i, rc, fllen, nrecs;
	struct flmsg_map_res *flres;
	struct scmp_msg_map_res *res = (struct scmp_msg_map_res *)buf;
	
	nrecs = ntohs(res->records);
	
	if (nrecs < 1 || nrecs > SCMP_RECORDS_MAX) {
		scmp_report_error(buf, peer, SCMP_ERR_RECORDS, 0);
		return;
	}
	
	if (peer->s_addr != res->base.peer) {
		scmp_report_error(buf, peer, SCMP_ERR_ADDRESS, 0);
		return;
	}
	
	fllen = sizeof(struct flmsg_map_res) +
	               nrecs * sizeof(struct flmsg_attr_map);
	               
	flres = malloc(fllen);
	if (flres == NULL) {
		syslog(LOG_ERR, "malloc: allocating flres: %m");
		return;
	}
	
	flres->base.serial  = ntohl(res->base.serial);
	flres->base.peer    = res->base.peer;
	flres->base.count   = nrecs;

	for (curr = (char *)res->map, end = (char *)res + len, i = 0;
	     (curr <= end && i < nrecs); i++) {
	     
	     	struct scmp_attr_map *map = (struct scmp_attr_map *)curr;
	     	int ctxlen = ntohs(map->context_len);
		security_id_t lsid;
		
		if (!SCMP_ALIGNED(ctxlen)) {
			scmp_report_error(buf, peer, SCMP_ERR_LENGTH, i+1);
			free(flres);
			return;
		}
		
		flres->map[i].rsid = ntohl(map->sid);
		
		rc = security_context_to_sid(map->context, ctxlen, &lsid);
		if (rc) {
			syslog(LOG_ERR, "%s: security_context_to_sid: %m",
			       __FUNCTION__);
			scmp_report_error(buf, peer, SCMP_ERR_SID, i+1);
			free(flres);
			return;
		}
		
		flres->map[i].lsid = lsid;
		
		debug("mapped rsid %u ctx %s to lsid %u", ntohl(map->sid),
		      map->context, lsid);
	
		curr += sizeof(struct scmp_attr_map) + ctxlen;
	}

	if (curr > end || (end - curr) > SCMP_ALIGN(1)) {
		scmp_report_error(buf, peer, SCMP_ERR_LENGTH, i);
		free(flres);
		return;
	}

	if (i != nrecs) {
		scmp_report_error(buf, peer, SCMP_ERR_RECORDS, i);
		free(flres);
		return;
	}

	rc = flnetlink_send_map_res(flres, fllen);
	if (rc < 0)
		syslog(LOG_ERR, "%s: flnetlink_send_map_res: %m: %s",
		       __FUNCTION__, fln_errstr());
	
	free(flres);
	return;
}

/*
 * Receive and log an SCMP error response as reported by a peer.  This message must
 * not itself be responded to.
 */
static void scmp_recv_error_response(unsigned char *buf,
                                     struct in_addr *peer, int len)
{
	struct scmp_msg_err_res *res = (struct scmp_msg_err_res *)buf;
	
	if (len != SCMP_ALIGN(sizeof(struct scmp_msg_err_res))) {
		scmp_report_error(buf, peer, SCMP_ERR_LENGTH, 0);
		return;
	}
	
	syslog(LOG_WARNING, "%s error ptr %u for scmp %s message seq %u",
	       scmp_errstr(ntohs(res->errcode)), ntohs(res->pointer),
	       scmp_msgstr(res->orig.type), ntohl(res->orig.sequence));
	       
	return;
}

/*
 * Main SCMP receive routine.  Messages from outside the security perimeter
 * are silently dropped.
 */
void scmp_recv(unsigned char *buf, struct in_addr *peer, int len)
{
	struct scmp_msg_base *msg;
	
	if (!perimtab_match(peer->s_addr))
		return;
	
	if (len < SCMP_MSGLEN_MIN) {
		syslog(LOG_ERR, "short message, length %u from %s",
		       len, inet_ntoa(*peer));
		return;
	}
	
	msg = (struct scmp_msg_base *)buf;
	
	if (!SCMP_ALIGNED(len) || len > SCMP_MSGLEN_MAX
	    || len != ntohs(msg->len)) {
		scmp_report_error(buf, peer, SCMP_ERR_LENGTH, 0);
		return;
	}
	
	if (msg->version != SCMP_VERSION) {
		scmp_report_error(buf, peer, SCMP_ERR_VERSION, 0);
		return;
	}

	debug("received scmp %s message from %s seq %u ser %u",
	      scmp_msgstr(msg->type), inet_ntoa(*peer), ntohl(msg->sequence),
	      ntohl(msg->serial));
	
	switch (msg->type) {
	case SCMP_MSG_MAP_REQ:
		scmp_recv_map_request(buf, peer, len);
		break;
	
	case SCMP_MSG_MAP_RES:
		scmp_recv_map_response(buf, peer, len);
		break;
	
	case SCMP_MSG_ERR_RES:
		scmp_recv_error_response(buf, peer, len);
		break;
	
	default:
		scmp_report_error(buf, peer, SCMP_ERR_TYPE, 0);
		break;
	}

	return;
}
