#!/usr/bin/python2
#
# Copyright (c) 2021 Alibaba Group Holding Limited
# Express UDP is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
#

# -*- coding:utf-8 -*-

import struct
import sys
import os
import socket
import argparse
import copy
import time
import signal

def sigint_handle(si, s):
    sys.exit(-1)

class g:
    stats_struct_code = \
    """
	u64 nanosecond;

	int ch_id;
	int g_id;
	int group_num;
	int is_tx;

@	u64 send_ebusy;
@	u64 send_again;
@	u64 send_err;
@	u64 send_success;

@	u64 no_cq;
@	u64 no_tx;

@	u64 rx_npkts;
@	u64 tx_npkts;

@	u64 xsk_rx_dropped;
@	u64 xsk_rx_invalid_descs;
@	u64 xsk_tx_invalid_descs;

@       int kern_tx_num;

	int padding;
    """

    dev                = None
    stats_struct_fmt   = None
    stats_struct_names = None
    udp_sock           = None
    dest_ip            = None
    dest_port          = None
    nanosecond         = 0
    max_key_len        = 0
    localport          = 1121
    groups             = []
    req                = None

def stats_struct_parse():
    lines = g.stats_struct_code.split('\n')

    fmt = []
    names = []

    for line in lines:
        line = line.strip()
        sum_key = False
        if not line:
            continue

        if line[0] == '#':
            continue

        line = line.split(';')[0]

        if line[0] == '@':
            sum_key = True
            line = line[1:].strip()
            if not line:
                continue

        tp, name = line.split()

        if len(name) > g.max_key_len:
            g.max_key_len = len(name)

        names.append(name)

        if sum_key:
            Stats.sum_keys.append(name)

        if tp == 'u64':
            fmt.append('Q')
            continue

        if tp == 'int':
            fmt.append('i')
            continue

        raise Exception("unknow type " + tp)

    g.stats_struct_fmt = ''.join(fmt)
    g.stats_struct_names = names


class Stats(object):
    sum_keys = []
    def __init__(self, binary = None):
        if not binary:
            return

        s = struct.unpack(g.stats_struct_fmt, binary)

        for i, v in enumerate(s):
            setattr(self, g.stats_struct_names[i], v)

    def __add__(self, s):
        def add_one(o, a, b, key):
            setattr(o, key, getattr(a, key, 0) + getattr(b, key, 0))

        o = Stats()

        for key in self.sum_keys:
            add_one(o, self, s, key)

        return o

    def print_sum_values(self):
        for key in self.sum_keys:
            print_kv(key, getattr(self, key))

    def val_list(self):
        o = [self.ch_id, self.is_tx]
        for key in self.sum_keys:
            o.append(getattr(self, key))

        return o

class Request(object):
    def __init__(self):
        raw_sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, 0)
        raw_sock.bind((g.dev, 0))
        self.raw_sock = raw_sock

    def mk_ip_header(self, ip_dest, ip_source):
        def carry_around_add(a, b):
            c = a + b
            return (c & 0xffff) + (c >> 16)

        def checksum(msg):
            s = 0
            for i in range(0, len(msg), 2):
                w = ord(msg[i]) + (ord(msg[i+1]) << 8)
                s = carry_around_add(s, w)
            return ~s & 0xffff

        ip_ver         = 4
        ip_ihl         = 5
        ip_dscp        = 0
        ip_total_len   = 32
        ip_id          = 2222
        ip_frag_offset = 0
        ip_ttl         = 255
        ip_protocol    = socket.IPPROTO_UDP
        ip_checksum    = 0
        ip_saddr = socket.inet_pton(socket.AF_INET, ip_source)
        ip_daddr = socket.inet_pton(socket.AF_INET, ip_dest)

        ip_ver_ihl = (ip_ver << 4) + ip_ihl

        iphdr = struct.pack('!BBHHHBBH4s4s' , ip_ver_ihl, ip_dscp,
                ip_total_len, ip_id,
                ip_frag_offset, ip_ttl,
                ip_protocol, ip_checksum, ip_saddr, ip_daddr)

        # first change htons, because next pack will htons again
        ip_checksum = socket.htons(checksum(iphdr))

        iphdr = struct.pack('!BBHHHBBH4s4s' , ip_ver_ihl, ip_dscp,
                ip_total_len, ip_id,
                ip_frag_offset, ip_ttl,
                ip_protocol, ip_checksum, ip_saddr, ip_daddr)

        return iphdr

    def mk_udp_header(self, dest, source, payload_len):
        l = 8 + payload_len
        udp_header = struct.pack('!HHHH', source, dest, l, 0)
        return udp_header

    def mac_to_pk(self, mac):
        m = mac.split(':')
        m = [int('0x' + x, 16) for x in m]
        return struct.pack('!6B', *m)

    def hw_addr(self, dev):
        s_mac = open('/sys/class/net/%s/address' % dev).read().strip()

        cmd = 'ip route get 8.8.8.8'
        next_hop = os.popen(cmd).read().split()[2]

        d_mac = None
        arps = os.popen('arp -n').readlines()
        for l in arps:
            t = l.split()
            if next_hop == t[0]:
                d_mac = t[2]
                break

        if not d_mac:
            print("not found mac from arp -n for %s" % next_hop)
            sys.exit(-1)

        s_mac = self.mac_to_pk(s_mac)
        d_mac = self.mac_to_pk(d_mac)

        return d_mac + s_mac + struct.pack("!H", 0x800)

    def send_req(self, group_id, ip_dest, dstport, srcport):
        hw         = self.hw_addr(g.dev)
        ip_header  = self.mk_ip_header(ip_dest, ip_dest)

        payload    = struct.pack('!I', group_id)

        udp_header = self.mk_udp_header(dstport, srcport, len(payload))

        packet = hw + ip_header + udp_header + payload

        self.raw_sock.send(packet)

    def req(self, gid):
        self.send_req(gid, g.dest_ip, g.dest_port, g.localport)


class Group(object):
    def __init__(self, gid):
        self.xsk = []
        self.gid = gid
        g.groups.append(self)

        self.stats_map_switch(1)
        try:
            self.fetch()
        except Exception as e:
            raise Exception(e)
        finally:
            self.stats_map_switch(0)

    def fetch(self):
        g.req.req(self.gid)
        while True:
            data, addr = g.udp_sock.recvfrom(1024)
            stats = Stats(data)
            self.xsk.append(stats)

            if stats.is_tx:
                break

    def stats_map_switch(self, on):
        v = 0
        if on:
            v = 1

        cmd = 'xudp-map update map_stats 0 %s' % v
        ret = os.system(cmd)

        ret = ret >> 8
        if ret:
            print("map_stats set value %d err" % v)
            sys.exit(-1)


    def get_group_num(self):
        return self.xsk[0].group_num

    def get_group_id(self):
        return self.xsk[0].g_id

def fetch_stats():
    g.groups = []
    gp = Group(0)

    for i in range(1, gp.get_group_num()):
        Group(i)

    return

def print_kv(key, value, fmt = "%s: %s"):
    print(fmt % (key.ljust(g.max_key_len), value))

class Show(object):
    def print_sum(self):
        s = None
        for gp in g.groups:
            for xsk in gp.xsk:
                if s == None:
                    s = xsk
                else:
                    s += xsk

        print("")
        print("=============================")
        s.print_sum_values()

    def print_group(self, args, g):
        table = []

        # group table headers
        keys = copy.deepcopy(Stats.sum_keys)
        keys.insert(0, 'is tx')
        keys.insert(0, 'channel id')

        table.append(keys)

        for s in g.xsk:
            if args.rx and s.is_tx:
                continue

            if args.tx and not s.is_tx:
                continue

            if args.channel_id != None and s.ch_id not in args.channel_id:
                continue

            table.append(s.val_list())

        self.print_tb(table)

    def print_tb(self, tb):
        width = [4] * len(tb)

        for i, col in enumerate(tb):
            for v in col:
                v = str(v)
                if len(v) > width[i]:
                    width[i] = len(v)

        for i in range(len(tb[0])):
            o = []
            for ii, col in enumerate(tb):
                if len(col) <= i:
                    v = ''
                else:
                    v = str(col[i])

                if ii == 0:
                    v = v.ljust(width[ii])
                else:
                    v = v.rjust(width[ii])
                o.append(v)

            print(' '.join(o))


    def print_list(self, args):
        for gp in g.groups:
            gid = gp.get_group_id()
            if args.group_id and gid not in args.group_id:
                continue

            print("## group: %s" % gid)
            self.print_group(args, gp)
            print('')

def parse_ids(ids):
    ids = ids.split(',')
    o = []
    for i in ids:
        vs = i.split('-')
        if len(vs) == 1:
            o.append(int(vs[0]))
        else:
            s = int(vs[0])
            e = int(vs[1]) + 1
            for i in range(s, e):
                o.append(i)
    return o

def get_addr_by_dev(dev):
    lines = os.popen('ip addr show dev %s' % g.dev).readlines()
    for l in lines:
        l = l.split()
        if l[0] == 'inet':
            return l[1].split('/')[0]

def main():
    stats_struct_parse()

    parser = argparse.ArgumentParser()
    parser.add_argument('-p', "--port",       help = "dest udp port", type=int)
    parser.add_argument('-i', "--ip",         help = "dest ip addr.  default get by from dev")
    parser.add_argument('-d', "--dev",        help = "xudp bind dev. default eth0", default='eth0')
    parser.add_argument('-l', "--list",       help = "list all channel stats", action='store_true')
    parser.add_argument('-r', "--rx",         help = "filter rx", action='store_true')
    parser.add_argument('-t', "--tx",         help = "filter tx", action='store_true')
    parser.add_argument('-c', "--channel-id", help = "channel ids. likes: 1 or 1,2,4-9")
    parser.add_argument('-g', "--group-id",   help = "group ids. likes: 1 or 1,2,4-9")
    parser.add_argument("-T", "--timeout",    help = "set the timeout for recv response", default = 3, type = int)
    parser.add_argument("-w", "--watch",      help = "watch mode. option as the inter second", type = int)

    args = parser.parse_args()
    if not args.port:
        cmd = "xudp-map ipport"
        lines = os.popen(cmd).readlines()
        if not lines:
            print("can not get port by xudp-map ipport")
            return -1

        t = lines[0].strip().split(':')
        if len(t) != 2:
            print("can not get port by xudp-map ipport")
            return -1

        port = int(t[1])
        args.port = port

    g.dest_ip   = args.ip
    g.dest_port = args.port
    g.dev       = args.dev
    g.req       = Request()

    if g.dest_ip == None:
        g.dest_ip = get_addr_by_dev(g.dev)

    print("# server %s:%d" % (g.dest_ip, g.dest_port))

    udp_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    udp_sock.bind(("0.0.0.0", g.localport))
    udp_sock.settimeout(args.timeout)
    g.udp_sock = udp_sock

    if args.channel_id:
        args.channel_id = parse_ids(args.channel_id)

    if args.group_id:
        args.group_id = parse_ids(args.group_id)

    signal.signal(signal.SIGINT, sigint_handle)

    show = Show()

    while True:
        fetch_stats()
        if args.list:
            show.print_list(args)
        else:
            show.print_sum()

        if args.watch:
            time.sleep(args.watch)
        else:
            break

main()


