#!/usr/bin/env python3

# Copyright 2017-2021 VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library.  If not, see <http://www.gnu.org/licenses/>.

import os
import re
import sys
import glob
import argparse

from vyos.ifconfig import Section
from vyos.ifconfig import Interface
from vyos.ifconfig import VRRP
from vyos.util import cmd, call


# interfaces = Sections.reserved()
interfaces = ['eno', 'ens', 'enp', 'enx', 'eth', 'vmnet', 'lo', 'tun', 'wan', 'pppoe']
glob_ifnames = '/sys/class/net/({})*'.format('|'.join(interfaces))


actions = {}
def register(name):
    """
    Decorator to register a function into actions with a name.
    `actions[name]' can be used to call the registered functions.
    We wrap each function in a SIGPIPE handler as all registered functions
    can be subject to a broken pipe if there are a lot of interfaces.
    """
    def _register(function):
        def handled_function(*args, **kwargs):
            try:
                function(*args, **kwargs)
            except BrokenPipeError:
                # Flush output to /dev/null and bail out.
                os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stdout.fileno())
                sys.exit(1)
        actions[name] = handled_function
        return handled_function
    return _register


def filtered_interfaces(ifnames, iftypes, vif, vrrp):
    """
    get all the interfaces from the OS and returns them
    ifnames can be used to filter which interfaces should be considered

    ifnames: a list of interfaces names to consider, empty do not filter
    return an instance of the interface class
    """
    if isinstance(iftypes, list):
        for iftype in iftypes:
            yield from filtered_interfaces(ifnames, iftype, vif, vrrp)

    for ifname in Section.interfaces(iftypes):
        # Bail out early if interface name not part of our search list
        if ifnames and ifname not in ifnames:
            continue

        # As we are only "reading" from the interface - we must use the
        # generic base class which exposes all the data via a common API
        interface = Interface(ifname, create=False, debug=False)

        # VLAN interfaces have a '.' in their name by convention
        if vif and not '.' in ifname:
            continue

        if vrrp:
            vrrp_interfaces = VRRP.active_interfaces()
            if ifname not in vrrp_interfaces:
                continue

        yield interface


def split_text(text, used=0):
    """
    take a string and attempt to split it to fit with the width of the screen

    text: the string to split
    used: number of characted already used in the screen
    """
    no_tty = call('tty -s')

    returned = cmd('stty size') if not no_tty else ''
    if len(returned) == 2:
        rows, columns = [int(_) for _ in returned]
    else:
        rows, columns = (40, 80)

    desc_len = columns - used

    line = ''
    for word in text.split():
        if len(line) + len(word) < desc_len:
            line = f'{line} {word}'
            continue
        if line:
            yield line[1:]
        else:
            line = f'{line} {word}'

    yield line[1:]


def get_counter_val(clear, now):
    """
    attempt to correct a counter if it wrapped, copied from perl

    clear: previous counter
    now:   the current counter
    """
    # This function has to deal with both 32 and 64 bit counters
    if clear == 0:
        return now

    # device is using 64 bit values assume they never wrap
    value = now - clear
    if (now >> 32) != 0:
        return value

    # The counter has rolled.  If the counter has rolled
    # multiple times since the clear value, then this math
    # is meaningless.
    if (value < 0):
        value = (4294967296 - clear) + now

    return value


@register('help')
def usage(*args):
    print(f"Usage: {sys.argv[0]} [intf=NAME|intf-type=TYPE|vif|vrrp] action=ACTION")
    print(f"  NAME = " + ' | '.join(Section.interfaces()))
    print(f"  TYPE = " + ' | '.join(Section.sections()))
    print(f"  ACTION = " + ' | '.join(actions))
    sys.exit(1)


@register('allowed')
def run_allowed(**kwarg):
    sys.stdout.write(' '.join(Section.interfaces()))


def pppoe(ifname):
    out = cmd(f'ps -C pppd -f')
    if ifname in out:
        return 'C'
    elif ifname in [_.split('/')[-1] for _ in glob.glob('/etc/ppp/peers/pppoe*')]:
	    return 'D'
    return ''


@register('show')
def run_show_intf(ifnames, iftypes, vif, vrrp):
    handled = []
    for interface in filtered_interfaces(ifnames, iftypes, vif, vrrp):
        handled.append(interface.ifname)
        cache = interface.operational.load_counters()

        out = cmd(f'ip addr show {interface.ifname}')
        out = re.sub(f'^\d+:\s+','',out)
        if re.search('link/tunnel6', out):
            tunnel = cmd(f'ip -6 tun show {interface.ifname}')
            # tun0: ip/ipv6 remote ::2 local ::1 encaplimit 4 hoplimit 64 tclass inherit flowlabel inherit (flowinfo 0x00000000)
            tunnel = re.sub('.*encap', 'encap', tunnel)
            out = re.sub('(\n\s+)(link/tunnel6)', f'\g<1>{tunnel}\g<1>\g<2>', out)

        print(out)

        timestamp = int(cache.get('timestamp', 0))
        if timestamp:
            when = interface.operational.strtime(timestamp)
            print(f'    Last clear: {when}')

        description = interface.get_alias()
        if description:
            print(f'    Description: {description}')

        print()
        print(interface.operational.formated_stats())

    for ifname in ifnames:
        if ifname not in handled and ifname.startswith('pppoe'):
            state = pppoe(ifname)
            if not state:
                continue
            string = {
                'C': 'Coming up',
                'D': 'Link down',
            }[state]
            print('{}: {}'.format(ifname, string))


@register('show-brief')
def run_show_intf_brief(ifnames, iftypes, vif, vrrp):
    format1 = '%-16s %-33s %-4s %s'
    format2 = '%-16s %s'

    print('Codes: S - State, L - Link, u - Up, D - Down, A - Admin Down')
    print(format1 % ("Interface", "IP Address", "S/L", "Description"))
    print(format1 % ("---------", "----------", "---", "-----------"))

    handled = []
    for interface in filtered_interfaces(ifnames, iftypes, vif, vrrp):
        handled.append(interface.ifname)

        oper_state = interface.operational.get_state()
        admin_state = interface.get_admin_state()

        intf = [interface.ifname,]

        oper = ['u', ] if oper_state in ('up', 'unknown') else ['D', ]
        admin = ['u', ] if admin_state in ('up', 'unknown') else ['A', ]
        addrs = [_ for _ in interface.get_addr() if not _.startswith('fe80::')] or ['-', ]
        descs = list(split_text(interface.get_alias(),0))

        while intf or oper or admin or addrs or descs:
            i = intf.pop(0) if intf else ''
            a = addrs.pop(0) if addrs else ''
            d = descs.pop(0) if descs else ''
            s = [admin.pop(0)] if admin else []
            l = [oper.pop(0)] if oper else []
            if len(a) < 33:
                print(format1 % (i, a, '/'.join(s+l), d))
            else:
                print(format2 % (i, a))
                print(format1 % ('', '', '/'.join(s+l), d))

    for ifname in ifnames:
        if ifname not in handled and ifname.startswith('pppoe'):
            state = pppoe(ifname)
            if not state:
                continue
            string = {
                'C': 'u/D',
                'D': 'A/D',
            }[state]
            print(format1 % (ifname, '', string, ''))


@register('show-count')
def run_show_counters(ifnames, iftypes, vif, vrrp):
    formating = '%-12s %10s %10s     %10s %10s'
    print(formating % ('Interface', 'Rx Packets', 'Rx Bytes', 'Tx Packets', 'Tx Bytes'))

    for interface in filtered_interfaces(ifnames, iftypes, vif, vrrp):
        oper = interface.operational.get_state()

        if oper not in ('up','unknown'):
            continue

        stats = interface.operational.get_stats()
        cache = interface.operational.load_counters()
        print(formating % (
            interface.ifname,
            get_counter_val(cache['rx_packets'], stats['rx_packets']),
            get_counter_val(cache['rx_bytes'],   stats['rx_bytes']),
            get_counter_val(cache['tx_packets'], stats['tx_packets']),
            get_counter_val(cache['tx_bytes'],   stats['tx_bytes']),
        ))


@register('clear')
def run_clear_intf(ifnames, iftypes, vif, vrrp):
    for interface in filtered_interfaces(ifnames, iftypes, vif, vrrp):
        print(f'Clearing {interface.ifname}')
        interface.operational.clear_counters()


@register('reset')
def run_reset_intf(ifnames, iftypes, vif, vrrp):
    for interface in filtered_interfaces(ifnames, iftypes, vif, vrrp):
        interface.operational.reset_counters()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(add_help=False, description='Show interface information')
    parser.add_argument('--intf', action="store", type=str, default='', help='only show the specified interface(s)')
    parser.add_argument('--intf-type', action="store", type=str, default='', help='only show the specified interface type')
    parser.add_argument('--action', action="store", type=str, default='show', help='action to perform')
    parser.add_argument('--vif', action='store_true', default=False, help="only show vif interfaces")
    parser.add_argument('--vrrp', action='store_true', default=False, help="only show vrrp interfaces")
    parser.add_argument('--help', action='store_true', default=False, help="show help")

    args = parser.parse_args()

    def missing(*args):
        print('Invalid action [{args.action}]')
        usage()

    actions.get(args.action, missing)(
        [_ for _ in args.intf.split(' ') if _],
        [_ for _ in args.intf_type.split(' ') if _],
        args.vif,
        args.vrrp
    )