# -*- coding: utf-8 -*-
# Moovida - Home multimedia server
# Copyright (C) 2006-2009 Fluendo Embedded S.L. (www.fluendo.com).
# All rights reserved.
#
# This file is available under one of two license agreements.
#
# This file is licensed under the GPL version 3.
# See "LICENSE.GPL" in the root of this distribution including a special
# exception to use Moovida with Fluendo's plugins.
#
# The GPL part of Moovida is also available under a commercial licensing
# agreement from Fluendo.
# See "LICENSE.Moovida" in the root directory of this distribution package
# for details on that license.
#
# Authors: Alessandro Decina <alessandro@fluendo.com>

from twisted.trial.unittest import TestCase, SkipTest
from twisted.internet import defer, error

from elisa.plugins.amp.master import Master, SlaveProcessProtocol, StartError
from elisa.plugins.amp.protocol import Ping, MasterFactory, MasterProtocol, \
                                       SlaveFactory, SlaveProtocol

import platform

# useful when debugging the protocol
#import sys
#from twisted.python import log
#log.startLogging(sys.stdout, setStdout=0)

class StubSlaveProcessProtocol(SlaveProcessProtocol):
    def __init__(self, *args):
        SlaveProcessProtocol.__init__(self, *args)
    def processEnded(self, reason):
        self.master.dead_processes.append((self, reason))
        SlaveProcessProtocol.processEnded(self, reason)
        if not self.master._spawned:
            self.master.all_slaves_dead.callback(self)


class TestMasterProtocol(MasterProtocol):
    ping_period = 2
    ping_timeout = 1

    def ping(self):
        if self.factory.master.dead:
            return defer.Deferred()

        return MasterProtocol.ping(self)
    Ping.responder(ping)

class TestMasterFactory(MasterFactory):
    protocol = TestMasterProtocol

class TestMasterLockProtocol(MasterProtocol):

    def connectionDied(self):
        MasterProtocol.connectionDied(self)
        self.factory.master.locked_processes.append(self)
        self.factory.master.slave_locked.callback(self)

class TestMasterLockFactory(MasterFactory):
    protocol = TestMasterLockProtocol

class TestMaster(Master):
    serverFactory = TestMasterFactory
    slaveProcessProtocolFactory = StubSlaveProcessProtocol

    def __init__(self, address=None, slave_runner=None):
        Master.__init__(self, address, slave_runner)
        self.dead_processes = []
        self.timeout = 0
        self.all_slaves_dead = defer.Deferred()
        self.dead = False

class TestMasterLock(TestMaster):
    serverFactory = TestMasterLockFactory

    def __init__(self, address=None, slave_runner=None):
        TestMaster.__init__(self, address, slave_runner)
        self.locked_processes = []
        self.slave_locked = defer.Deferred()

class TestSlaveProtocol(SlaveProtocol):
    def ping(self):
        pings = getattr(self, 'pings', 0)
        timeout = self.factory.timeout
        if timeout and pings == timeout:
            return defer.Deferred()

        self.pings = pings + 1

        return SlaveProtocol.ping(self)
    Ping.responder(ping)

class TestSlaveFactory(SlaveFactory):
    protocol = TestSlaveProtocol

    def __init__(self, cookie, timeout):
        SlaveFactory.__init__(self, cookie)
        self.timeout = timeout

class TestDeadlockSlaveProtocol(TestSlaveProtocol):
    def ping(self):
        # I respond to one ping and only one. At second ping I start
        # looping forever to simulate a deadlock
        pings = getattr(self, 'pings', 0)
        if pings:
            import time
            while True:
                time.sleep(0.1)
        else:
            self.pings = 1

        return TestSlaveProtocol.ping(self)
    Ping.responder(ping)

class TestDeadlockSlaveFactory(TestSlaveFactory):
    protocol = TestDeadlockSlaveProtocol

def dying_runner(cookie, connection_string):
    """
    Runner that does nothing and dies.
    """
    return 1

def blocked_runner(cookie, connection_string):
    """
    Runner that doesn't respond to pings, not even the first.
    """
    from twisted.internet import reactor

    reactor.run()

def okayish_runner(cookie, connection_string, disconnect=False, timeout=0,
                   slaveFactory=TestSlaveFactory):
    from twisted.internet import reactor

    tokens = connection_string.split(':', 3)
    assert tokens[0] in ('tcp', 'unix')

    if tokens[0] == 'tcp':
        host, port = tokens[1:]
        port = int(port)
        connector = reactor.connectTCP(host, port,
                                       slaveFactory(cookie, timeout))
    else:
        address = tokens[1]
        connector = reactor.connectUNIX(address,
                                        slaveFactory(cookie, timeout))

    if disconnect:
        reactor.callLater(2, connector.disconnect)

    reactor.run()

def disconnecting_runner(cookie, connection_string):
    okayish_runner(cookie, connection_string, True)

def timeout_runner(cookie, connection_string):
    okayish_runner(cookie, connection_string, timeout=3)

def deadlock_runner(cookie, connection_string):
    okayish_runner(cookie, connection_string,
                   slaveFactory=TestDeadlockSlaveFactory)

class MasterMixin(object):

    def tearDown(self):
        if hasattr(self, "current"):
            self.master._slave_runner = self.current
            del self.current

        return self.master.stop()

    def setRunner(self, runner):
        msg = "setRunner may only called once per test"
        assert hasattr(self, "current") is False, msg

        self.current = self.master._slave_runner
        self.master._slave_runner = '%s.%s' % (__name__, runner)

class MasterTestMixin(MasterMixin):

    def testStartSlavesFail(self):
        """
        Call startSlaves() and make all the slaves fail to start. Check that
        startSlaves() errbacks when all the processes are ended.
        """
        def startSlavesCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stopSlaves()

        self.setRunner('dying_runner')
        dfr = self.master.startSlaves(2, 7)
        self.failUnlessFailure(dfr, StartError)
        dfr.addCallback(startSlavesCb) 

        return dfr

    def testStartStopSlaves(self):
        """
        Start and stop slaves.
        Check that the slaves are started correctly and that they die when
        master.stopSlaves() is called.
        """
        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.stopSlaves()
            dfr.addCallback(slavesStoppedCb)

            return dfr

        self.setRunner('okayish_runner')
        dfr = self.master.startSlaves(2, 7)
        dfr.addCallback(slavesStartedCb)

        return dfr

    def testStartSlavesSpawnTimeout(self):
        """
        Start slaves that don't connect to the master, resulting in a timeout.
        """
        def startSlavesEb(failure):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            for process, reason in self.master.dead_processes:
                self.failUnlessEqual(reason.value.signal, 9)

            return self.master.stop()

        self.setRunner('blocked_runner')
        dfr = self.master.startSlaves(2, 2)
        dfr.addErrback(startSlavesEb)

        return dfr

    def testSlavesDisconnect(self):
        """
        Start slaves that disconnect from the master after a while.
        """
        def slavesDeadCb(result):
            return self.master.stopSlaves().addCallback(lambda result:
                self.master.stop())

        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesDeadCb)

            return dfr

        self.setRunner('disconnecting_runner')
        dfr = self.master.startSlaves(2, 7)
        dfr.addCallback(slavesStartedCb)

        return dfr

    def testSlavesPingTimeout(self):
        """
        Run a slave that stays connected but stops answering ping requests
        after a while.
        """
        def slavesDeadCb(result):
            dfr = self.master.stopSlaves()
            dfr.addErrback(slavesDeadEb)
            return dfr

        def slavesDeadEb(failure):
            failure.trap(error.ConnectionLost)

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesDeadCb)
            dfr.addErrback(slavesDeadEb)
            return dfr

        self.setRunner('timeout_runner')
        dfr = self.master.startSlaves(2, 7)
        dfr.addCallback(slavesStartedCb)
        return dfr

    def testMasterPingTimeout(self):
        """
        check that a slave kills itself if the master does not answer to pings
        """
        
        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.dead_processes), 2)
            return self.master.stop()

        def slavesStartedCb(result):
            dfr = self.master.all_slaves_dead
            dfr.addCallback(slavesStoppedCb)

            return dfr

        self.setRunner('okayish_runner')
        self.master.dead = True
        dfr = self.master.startSlaves(2, 7)
        dfr.addCallback(slavesStartedCb)

        return dfr

class UnixMasterTestCase(MasterTestMixin, TestCase):
    """
    Test the Master with unix sockets
    """
    def setUp(self):
        if platform.system() != 'Linux':
            raise SkipTest("This is only supported in Linux")
        self.master = TestMaster(address='unix:')
        self.master.start()


class TCPMasterTestCase(MasterTestMixin, TestCase):
    """
    Test the Master with TCP
    """
    def setUp(self):
        self.master = TestMaster(address='tcp:')
        self.master.start()

class MasterLockTestCase(MasterMixin, TestCase):

    def setUp(self):
        self.master = TestMasterLock(address='tcp:')
        self.master.start()

    def testLockedSlaves(self):
        """
        Start two slaves that will lock after first ping. Check that it
        will be killed and restarted.
        """

        def slavesStoppedCb(result):
            self.failUnlessEqual(len(self.master.locked_processes), 1)

        def slavesStartedCb(result):
            dfr = self.master.slave_locked
            dfr.addCallback(slavesStoppedCb)

            return dfr

        self.setRunner('deadlock_runner')
        dfr = self.master.startSlaves(1, 2)
        dfr.addCallback(slavesStartedCb)

        return dfr

    testLockedSlaves.timeout = 15
