/*
  Copyright Mission Critical Linux, 2000

  Kimberlite is free software; you can redistribute it and/or modify it
  under the terms of the GNU General Public License as published by the
  Free Software Foundation; either version 2, or (at your option) any
  later version.

  Kimberlite is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with Kimberlite; see the file COPYING.  If not, write to the
  Free Software Foundation, Inc.,  675 Mass Ave, Cambridge, 
  MA 02139, USA.
*/
/*
 *  $Id: msg.c,v 1.20 2000/11/20 22:27:10 burke Exp $
 *
 *  Copyright (C) 2000 Mission Critical Linux, LLC
 *
 *  author: Jeff Moyer <moyer@mclinux.com>
 *  description: Inter-process messaging interface.
 */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/errno.h>
#include <sys/socket.h>
#include <malloc.h>
#include <linux/limits.h>
#include <sys/time.h>
#include <sys/un.h>
#include <dirent.h>
#include <string.h>
#include <sys/file.h>
#include <sys/stat.h>
#include <libgen.h>
#include <netinet/in.h>
#include <netdb.h>
#include <fcntl.h>
#include <sys/syslog.h>

#include "fdlist.h"
#include <msgsvc.h>
#include <parseconf.h>
#include <clusterdefs.h>
#include <clucfg.h>
#include <logger.h>
#include <diskstate.h>
#include <disk_proto.h>

static const char *version __attribute__ ((unused)) = "$Revision: 1.20 $";

#ifdef DEBUG
#define Dprintf(fmt,args...)  printf(fmt,##args)
#else
#define Dprintf(fmt,args...)
#endif

/*
 *  Configuration file entries we need to parse.
 */
#define CFG_HEARTBEAT_PORT   "heartbeat%addr"
#define CFG_SVCMGR_PORT      "svcmgr%addr"
#define CFG_QUORUMD_PORT     "quorumd%addr"
#define CFG_POWER_PORT       "power%addr"

#define DFLT_HEARTBEAT_PORT  "4001"
#define DFLT_SVCMGR_PORT     "4002"
#define DFLT_QUORUMD_PORT    "4003"
#define DFLT_POWER_PORT      "4004"

/*
 * This is the size of our authentication token.
 * Most likely, you won't be able to guess 450 bytes worth of data.
 */
#define SESSION_ID_SIZE      SPACE_NET_BLOCK_DATA

/*
 * Global Variables
 */
static int msgsvc_initialized = 0;
static unsigned char session_id[SESSION_ID_SIZE];

struct sockaddr_in msg_hb_addr,
                   msg_svcmgr_addr,
                   msg_quorumd_addr,
                   msg_power_addr;

struct msg_id_struct {
    int                msg_fd; /* Used for the current process to listen on */
    struct sockaddr_in *msg_ins; /* Unix Domain Socket associated with proc */
    char               *config_file_port_ent;
    char               *dflt_port;
} proc_id_array[] ={ 
    {-1, &msg_hb_addr,        CFG_HEARTBEAT_PORT,  DFLT_HEARTBEAT_PORT },
    {-1, &msg_svcmgr_addr,    CFG_SVCMGR_PORT,     DFLT_SVCMGR_PORT },
    {-1, &msg_quorumd_addr,   CFG_QUORUMD_PORT,    DFLT_QUORUMD_PORT },
    {-1, &msg_power_addr,     CFG_POWER_PORT,      DFLT_POWER_PORT } };

struct msg_struct {
    unsigned char     ver;
    ssize_t           count; /* number of bytes in payload */
    unsigned char     session_id[SESSION_ID_SIZE];
};

/*
 * Process ID for running process.  Needed for msg_receive, set in msgInit.
 */
static int proc_id = -1;
static int local_node = -1;
static struct sockaddr_in remote_node;

/*
 *  Local prototypes
 */
static int msg_svc_init(void);
static void msg_setup_sockin(char *host, int port, struct sockaddr_in *ins);
static unsigned long msg_create(void *payload, ssize_t count,
				void **msg, int auth);
static void msg_destroy(void *msg);

/*
 *  int msg_svc_init(void)
 */
static int
msg_svc_init()
{
    int            i, retval=-1;
    char           *port;
    struct hostent *hp;
    CluCfg         *cfg;
    char           netblk[SPACE_NET_BLOCK_DATA];

    cfg = get_clu_cfg(NULL);
    if (!cfg) {
	clulog(LOG_ERR, "msg_svc_init: Unable to retrieve cluster "
	       "configuration information.\n");
	return -1;
    }

    local_node = cfg->lid;

    /*
     * We need to get information on the hostname of the remote node.
     */
    /* Note: the following only works for two node cluster */
    remote_node.sin_family = AF_INET;
    hp = gethostbyname(cfg->nodes[local_node ^ 1].name);
    if (hp == NULL) {
	clulog(LOG_ERR, "msg_svc_init: Unable to get host information for "
	       "remote node.\n");
	free(cfg);
	return -1;
    }
    memcpy(&remote_node.sin_addr,hp->h_addr,hp->h_length);

    /*
     * Initialize our file descriptor table.
     */
    fdlist.head = fdlist.tail = NULL;

    /*
     *  Now we get communication endpoints from the config file.
     */
    for (i = 0; i <= MAX_PROCID; i++) {
	CFG_Get(proc_id_array[i].config_file_port_ent,
		proc_id_array[i].dflt_port, &port);

	msg_setup_sockin(cfg->nodes[local_node].name,
			 atoi(port),proc_id_array[i].msg_ins);
    }

    free(cfg);
    cfg = NULL;
    /*
     * Next we find out what our session id is.  The session id is used
     * to authenticate all cluster messages (both between daemons and 
     * between nodes).
     */
    clu_lock();
    retval = getNetBlockData(netblk);
    clu_un_lock();
    if (retval < 0) {
	clulog(LOG_ERR, "msg_svc_init: Unable to read session_id.\n");
	return -1;
    }
    memcpy(session_id, netblk, SPACE_NET_BLOCK_DATA);

    msgsvc_initialized = 1;
    return 0;
}


msg_handle_t
msg_listen(msg_addr_t my_proc_id)
{
    char         *port;
    CluCfg       *cfg;
    const int     on = 1;
    int           val; /* for F_GETFL */

    if (!msgsvc_initialized) {
	if (msg_svc_init() < 0) {
	    clulog(LOG_ERR, "msg_listen: Unable to initialize "
		   "messaging subsystem.\n");
	    return -1;
	}
    }

    proc_id = my_proc_id;

    cfg = get_clu_cfg(NULL);
    if (!cfg)
	return -1;

    CFG_Get(proc_id_array[proc_id].config_file_port_ent,
	    proc_id_array[proc_id].dflt_port, &port);

    msg_setup_sockin(cfg->nodes[local_node].name, atoi(port),
		     proc_id_array[proc_id].msg_ins);
    free(cfg);
    cfg = NULL;

    if ((proc_id_array[proc_id].msg_fd = socket(AF_INET,SOCK_STREAM,0))<0) {
	clulog(LOG_ERR, "msg_listen: Unable to create socket\n");
	close(proc_id_array[proc_id].msg_fd);
	return -1;
    }

    setsockopt(proc_id_array[proc_id].msg_fd, SOL_SOCKET, 
	       SO_REUSEADDR, &on, sizeof(on));

    val = fcntl(proc_id_array[proc_id].msg_fd, F_GETFD, 0);
    if (val < 0) {
	/*
	 * Error.
	 */
	close(proc_id_array[proc_id].msg_fd);
	return -1;
    }
    val |= FD_CLOEXEC;
    if (fcntl(proc_id_array[proc_id].msg_fd, F_SETFD, val) < 0) {
	clulog(LOG_ERR, "msg_listen: Unable to set the FD_CLOEXEC flag on sock\n");
	close(proc_id_array[proc_id].msg_fd);
	return -1;
    }

    if (bind(proc_id_array[proc_id].msg_fd, 
	     (struct sockaddr_in *)proc_id_array[proc_id].msg_ins,
	     sizeof(struct sockaddr_in)) < 0) {

	clulog(LOG_ERR, "msg_listen: Unable to bind to socket %s\n",
	       strerror(errno));
	close(proc_id_array[proc_id].msg_fd);
	return -1;
    }
    listen(proc_id_array[proc_id].msg_fd, MAX_MSGBACKLOG);
    fdlist_add(proc_id_array[proc_id].msg_fd, 
	       MSG_LISTENING, proc_id_array[proc_id].msg_ins);

    return proc_id_array[proc_id].msg_fd;
}

/*
 *  msg_open(msg_addr_t dest, int nodeid)
 *
 *  Open a communications channel to destination address.
 */
msg_handle_t
msg_open(msg_addr_t dest, int nodeid)
{
    int                 sockfd, flags, error, len;
    int                 ret;
    socklen_t           ins_len=sizeof(struct sockaddr_in);
    fd_set              rfds, wfds;
    struct timeval      tv;

    if (nodeid < 0) {
	clulog(LOG_ERR, "msg_open: invalid node id %d.\n", nodeid);
	return -1;
    }
    if (!msgsvc_initialized) {
	if (msg_svc_init() < 0) {
	    clulog(LOG_ERR, "msg_open: unable to initialize msg subsystem.\n");
	    return -1;
	}
    }

    sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd < 0) {
	clulog(LOG_ERR, "msg_open: Unable to create socket. Error: %s\n", 
	       strerror(errno));
	return -1;
    }
    /*
     * Set the socket up for a non-blocking connect.  Normal TCP connects
     * have a timeout of 75 seconds or more.  This is not acceptable for
     * our clustering software, so we'll make the timeout 5 seconds.  This
     * is very reasonable, considering these are local connects, for the most
     * part.
     */
    flags = fcntl(sockfd, F_GETFL, 0);
    fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);

    if (nodeid != local_node) {
	struct sockaddr_in tmpaddr;
	memcpy(&tmpaddr, &remote_node, ins_len);
	tmpaddr.sin_port = proc_id_array[dest].msg_ins->sin_port;
	ret = connect(sockfd, (struct sockaddr *)&tmpaddr, ins_len);
    } else {
	ret = connect(sockfd, (struct sockaddr *)proc_id_array[dest].msg_ins, ins_len);
    }

    if (ret < 0) {
	if (errno != EINPROGRESS) {
	    clulog(LOG_ERR, "msg_open: Unable to connect. Error %s\n",
		   strerror(errno));
	    close(sockfd);
	    return -1;
	}
    }

    if (ret == 0) /* connect completed immediately */
	goto done;

    FD_ZERO(&rfds);
    FD_SET(sockfd, &rfds);
    wfds = rfds;
    tv.tv_sec = MSGSVC_CONNECT_TIMEOUT;
    tv.tv_usec = 0;

    ret = select(sockfd + 1, &rfds, &wfds, NULL, &tv);
    if (ret == 0) {
	close(sockfd);
	errno = ETIMEDOUT;
	return -1;
    }
    if (FD_ISSET(sockfd, &rfds) || FD_ISSET(sockfd, &wfds)) {
	len = sizeof(error);
	if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &error, &len) < 0) {
	    close(sockfd);
	    return -1;
	}
    } else {
	clulog(LOG_ERR, "msg_open: select error: sockfd not set\n");
	close(sockfd);
	return -1;
    }

 done:
    fcntl(sockfd, F_SETFL, flags);

    if (error) {
	close(sockfd);
	errno = error;
	return -1;
    }

    if (fdlist_add(sockfd, MSG_CONNECTED|MSG_AUTHENTICATED,
		   proc_id_array[proc_id].msg_ins) < 0) {
	clulog(LOG_ERR,"msg_open: unable to add sock to descriptor table.\n");
	close(sockfd);
	return -1;
    }

    return sockfd;
}


/*
 * Blocking call to accept
 */
msg_handle_t
msg_accept(msg_handle_t handle)
{
    int                sockfd=handle, acceptfd;
    struct sockaddr_in cliaddr;
    socklen_t          clilen;

    /*
     *  Do some sanity checks on the handle passed in.
     */
    if (sockfd < 0) {
	clulog(LOG_ERR, "msg_accept called with bad handle %d\n",handle);
	errno = EBADF;
	return -1;
    }

    if (fdlist_getstate(sockfd) != MSG_LISTENING) {
	clulog(LOG_ERR, "msg_accept: file descriptor not in listen state.\n");
	errno = EINVAL;
	return -1;
    }

    memset(&cliaddr, 0, sizeof(cliaddr));
    memset(&clilen, 0, sizeof(clilen));

    while ((acceptfd = accept(sockfd,(struct sockaddr*)&cliaddr,&clilen))< 0) {
	if (errno == EINTR) {
	    continue;
	}

	clulog(LOG_ERR, "msg_accept: accept error %s.\n",strerror(errno));
	return -1;
    }

    fdlist_add(acceptfd, MSG_ACCEPTED, NULL);
    return acceptfd;
}


msg_handle_t
msg_accept_timeout(msg_handle_t handle, int timeout)
{
    int                sockfd = handle, acceptfd, state=0;
    struct sockaddr_in cliaddr;
    socklen_t          clilen;
    struct timeval     tv;
    fd_set             fdset;
    int                ret;

    /*
     *  Do some sanity checks on the handle passed in.
     */
    if (sockfd < 0) {
	clulog(LOG_ERR,"msg_accept_timeout called with bad handle %d\n",handle);
	return -1;
    }

    state = fdlist_getstate(sockfd);
    if (state != MSG_LISTENING) {
	clulog(LOG_ERR, "msg_accept_timeout: file descriptor not in "
	       "listen state.\n");
	return -1;
    }

    tv.tv_sec = timeout;
    tv.tv_usec = 0;
    FD_ZERO(&fdset);
    FD_SET(sockfd,&fdset);

    while ((ret = select(sockfd+1, &fdset, NULL, NULL, &tv)) < 0) {
	if (errno == EINTR)
	    continue;

	clulog(LOG_ERR, "msg_accept_timeout: select returned error "
	       "status %s\n", strerror(errno));
	return -1;
    }
    if (ret == 0) {
	return 0;
    }

    bzero(&cliaddr, sizeof(cliaddr));
    bzero(&clilen, sizeof(clilen));

    while ((acceptfd = accept(sockfd,(struct sockaddr*)&cliaddr,&clilen))< 0) {
	if (errno == EINTR) {
	    continue;
	}

	clulog(LOG_ERR, "msg_accept: accept error %s.\n",strerror(errno));
	return -1;
    }

    fdlist_add(acceptfd,MSG_ACCEPTED, NULL);
    return acceptfd;
}

int
__msg_send(msg_handle_t handle, void *buf, ssize_t count)
{
    void           *msg;
    int            msg_len=-1, bytes_written=0;
    int            state = 0;

    state = fdlist_getstate(handle);
    msg_len = msg_create(buf, count, &msg, MSG_SECURE(state));
    if ((bytes_written = write(handle, msg, msg_len)) < msg_len) {
	clulog(LOG_ERR, "__msg_send: Incomplete write. Error: %s\n", 
	       strerror(errno));
	return -1;
    }
    msg_destroy(msg);
    return(bytes_written - sizeof(struct msg_struct));
}
/*
 *  Send a message to another process.
 */
int 
msg_send(msg_handle_t handle, void *buf, ssize_t count)
{
    int  sockfd = handle, state = 0;

    if (!msgsvc_initialized) {
	if (msg_svc_init() < 0) {
	    clulog(LOG_ERR, "msg_send: unable to initialize msg subsystem.\n");
	    return -1;
	}
    }

    if (sockfd < 0) {
	clulog(LOG_ERR, "msg_send called with bad handle %d\n",handle);
	return -1;
    }

    state = fdlist_getstate(sockfd);
    if (!(MSG_CANWRITE(state))) {
	clulog(LOG_ERR, "msg_send: Attempt to write to unconnected socket\n");
	errno = EINVAL;
	return -1;
    }

    return(__msg_send(sockfd, buf, count));
}


ssize_t
__msg_receive(msg_handle_t handle, void *buf, ssize_t count, int *auth)
{
    struct msg_struct msg_hdr;
    int               sockfd = handle;
    ssize_t           retval = 0;

    if ((retval = read(sockfd, &msg_hdr, sizeof(msg_hdr))) < 
	                                           (ssize_t)sizeof(msg_hdr)) {
#if 0 /* XXX */
	clulog(LOG_DEBUG, "__msg_receive: Incomplete receive.  %d bytes "
	       "read\n", retval);
#endif
	return -1;
    }
    if (!msg_hdr.count) {
	clulog(LOG_ERR, "__msg_receive: empty response?\n");
	return 0;
    }

    /*
     * Check for the proper session Id.
     */
    if (!(retval = memcmp(msg_hdr.session_id, session_id, SESSION_ID_SIZE))) {
	*auth = 1;
	fdlist_setstate(handle, MSG_AUTHENTICATED);
    } else {
	clulog(LOG_DEBUG, "__msg_receive: output of memcmp: %d\n", retval);
	*auth = 0;
    }

    return(read(sockfd, buf, count));
}


ssize_t
msg_receive(msg_handle_t handle, void *buf, ssize_t count, int *auth)
{
    int  sockfd = handle, state=0;

    /*
     * Sanity checks
     */
    if (!msgsvc_initialized) {
	if (msg_svc_init() < 0) {
	    clulog(LOG_ERR, "msg_receive: unable to initialize msg subsystem.\n");
	    return -1;
	}
    }

    if (sockfd < 0) {
	clulog(LOG_DEBUG, "msg_receive called with bad handle %d\n",handle);
	return -1;
    }

    state = fdlist_getstate(sockfd);
    if (!(MSG_CANREAD(state))) {
	clulog(LOG_DEBUG, "msg_receive: Attempt to write to unconnected socket\n");
	return -1;
    }

    return(__msg_receive(sockfd, buf, count, auth));
}


ssize_t
msg_receive_timeout(msg_handle_t handle, void *buf, 
		    ssize_t count, int *auth, unsigned int timeout)
{
    int                  sockfd = handle;
    int                  retval, state=0;
    fd_set               fdset;
    struct timeval       tv;

    /*
     * Sanity checks
     */
    if (!msgsvc_initialized) {
	if (msg_svc_init() < 0) {
	    clulog(LOG_DEBUG, "msg_receive_timeout: unable to initialize"
		   " msg subsystem.\n");
	    return -1;
	}
    }

    if (sockfd < 0) {
	clulog(LOG_DEBUG, "msg_receive_timeout called with bad "
	       "handle %d\n", handle);
	return -1;
    }

    state = fdlist_getstate(sockfd);
    if (!(MSG_CANREAD(state))) {
	clulog(LOG_DEBUG, "msg_receive_timeout: Attempt to read from "
	       "unconnected socket\n");
	return -1;
    }

    tv.tv_sec = timeout;
    tv.tv_usec = 0;

    FD_ZERO(&fdset);
    FD_SET(sockfd,&fdset);

    retval = select(sockfd+1, &fdset, NULL, NULL, &tv);
    if (retval <= 0)
	return retval;

    return(__msg_receive(sockfd, buf, count, auth));
}


ssize_t
__msg_peek(msg_handle_t handle, void *buf, ssize_t count)
{
    char    *bigbuf;
    ssize_t ret;
    int     sockfd = handle, bigbuf_sz;
    int     hdrsz = sizeof(struct msg_struct);

    bigbuf_sz = count + hdrsz;
    bigbuf = (char*)malloc(bigbuf_sz);
    if (bigbuf == NULL) {
	clulog(LOG_DEBUG, "msg_peek: Out of memory\n");
	return -1;
    }

    /*
     * We need to account for the msg header.  So we skip past it
     * and decrement the return value by the number of bytes eaten
     * up by the header.
     */
    ret = recv(sockfd, bigbuf, bigbuf_sz, MSG_PEEK);
    if (ret < 0) {
	free(bigbuf);
	return -1;
    }
    if (ret - hdrsz > 0) {
	ret -= hdrsz;
	if (ret > count)
	    ret = count;
	memcpy(buf, bigbuf + hdrsz, ret);
	free(bigbuf);
    } else {
	ret = 0;
    }

    return ret;
}


ssize_t
msg_peek(msg_handle_t handle, void *buf, ssize_t count)
{
    int sockfd = handle, state=0;

    if (sockfd < 0) {
	clulog(LOG_ERR, "msg_peek: called with bad handle %d\n", sockfd);
	return -1;
    }

    state = fdlist_getstate(sockfd);
    if (!MSG_CANREAD(state)) {
	clulog(LOG_ERR, "msg_peek: Attempt to read from unconnected socket.\n");
	return -1;
    }

    return(__msg_peek(sockfd, buf, count));
}


void
msg_close(msg_handle_t handle)
{
    if (handle < 0) {
	clulog(LOG_ERR, "msg_close called with bad handle %d\n", handle);
	return;
    }

    if (fdlist_delete(handle) < 0) {
	clulog(LOG_WARNING, "msg_close: fdlist_delete returned -1\n");
    }
    close(handle);
}


/*
 *  Internal Helper Functions
 */
static void
msg_setup_sockin(char *host, int port, struct sockaddr_in *ins)
{
    struct hostent *hp;

    bzero(ins, sizeof(struct sockaddr_in));
    ins->sin_family = AF_INET;
    hp = gethostbyname(host);
    memcpy(&ins->sin_addr,hp->h_addr,hp->h_length);
    ins->sin_port = htons(port);
}

static unsigned long
msg_create(void *payload, ssize_t len, void **msg, int auth)
{
    unsigned long    ret;
    struct msg_struct msg_hdr;

    memset(&msg_hdr, 0, sizeof(msg_hdr));
    msg_hdr.ver = (unsigned char)MSGSVC_VERSION;
    msg_hdr.count = len;
    if (auth) {
	memcpy(msg_hdr.session_id, session_id, SESSION_ID_SIZE);
    } else {
	memset(msg_hdr.session_id, 0, SESSION_ID_SIZE);
    }

    *msg = (void *)malloc(sizeof(msg_hdr)+len);
    if (*msg == NULL) {
	clulog(LOG_ERR, "msg_create: unable to allocate memory.  error %s\n",
	       strerror(errno));
	return -1;
    }
    memcpy(*msg, &msg_hdr, sizeof(msg_hdr));
    memcpy(*msg+sizeof(msg_hdr), payload, len);

    ret = sizeof(msg_hdr) + len;
    return ret;
}

static void
msg_destroy(void *msg)
{
    if (msg != NULL)
	free(msg);
}
/*
 * Local variables:
 *   c-indent-level: 4
 *   c-basic-offset: 4
 *   tab-width: 8
 *   compile-command: "gcc -g -Wall -Wstrict-prototypes -I/usr/local/cluster/include -c msg.c"
 * End:
 */
