#!/usr/bin/env python3
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import gc
import io
import json
import os
import sys
import time
from timeit import timeit

import numpy
import psutil
import rapidjson
import simplejson
from memory_profiler import memory_usage
from tabulate import tabulate

import orjson

os.sched_setaffinity(os.getpid(), {0, 1})


kind = sys.argv[1] if len(sys.argv) >= 1 else ""


if kind == "float16":
    dtype = numpy.float16
    array = numpy.random.random(size=(50000, 100)).astype(dtype)
elif kind == "float32":
    dtype = numpy.float32
    array = numpy.random.random(size=(50000, 100)).astype(dtype)
elif kind == "float64":
    dtype = numpy.float64
    array = numpy.random.random(size=(50000, 100))
    assert array.dtype == numpy.float64
elif kind == "bool":
    dtype = numpy.bool_
    array = numpy.random.choice((True, False), size=(100000, 200))
elif kind == "int8":
    dtype = numpy.int8
    array = numpy.random.randint(((2**7) - 1), size=(100000, 100), dtype=dtype)
elif kind == "int16":
    dtype = numpy.int16
    array = numpy.random.randint(((2**15) - 1), size=(100000, 100), dtype=dtype)
elif kind == "int32":
    dtype = numpy.int32
    array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint8":
    dtype = numpy.uint8
    array = numpy.random.randint(((2**8) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint16":
    dtype = numpy.uint16
    array = numpy.random.randint(((2**16) - 1), size=(100000, 100), dtype=dtype)
elif kind == "uint32":
    dtype = numpy.uint32
    array = numpy.random.randint(((2**31) - 1), size=(100000, 100), dtype=dtype)
else:
    print(
        "usage: pynumpy (bool|int16|int32|float16|float32|float64|int8|uint8|uint16|uint32)"
    )
    sys.exit(1)
proc = psutil.Process()


def default(__obj):
    if isinstance(__obj, numpy.ndarray):
        return __obj.tolist()
    raise TypeError


headers = ("Library", "Latency (ms)", "RSS diff (MiB)", "vs. orjson")

LIBRARIES = ("orjson", "ujson", "rapidjson", "simplejson", "json")

ITERATIONS = 10


def orjson_dumps():
    return orjson.dumps(array, option=orjson.OPT_SERIALIZE_NUMPY)


ujson_dumps = None


def rapidjson_dumps():
    return rapidjson.dumps(array, default=default).encode("utf-8")


def simplejson_dumps():
    return simplejson.dumps(array, default=default).encode("utf-8")


def json_dumps():
    return json.dumps(array, default=default).encode("utf-8")


output_in_mib = len(orjson_dumps()) / 1024 / 1024

print(f"{output_in_mib:,.1f}MiB {kind} output (orjson)")

gc.collect()
mem_before = proc.memory_full_info().rss / 1024 / 1024


def per_iter_latency(val):
    if val is None:
        return None
    return (val * 1000) / ITERATIONS


def test_correctness(func):
    return numpy.array_equal(array, numpy.array(orjson.loads(func()), dtype=dtype))


table = []
for lib_name in LIBRARIES:
    gc.collect()

    print(f"{lib_name}...")
    func = locals()[f"{lib_name}_dumps"]
    if func is None:
        total_latency = None
        latency = None
        mem = None
        correct = False
    else:
        total_latency = timeit(
            func,
            number=ITERATIONS,
        )
        latency = per_iter_latency(total_latency)
        time.sleep(1)
        mem = max(memory_usage((func,), interval=0.001, timeout=latency * 2))
        correct = test_correctness(func)

    if lib_name == "orjson":
        compared_to_orjson = 1
        orjson_latency = latency
    elif latency:
        compared_to_orjson = latency / orjson_latency
    else:
        compared_to_orjson = None

    if not correct:
        latency = None
        mem = 0

    mem_diff = mem - mem_before

    table.append(
        (
            lib_name,
            f"{latency:,.0f}" if latency else "",
            f"{mem_diff:,.0f}" if mem else "",
            f"{compared_to_orjson:,.1f}" if (latency and compared_to_orjson) else "",
        )
    )

buf = io.StringIO()
buf.write(tabulate(table, headers, tablefmt="github"))
buf.write("\n")

print(buf.getvalue())
