##
# Copyright (c) 2011 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##

from twext.internet.ssl import ChainingOpenSSLContextFactory
from twext.python.log import Logger, LoggingMixIn
from twext.python.log import LoggingMixIn
from twext.web2 import responsecode
from twext.web2.dav import davxml
from twext.web2.dav.noneprops import NonePropertyStore
from twext.web2.dav.resource import DAVResource
from twext.web2.http import Response
from twext.web2.http_headers import MimeType
from twext.web2.server import parsePOSTData
from twisted.application import service
from twisted.internet import reactor, protocol
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.protocol import ClientFactory, ReconnectingClientFactory
from twistedcaldav.extensions import DAVResource, DAVResourceWithoutChildrenMixin
from twistedcaldav.resource import ReadOnlyNoCopyResourceMixIn
import OpenSSL
import struct
import time



log = Logger()


class ApplePushNotifierService(service.MultiService, LoggingMixIn):
    """
    ApplePushNotifierService is a MultiService responsible for
    setting up the APN provider and feedback connections.  Once
    connected, calling its enqueue( ) method sends notifications
    to any device token which is subscribed to the enqueued key.

    The Apple Push Notification protocol is described here:

    http://developer.apple.com/library/ios/#documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingWIthAPS/CommunicatingWIthAPS.html
    """

    @classmethod
    def makeService(cls, settings, store, testConnectorClass=None,
        reactor=None):
        """
        Creates the various "subservices" that work together to implement
        APN, including "provider" and "feedback" services for CalDAV and
        CardDAV.

        @param settings: The portion of the configuration specific to APN
        @type settings: C{dict}

        @param store: The db store for storing/retrieving subscriptions
        @type store: L{IDataStore}

        @param testConnectorClass: Used for unit testing; implements
            connect( ) and receiveData( )
        @type testConnectorClass: C{class}

        @param reactor: Used for unit testing; allows tests to advance the
            clock in order to test the feedback polling service.
        @type reactor: L{twisted.internet.task.Clock}

        @return: instance of L{ApplePushNotifierService}
        """

        service = cls()

        service.store = store
        service.providers = {}
        service.feedbacks = {}
        service.dataHost = settings["DataHost"]

        for protocol in ("CalDAV", "CardDAV"):

            if settings[protocol]["CertificatePath"]:

                providerTestConnector = None
                feedbackTestConnector = None
                if testConnectorClass is not None:
                    providerTestConnector = testConnectorClass()
                    feedbackTestConnector = testConnectorClass()

                provider = APNProviderService(
                    settings["ProviderHost"],
                    settings["ProviderPort"],
                    settings[protocol]["CertificatePath"],
                    settings[protocol]["PrivateKeyPath"],
                    testConnector=providerTestConnector,
                    reactor=reactor,
                )
                provider.setServiceParent(service)
                service.providers[protocol] = provider
                service.log_info("APNS %s topic: %s" %
                    (protocol, settings[protocol]["Topic"]))

                feedback = APNFeedbackService(
                    service.store,
                    settings["FeedbackUpdateSeconds"],
                    settings["FeedbackHost"],
                    settings["FeedbackPort"],
                    settings[protocol]["CertificatePath"],
                    settings[protocol]["PrivateKeyPath"],
                    testConnector=feedbackTestConnector,
                    reactor=reactor,
                )
                feedback.setServiceParent(service)
                service.feedbacks[protocol] = feedback

        return service


    @inlineCallbacks
    def enqueue(self, op, id):
        """
        Sends an Apple Push Notification to any device token subscribed to
        this id.

        @param op: The operation that took place, either "create" or "update"
            (ignored in this implementation)
        @type op: C{str}

        @param id: The identifier of the resource that was updated, including
            a prefix indicating whether this is CalDAV or CardDAV related.
            The prefix is separated from the id with "|", e.g.:

            "CalDAV|abc/def"

            The id is an opaque token as far as this code is concerned, and
            is used in conjunction with the prefix and the server hostname
            to build the actual key value that devices subscribe to.
        @type id: C{str}
        """

        try:
            protocol, id = id.split("|", 1)
        except ValueError:
            # id has no protocol, so we can't do anything with it
            self.log_error("Notification id '%s' is missing protocol" % (id,))
            return

        provider = self.providers.get(protocol, None)
        if provider is not None:
            key = "/%s/%s/%s/" % (protocol, self.dataHost, id)

            # Look up subscriptions for this key
            txn = self.store.newTransaction()
            subscriptions = (yield txn.apnSubscriptionsByKey(key))
            yield txn.commit()

            numSubscriptions = len(subscriptions)
            if numSubscriptions > 0:
                self.log_debug("Sending %d APNS notifications for %s" %
                    (numSubscriptions, key))
                for token, guid in subscriptions:
                    provider.sendNotification(token, key)



class APNProviderProtocol(protocol.Protocol, LoggingMixIn):
    """
    Implements the Provider portion of APNS
    """

    # Sent by provider
    COMMAND_SIMPLE   = 0
    COMMAND_ENHANCED = 1

    # Received by provider
    COMMAND_ERROR    = 8

    # Returned only for an error.  Successful notifications get no response.
    STATUS_CODES = {
        0   : "No errors encountered",
        1   : "Processing error",
        2   : "Missing device token",
        3   : "Missing topic",
        4   : "Missing payload",
        5   : "Invalid token size",
        6   : "Invalid topic size",
        7   : "Invalid payload size",
        8   : "Invalid token",
        255 : "None (unknown)",
    }

    def makeConnection(self, transport):
        self.identifier = 0
        # self.log_debug("ProviderProtocol makeConnection")
        protocol.Protocol.makeConnection(self, transport)

    def connectionMade(self):
        self.log_debug("ProviderProtocol connectionMade")
        # Store a reference to ourself on the factory so the service can
        # later call us
        self.factory.connection = self
        self.factory.clientConnectionMade()

    def connectionLost(self, reason=None):
        # self.log_debug("ProviderProtocol connectionLost: %s" % (reason,))
        # Clear the reference to us from the factory
        self.factory.connection = None

    def dataReceived(self, data):
        self.log_debug("ProviderProtocol dataReceived %d bytes" % (len(data),))
        command, status, identifier = struct.unpack("!BBI", data)
        if command == self.COMMAND_ERROR:
            self.processError(status, identifier)

    def processError(self, status, identifier):
        """
        Handles an error message we've received from on feedback channel.
        Not much to do here besides logging the error.

        @param status: The status value returned from APN Feedback server
        @type status: C{int}

        @param identifier: The identifier of the outbound push notification
            message which had a problem.
        @type status: C{int}
        """
        msg = self.STATUS_CODES.get(status, "Unknown status code")
        self.log_error("Received APN error %d on identifier %d: %s" % (status, identifier, msg))

    def sendNotification(self, token, key):
        """
        Sends a push notification message for the key to the device associated
        with the token.

        @param token: The device token subscribed to the key
        @type token: C{str}

        @param key: The key we're sending a notification about
        @type key: C{str}
        """

        try:
            binaryToken = token.replace(" ", "").decode("hex")
        except:
            self.log_error("Invalid APN token in database: %s" % (token,))
            return

        self.identifier += 1
        payload = '{"key" : "%s"}' % (key,)
        payloadLength = len(payload)
        self.log_debug("Sending APNS notification to %s: id=%d payload=%s" %
            (token, self.identifier, payload))

        self.transport.write(
            struct.pack("!BIIH32sH%ds" % (payloadLength,),
                self.COMMAND_ENHANCED,  # Command
                self.identifier,        # Identifier
                0,                      # Expiry
                32,                     # Token Length
                binaryToken,            # Token
                payloadLength,          # Payload Length
                payload,                # Payload in JSON format
            )
        )


class APNProviderFactory(ReconnectingClientFactory, LoggingMixIn):

    protocol = APNProviderProtocol

    def __init__(self, service):
        self.service = service

    def clientConnectionMade(self):
        self.service.clientConnectionMade()

    def clientConnectionLost(self, connector, reason):
        # self.log_info("Connection to APN server lost: %s" % (reason,))
        ReconnectingClientFactory.clientConnectionLost(self, connector, reason)

    def clientConnectionFailed(self, connector, reason):
        self.log_error("Unable to connect to APN server: %s" % (reason,))
        self.connected = False
        ReconnectingClientFactory.clientConnectionFailed(self, connector,
            reason)


class APNConnectionService(service.Service, LoggingMixIn):

    def __init__(self, host, port, certPath, keyPath, chainPath="",
        sslMethod="TLSv1_METHOD", testConnector=None, reactor=None):

        self.host = host
        self.port = port
        self.certPath = certPath
        self.keyPath = keyPath
        self.chainPath = chainPath
        self.sslMethod = sslMethod
        self.testConnector = testConnector

        if reactor is None:
            from twisted.internet import reactor
        self.reactor = reactor

    def connect(self, factory):
        if self.testConnector is not None:
            # For testing purposes
            self.testConnector.connect(self, factory)
        else:
            context = ChainingOpenSSLContextFactory(
                self.keyPath,
                self.certPath,
                certificateChainFile=self.chainPath,
                sslmethod=getattr(OpenSSL.SSL, self.sslMethod)
            )
            reactor.connectSSL(self.host, self.port, factory, context)


class APNProviderService(APNConnectionService):

    def __init__(self, host, port, certPath, keyPath, chainPath="",
        sslMethod="TLSv1_METHOD", testConnector=None, reactor=None):

        APNConnectionService.__init__(self, host, port, certPath, keyPath,
            chainPath="", sslMethod=sslMethod,
            testConnector=testConnector, reactor=reactor)

        self.factory = None
        self.queue = []

    def startService(self):
        self.log_debug("APNProviderService startService")
        self.factory = APNProviderFactory(self)
        self.connect(self.factory)

    def stopService(self):
        self.log_debug("APNProviderService stopService")

    def clientConnectionMade(self):
        # Service the queue
        if self.queue:
            # Copy and clear the queue.  Any notifications that don't get
            # sent will be put back into the queue.
            queued = list(self.queue)
            self.queue = []
            for token, key in queued:
                self.sendNotification(token, key)

    def sendNotification(self, token, key):
        # Service has reference to factory has reference to protocol instance
        connection = getattr(self.factory, "connection", None)
        if connection is None:
            self.log_debug("APNProviderService has no connection; queuing: %s %s" % (token, key))
            tokenKeyPair = (token, key)
            if tokenKeyPair not in self.queue:
                self.queue.append(tokenKeyPair)
        else:
            connection.sendNotification(token, key)


class APNFeedbackProtocol(protocol.Protocol, LoggingMixIn):
    """
    Implements the Feedback portion of APNS
    """

    def connectionMade(self):
        self.log_debug("FeedbackProtocol connectionMade")

    def dataReceived(self, data):
        self.log_debug("FeedbackProtocol dataReceived %d bytes" % (len(data),))
        timestamp, tokenLength, binaryToken = struct.unpack("!IH32s", data)
        token = binaryToken.encode("hex").lower()
        return self.processFeedback(timestamp, token)

    @inlineCallbacks
    def processFeedback(self, timestamp, token):
        """
        Handles a feedback message indicating that the given token is no
        longer active as of the timestamp, and its subscription should be
        removed as long as that device has not re-subscribed since the
        timestamp.

        @param timestamp: Seconds since the epoch
        @type timestamp: C{int}

        @param token: The device token to unsubscribe
        @type token: C{str}
        """

        self.log_debug("FeedbackProtocol processFeedback time=%d token=%s" %
            (timestamp, token))
        txn = self.factory.store.newTransaction()
        subscriptions = (yield txn.apnSubscriptionsByToken(token))

        for key, modified, guid in subscriptions:
            if timestamp > modified:
                self.log_debug("FeedbackProtocol removing subscription: %s %s" %
                    (token, key))
                yield txn.removeAPNSubscription(token, key)
        yield txn.commit()


class APNFeedbackFactory(ClientFactory, LoggingMixIn):

    protocol = APNFeedbackProtocol

    def __init__(self, store):
        self.store = store

    def clientConnectionFailed(self, connector, reason):
        self.log_error("Unable to connect to APN feedback server: %s" %
            (reason,))
        self.connected = False
        ClientFactory.clientConnectionFailed(self, connector, reason)


class APNFeedbackService(APNConnectionService):

    def __init__(self, store, updateSeconds, host, port, certPath, keyPath,
        chainPath="", sslMethod="TLSv1_METHOD", testConnector=None,
        reactor=None):

        APNConnectionService.__init__(self, host, port, certPath, keyPath,
            chainPath="", sslMethod=sslMethod,
            testConnector=testConnector, reactor=reactor)

        self.store = store
        self.updateSeconds = updateSeconds

    def startService(self):
        self.log_debug("APNFeedbackService startService")
        self.factory = APNFeedbackFactory(self.store)
        self.checkForFeedback()

    def stopService(self):
        self.log_debug("APNFeedbackService stopService")
        if self.nextCheck is not None:
            self.nextCheck.cancel()

    def checkForFeedback(self):
        self.nextCheck = None
        self.log_debug("APNFeedbackService checkForFeedback")
        self.connect(self.factory)
        self.nextCheck = self.reactor.callLater(self.updateSeconds,
            self.checkForFeedback)


class APNSubscriptionResource(ReadOnlyNoCopyResourceMixIn,
    DAVResourceWithoutChildrenMixin, DAVResource, LoggingMixIn):
    """
    The DAV resource allowing clients to subscribe to Apple push notifications.
    To subscribe, a client should first determine the key they are interested
    in my examining the "pushkey" DAV property on the home or collection they
    want to monitor.  Next the client sends an authenticated HTTP GET or POST
    request to this resource, passing their device token and the key in either
    the URL params or in the POST body.
    """

    def __init__(self, parent, store):
        DAVResource.__init__(
            self, principalCollections=parent.principalCollections()
        )
        self.parent = parent
        self.store = store

    def deadProperties(self):
        if not hasattr(self, "_dead_properties"):
            self._dead_properties = NonePropertyStore(self)
        return self._dead_properties

    def etag(self):
        return None

    def checkPreconditions(self, request):
        return None

    def defaultAccessControlList(self):
        return davxml.ACL(
            # DAV:Read for authenticated principals
            davxml.ACE(
                davxml.Principal(davxml.Authenticated()),
                davxml.Grant(
                    davxml.Privilege(davxml.Read()),
                ),
                davxml.Protected(),
            ),
            # DAV:Write for authenticated principals
            davxml.ACE(
                davxml.Principal(davxml.Authenticated()),
                davxml.Grant(
                    davxml.Privilege(davxml.Write()),
                ),
                davxml.Protected(),
            ),
        )

    def contentType(self):
        return MimeType.fromString("text/html; charset=utf-8");

    def resourceType(self):
        return None

    def isCollection(self):
        return False

    def isCalendarCollection(self):
        return False

    def isPseudoCalendarCollection(self):
        return False

    @inlineCallbacks
    def http_POST(self, request):
        yield self.authorize(request, (davxml.Write(),))
        yield parsePOSTData(request)
        code, msg = (yield self.processSubscription(request))
        returnValue(self.renderResponse(code, body=msg))

    http_GET = http_POST

    def principalFromRequest(self, request):
        """
        Given an authenticated request, return the principal based on
        request.authnUser
        """
        principal = None
        for collection in self.principalCollections():
            data = request.authnUser.children[0].children[0].data
            principal = collection._principalForURI(data)
            if principal is not None:
                return principal

    @inlineCallbacks
    def processSubscription(self, request):
        """
        Given an authenticated request, use the token and key arguments
        to add a subscription entry to the database.

        @param request: The request to process
        @type request: L{twext.web2.server.Request}
        """

        token = request.args.get("token", None)
        key = request.args.get("key", None)
        if key and token:
            key = key[0]
            token = token[0].replace(" ", "").lower()
            principal = self.principalFromRequest(request)
            guid = principal.record.guid
            yield self.addSubscription(token, key, guid)
            code = responsecode.OK
            msg = None
        else:
            code = responsecode.BAD_REQUEST
            msg = "Invalid request: both 'token' and 'key' must be provided"

        returnValue((code, msg))

    @inlineCallbacks
    def addSubscription(self, token, key, guid):
        """
        Add a subscription (or update its timestamp if already there).

        @param token: The device token, must be lowercase
        @type token: C{str}

        @param key: The push key
        @type key: C{str}

        @param guid: The GUID of the subscriber principal
        @type guid: C{str}
        """
        now = int(time.time()) # epoch seconds
        txn = self.store.newTransaction()
        yield txn.addAPNSubscription(token, key, now, guid)
        yield txn.commit()

    def renderResponse(self, code, body=None):
        response = Response(code, {}, body)
        response.headers.setHeader("content-type", MimeType("text", "html"))
        return response
