# Copyright 2015 Metaswitch Networks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage: parse (allow|deny) [(
      (tcp|udp) [(from [(ports <SRCPORTS>)] [(label <SRCLABEL>)] [(cidr <SRCCIDR>)])]
                [(to   [(ports <DSTPORTS>)] [(label <DSTLABEL>)] [(cidr <DSTCIDR>)])] |
      icmp [(type <ICMPTYPE> [(code <ICMPCODE>)])]
           [(from [(label <SRCLABEL>)] [(cidr <SRCCIDR>)])]
           [(to   [(label <DSTLABEL>)] [(cidr <DSTCIDR>)])] |
      [(from [(label <SRCLABEL>)] [(cidr <SRCCIDR>)])]
      [(to   [(label <DSTLABEL>)] [(cidr <DSTCIDR>)])]
    )]
"""
import docopt
import logging
import re
from pycalico.datastore_datatypes import Rule

_log = logging.getLogger(__name__)

# Regex to extract a key / value from a Kubernetes label.
LABEL_REG = re.compile("^([a-zA-Z0-9/\.\-\_]+)=([a-zA-Z0-9/\.\-\_]+)$")


class PolicyParser(object):
    """
    Class which parses Kubernetes annotation-based policy strings and
    converts them to libcalico Rules objects.

    Usage:
      parser = PolicyParser(namespace)
      rule = parser.parse_line("allow tcp from label stage=prod")
    """
    def __init__(self, namespace):
        self.namespace = namespace

    def parse_line(self, policy):
        """
        Takes a single line of policy as defined in the annotations of a
        pod and returns the equivalent libcalico Rule object.
        :param policy: Policy string from annotations.
        :return: A Rule object which represent the given policy.
        """
        _log.info("Parsing policy line: '%s'", policy)
        splits = policy.split()

        try:
            args = docopt.docopt(__doc__, argv=splits)
        except docopt.DocoptExit:
            raise ValueError("Failed to parse policy: %s", policy)

        # Generate a rule object from the arguments.
        rule = self._generate_rule(args)

        return rule

    def _generate_rule(self, arguments):
        """
        Generates a libcalico Rule object.
        :param arguments: A dictionary of arguments as generated by docopt.
        :return: A libcalico Rule object
        """
        # We only support whitelist rules.
        rule_args = {"action": "allow"}

        # Get arguments.
        if arguments.get("tcp"):
            protocol = "tcp"
        elif arguments.get("udp"):
            protocol = "udp"
        elif arguments.get("icmp"):
            protocol = "icmp"
        else:
            protocol = None

        src_ports = arguments.get("<SRCPORTS>")
        dst_ports = arguments.get("<DSTPORTS>")
        icmp_type = arguments.get("<ICMPTYPE>")
        icmp_code = arguments.get("<ICMPCODE>")
        src_net = arguments.get("<SRCCIDR>")
        src_label = arguments.get("<SRCLABEL>")
        dst_net = arguments.get("<DSTCIDR>")
        dst_label= arguments.get("<DSTLABEL>")

        # Populate rule arguments
        if protocol:
            rule_args["protocol"] = protocol
        if src_ports:
            rule_args["src_ports"] = [s.strip() for s in src_ports.split(",")]
        if dst_ports:
            rule_args["dst_ports"] = [s.strip() for s in dst_ports.split(",")]
        if icmp_type:
            rule_args["icmp_type"] = icmp_type
        if icmp_code:
            rule_args["icmp_code"] = icmp_code
        if src_net:
            rule_args["src_net"] = src_net
        if src_label:
            rule_args["src_tag"] = self._validate_label(src_label)
        if dst_net:
            rule_args["dst_net"] = dst_net
        if dst_label:
            rule_args["dst_tag"] = self._validate_label(dst_label)

        return Rule(**rule_args)

    def _validate_label(self, value):
        """
        Takes the given label, validates it, and returns the equivalent tag.
        """
        match = LABEL_REG.search(value)
        if not match:
            raise ValueError("Failed to parse %s, expecting "
                             "label of form X=Y", value)
        k, v = match.groups()
        tag = self.label_to_tag(k, v)
        return tag

    def _escape_chars(self, unescaped_string):
            """
            Calico can only handle 3 special chars, '_.-'
            This function uses regex sub to replace SCs with '_'
            """
            # Character to replace symbols
            swap_char = '_'
    
            # If swap_char is in string, double it.
            unescaped_string = re.sub(swap_char, "%s%s" % (swap_char, swap_char),
                                      unescaped_string)
    
            # Substitute all invalid chars.
            return re.sub('[^a-zA-Z0-9\.\_\-]', swap_char, unescaped_string)

    def label_to_tag(self, label_key, label_value):
            """
            Labels are key-value pairs, tags are single strings. This function
            handles that translation.
            1) Concatenate key and value with '='
            2) Prepend a pod's namespace followed by '/' if available
            3) Escape the generated string so it is Calico compatible
            :param label_key: key to label
            :param label_value: value to given key for a label
            :param namespace: Namespace string, input None if not available
            :param types: string, string, string)
            :return single string tag
            :rtype string
            """
            tag = '%s=%s' % (label_key, label_value)
            tag = '%s/%s' % (self.namespace, tag)
            tag = self._escape_chars(tag)
            return tag
