/****************************************************************************
** sslfilter.cpp - simple OpenSSL encryption I/O
** Copyright (C) 2001  Justin Karneges
**
** This program is free software; you can redistribute it and/or
** modify it under the terms of the GNU General Public License
** as published by the Free Software Foundation; either version 2
** of the License, or (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,USA.
**
****************************************************************************/

#include "sslfilter.h"

#ifdef Q_WS_WIN
#include "openssl/ssl.h"
#include "openssl/err.h"
#else
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif


struct SSL_STRUCT
{
	SSL *ssl;
	SSL_METHOD *method;
	SSL_CTX *context;
	BIO *rbio, *wbio;
};
#define d ((SSL_STRUCT *)dp)

void (*_SSL_library_init)();
void (*_SSL_load_error_strings)();
SSL_METHOD * (*_TLSv1_client_method)();
SSL_CTX * (*_SSL_CTX_new)(SSL_METHOD *);
SSL * (*_SSL_new)(SSL_CTX *);
int (*_SSL_set_ssl_method)(SSL *, SSL_METHOD *);
BIO * (*_BIO_new)(BIO_METHOD *);
BIO_METHOD * (*_BIO_s_mem)();
void (*_SSL_set_bio)(SSL *, BIO *, BIO *);
void (*_SSL_set_connect_state)(SSL *);
int (*_SSL_shutdown)(SSL *);
void (*_SSL_free)(SSL *);
void (*_SSL_CTX_free)(SSL_CTX *);
void (*_ERR_free_strings)();
void (*_ERR_remove_state)(unsigned long);
int (*_BIO_write)(BIO *, const void *, int);
int (*_BIO_read)(BIO *, void *, int);
long (*_BIO_ctrl)(BIO *, int, long, void *);
int (*_SSL_connect)(SSL *);
int (*_SSL_get_error)(SSL *, int);
int (*_SSL_do_handshake)(SSL *);
int (*_SSL_read)(SSL *, void *, int);
int (*_SSL_write)(SSL *, const void *, int);

#define _BIO_pending(b)  (int)_BIO_ctrl(b,BIO_CTRL_PENDING,0,NULL)


bool SSLFilter::loadSymbols()
{
	void *p;
	p = lib1->resolve("SSL_library_init"); if(!p) return FALSE; else _SSL_library_init = (void(*)())p;
	p = lib1->resolve("SSL_load_error_strings"); if(!p) return FALSE; else _SSL_load_error_strings = (void(*)())p;
	p = lib1->resolve("TLSv1_client_method"); if(!p) return FALSE; else _TLSv1_client_method = (SSL_METHOD *(*)())p;
	p = lib1->resolve("SSL_CTX_new"); if(!p) return FALSE; else _SSL_CTX_new = (SSL_CTX *(*)(SSL_METHOD *))p;
	p = lib1->resolve("SSL_new"); if(!p) return FALSE; else _SSL_new = (SSL *(*)(SSL_CTX *))p;
	p = lib1->resolve("SSL_set_ssl_method"); if(!p) return FALSE; else _SSL_set_ssl_method = (int(*)(SSL *, SSL_METHOD *))p;
#ifdef Q_WS_WIN
	p = lib2->resolve("BIO_new"); if(!p) return FALSE; else _BIO_new = (BIO *(*)(BIO_METHOD *))p;
	p = lib2->resolve("BIO_s_mem"); if(!p) return FALSE; else _BIO_s_mem = (BIO_METHOD *(*)())p;
#else
	p = lib1->resolve("BIO_new"); if(!p) return FALSE; else _BIO_new = (BIO *(*)(BIO_METHOD *))p;
	p = lib1->resolve("BIO_s_mem"); if(!p) return FALSE; else _BIO_s_mem = (BIO_METHOD *(*)())p;
#endif
	p = lib1->resolve("SSL_set_bio"); if(!p) return FALSE; else _SSL_set_bio = (void(*)(SSL *, BIO *, BIO *))p;
	p = lib1->resolve("SSL_set_connect_state"); if(!p) return FALSE; else _SSL_set_connect_state = (void(*)(SSL *))p;
	p = lib1->resolve("SSL_shutdown"); if(!p) return FALSE; else _SSL_shutdown = (int(*)(SSL *))p;
	p = lib1->resolve("SSL_free"); if(!p) return FALSE; else _SSL_free = (void(*)(SSL *))p;
	p = lib1->resolve("SSL_CTX_free"); if(!p) return FALSE; else _SSL_CTX_free = (void(*)(SSL_CTX *))p;
#ifdef Q_WS_WIN
	p = lib2->resolve("ERR_free_strings"); if(!p) return FALSE; else _ERR_free_strings = (void(*)())p;
	p = lib2->resolve("ERR_remove_state"); if(!p) return FALSE; else _ERR_remove_state = (void(*)(unsigned long))p;
	p = lib2->resolve("BIO_write"); if(!p) return FALSE; else _BIO_write = (int(*)(BIO *, const void *, int))p;
	p = lib2->resolve("BIO_read"); if(!p) return FALSE; else _BIO_read = (int(*)(BIO *, void *, int))p;
	p = lib2->resolve("BIO_ctrl"); if(!p) return FALSE; else _BIO_ctrl = (long(*)(BIO *, int, long, void *))p;
#else
	p = lib1->resolve("ERR_free_strings"); if(!p) return FALSE; else _ERR_free_strings = (void(*)())p;
	p = lib1->resolve("ERR_remove_state"); if(!p) return FALSE; else _ERR_remove_state = (void(*)(unsigned long))p;
	p = lib1->resolve("BIO_write"); if(!p) return FALSE; else _BIO_write = (int(*)(BIO *, const void *, int))p;
	p = lib1->resolve("BIO_read"); if(!p) return FALSE; else _BIO_read = (int(*)(BIO *, void *, int))p;
	p = lib1->resolve("BIO_ctrl"); if(!p) return FALSE; else _BIO_ctrl = (long(*)(BIO *, int, long, void *))p;
#endif
	p = lib1->resolve("SSL_connect"); if(!p) return FALSE; else _SSL_connect = (int(*)(SSL *))p;
	p = lib1->resolve("SSL_get_error"); if(!p) return FALSE; else _SSL_get_error = (int(*)(SSL *, int))p;
	p = lib1->resolve("SSL_do_handshake"); if(!p) return FALSE; else _SSL_do_handshake = (int(*)(SSL *))p;
	p = lib1->resolve("SSL_read"); if(!p) return FALSE; else _SSL_read = (int(*)(SSL *, void *, int))p;
	p = lib1->resolve("SSL_write"); if(!p) return FALSE; else _SSL_write = (int(*)(SSL *, const void *, int))p;

	return TRUE;
}

SSLFilter::SSLFilter()
{
	lib1 = 0;
	lib2 = 0;
	dp = 0;

#ifdef Q_WS_WIN
	lib1 = new QLibrary("libssl32.dll");
	lib2 = new QLibrary("libeay32.dll");
	if(!lib1->load() || !lib2->load()) {
		delete lib1;
		delete lib2;
		lib1 = 0;
		lib2 = 0;
		supported = FALSE;

		qDebug("SSLFilter: unable to load win32 openssl\n");
		return;
	}

#else
	lib1 = new QLibrary("libssl.so");
	if(!lib1->load()) {
		delete lib1;
		lib1 = 0;
		supported = FALSE;

		qDebug("SSLFilter: unable to load unix openssl\n");
		return;
	}

#endif
	if(!loadSymbols()) {
		delete lib1;
		if(lib2)
			delete lib2;
		lib1 = 0;
		lib2 = 0;
		supported = FALSE;

		qDebug("SSLFilter: unable to load all symbols\n");
		return;
	}
	supported = TRUE;
	qDebug("SSLFilter: successfully loaded ssl libraries\n");

	// init the library
	_SSL_library_init();
	_SSL_load_error_strings();
}

SSLFilter::~SSLFilter()
{
	if(supported) {
		reset();

		_ERR_free_strings();
		_ERR_remove_state(0);

		if(lib1)
			delete lib1;
		if(lib2)
			delete lib2;
	}
}

bool SSLFilter::isSupported()
{
	return supported;
}

void SSLFilter::reset()
{
	if(dp) {
		if(d->ssl) {
			_SSL_shutdown(d->ssl);
			_SSL_free(d->ssl);
		}
		if(d->context) {
			_SSL_CTX_free(d->context);
		}
		delete d;
		dp = 0;
	}

	sendQueue.resize(0);
	recvQueue.resize(0);
}

bool SSLFilter::begin()
{
	if(!supported)
		return FALSE;
	if(dp)
		reset();

	dp = new SSL_STRUCT;
	d->ssl = 0;
	d->method = 0;
	d->context = 0;

	// get our handles
	d->method = _TLSv1_client_method();
	if(!d->method) {
		reset();
		return FALSE;
	}
	d->context = _SSL_CTX_new(d->method);
	if(!d->context) {
		reset();
		return FALSE;
	}
	d->ssl = _SSL_new(d->context);
	if(!d->ssl) {
		reset();
		return FALSE;
	}
	_SSL_set_ssl_method(d->ssl, d->method); // can this return error?

	// setup the memory bio
	// these could error out, but i don't see how
	d->rbio = _BIO_new(_BIO_s_mem());
	d->wbio = _BIO_new(_BIO_s_mem());

	// these always work
	_SSL_set_bio(d->ssl, d->rbio, d->wbio);
	_SSL_set_connect_state(d->ssl);

	sslAction = SSL_CONNECT;
	sslUpdate();

	return TRUE;
}

void SSLFilter::putIncomingSSLData(const QByteArray &a)
{
	_BIO_write(d->rbio, a.data(), a.size());
	sslUpdate();
}

bool SSLFilter::isOutgoingSSLData()
{
	return (_BIO_pending(d->wbio) > 0) ? TRUE: FALSE;
}

QByteArray SSLFilter::getOutgoingSSLData()
{
	QByteArray a;

	int size = _BIO_pending(d->wbio);
	if(size <= 0)
		return a;
	a.resize(size);

	int r = _BIO_read(d->wbio, a.data(), size);
	if(r <= 0) {
		a.resize(0);
		return a;
	}
	if(r != size)
		a.resize(r);

	return a;
}

void SSLFilter::sslUpdate()
{
	if(sslAction == SSL_CONNECT) {
		int ret = _SSL_connect(d->ssl);
		if(ret == 0) {
			reset();
			doError();
			return;
		}
		if(ret > 0) {
			if(_SSL_do_handshake(d->ssl) < 0) {
				reset();
				doError();
				return;
			}

			sslAction = SSL_ACTIVE;
			processSendQueue();
		}
	}

	if(isOutgoingSSLData()) {
		emit outgoingSSLDataReady();
	}

	// try to read incoming unencrypted data
	sslReadAll();

	if(isRecvData())
		emit readyRead();
}

void SSLFilter::send(const QByteArray &a)
{
	int oldsize = sendQueue.size();
	sendQueue.resize(oldsize + a.size());
	memcpy(sendQueue.data() + oldsize, a.data(), a.size());

	if(sslAction == SSL_ACTIVE)
		processSendQueue();
}

bool SSLFilter::isRecvData()
{
	return (recvQueue.size() > 0) ? TRUE : FALSE;
}

void SSLFilter::sslReadAll()
{
	QByteArray a;

	while(1) {
		a.resize(4096);
		int x = _SSL_read(d->ssl, a.data(), a.size());
		if(x <= 0)
			break;

		if(x != (int)a.size())
			a.resize(x);

		int oldsize = recvQueue.size();
		recvQueue.resize(oldsize + a.size());
		memcpy(recvQueue.data() + oldsize, a.data(), a.size());
	}
}

QByteArray SSLFilter::recv()
{
	QByteArray a = recvQueue;
	a.detach();
	recvQueue.resize(0);
	return a;
}

void SSLFilter::processSendQueue()
{
	if(sendQueue.size() > 0) {
		_SSL_write(d->ssl, sendQueue.data(), sendQueue.size());
		sendQueue.resize(0);
		sslUpdate();
	}
}

void SSLFilter::doError()
{
	emit error();
}
