#!/usr/bin/env python3

'''
Wrapper script for running the Rumur model checker.

This script is intended to be installed alongside the `rumur` binary from the
Rumur model checker. It can then be used to quickly generate and run a model, as
an alternative to having to run the model generation, compilation and execution
steps manually.
'''

import atexit
import os
import platform
import shutil
import subprocess as sp
import sys
import tempfile
from typing import Optional

def which(cmd: str) -> Optional[str]:
  '''
  Equivalent of shell `which`
  '''
  try:
    return sp.check_output(['which', cmd], stderr=sp.DEVNULL,
      universal_newlines=True).strip()
  except sp.CalledProcessError:
    return None

# C compiler
CC = which(os.environ.get('CC', 'cc'))

def categorise(cc: str) -> str:
  '''
  Determine the vendor of a given C compiler
  '''

  # Create a temporary area to compile a test file
  tmp = tempfile.mkdtemp()

  # Setup the test file
  src = os.path.join(tmp, 'test.c')
  with open(src, 'wt') as f:
    f.write('#include <stdio.h>\n'
            '#include <stdlib.h>\n'
            'int main(void) {\n'
            '#ifdef __clang__\n'
            '  printf("clang\\n");\n'
            '#elif defined(__GNUC__)\n'
            '  printf("gcc\\n");\n'
            '#else\n'
            '  printf("unknown\\n");\n'
            '#endif\n'
            '  return EXIT_SUCCESS;\n'
            '}\n')

  categorisation = 'unknown'

  # Compile it
  aout = os.path.join(tmp, 'a.out')
  cc_proc = sp.run([cc, '-o', aout, src], universal_newlines=True,
    stdout=sp.DEVNULL, stderr=sp.DEVNULL)

  # Run it
  if cc_proc.returncode == 0:
    try:
      categorisation = sp.check_output([aout], universal_newlines=True).strip()
    except sp.CalledProcessError:
      pass

  # Clean up
  shutil.rmtree(tmp)

  return categorisation

def supports(flag: str) -> bool:
  '''check whether the compiler supports a given command line flag'''

  # a trivial program to ask it to compile
  program = 'int main(void) { return 0; }'

  # compile it
  p = sp.run([CC, '-o', os.devnull, '-x', 'c', '-', flag], stderr=sp.DEVNULL,
    input=program.encode('utf-8', 'replace'))

  # check whether compilation succeeded
  return p.returncode == 0

def needs_libatomic() -> bool:
  '''check whether the compiler needs -latomic for a double-word
  compare-and-swap'''

  # CAS program to ask it to compile
  program = 'int main(void) {\n' \
            '#if __SIZEOF_POINTER__ <= 4\n' \
            '  uint64_t target = 0;\n' \
            '#elif __SIZEOF_POINTER__ <= 8\n' \
            '  unsigned __int128 target = 0;\n' \
            '#endif\n' \
            '  return __sync_val_compare_and_swap(&target, 0, 1);\n' \
            '}\n'

  # compile it
  args = [CC, '-x', 'c', '-std=c11', '-', '-o', os.devnull]
  if supports('-mcx16'):
    args.append('-mcx16')
  p = sp.run(args, stderr=sp.DEVNULL, input=program.encode('utf-8', 'replace'))

  # check whether compilation succeeded
  return p.returncode != 0

def optimisation_flags() -> [str]:
  '''C compiler optimisation command line options for this platform'''

  flags = ['-O3']

  # optimise code for the current host architecture
  if supports('-march=native'): flags.append('-march=native')

  # optimise code for the current host CPU
  if supports('-mtune=native'): flags.append('-mtune=native')

  # enable link-time optimisation
  if supports('-flto'): flags.append('-flto')

  cc_vendor = categorise(CC)

  # allow GCC to perform more advanced interprocedural optimisations
  if cc_vendor == 'gcc': flags.append('-fwhole-program')

  # on platforms that need it made explicit, enable CMPXCHG16B
  if supports('-mcx16'): flags.append('-mcx16')

  return flags

def main(args: [str]) -> int:

  # Find the Rumur binary
  rumur_bin = which('rumur')
  if rumur_bin is None:
    rumur_bin = which(os.path.join(os.path.dirname(__file__), 'rumur'))
  if rumur_bin is None:
    sys.stderr.write('rumur binary not found\n')
    return -1

  # if the user asked for help or version information, run Rumur directly
  for arg in args[1:]:
    if arg.startswith('-h') or arg.startswith('--h') or arg.startswith('--vers'):
      os.execv(rumur_bin, [rumur_bin] + args[1:])

  if CC is None:
    sys.stderr.write('no C compiler found\n')
    return -1

  # Generate the checker
  print('Generating the checker...')
  rumur_proc = sp.run([rumur_bin] + args[1:] + ['--output', '/dev/stdout'],
    stdin=sp.PIPE, stdout=sp.PIPE)
  if rumur_proc.returncode != 0:
    return rumur_proc.returncode
  checker_c = rumur_proc.stdout

  ok = True

  # Setup a temporary directory in which to generate the checker
  tmp = tempfile.mkdtemp()
  atexit.register(shutil.rmtree, tmp)

  # Compile the checker
  if ok:
    print('Compiling the checker...')
    aout = os.path.join(tmp, 'a.out')
    argv = [CC, '-std=c11'] + optimisation_flags() + ['-o', aout, '-x', 'c',
      '-', '-lpthread']
    if needs_libatomic():
      argv.append('-latomic')
    cc_proc = sp.run(argv, input=checker_c)
    ok &= cc_proc.returncode == 0

  # Run the checker
  if ok:
    print('Running the checker...')
    checker_proc = sp.run([aout])
    ok &= checker_proc.returncode == 0

  return 0 if ok else -1

if __name__ == '__main__':
  try:
    sys.exit(main(sys.argv))
  except KeyboardInterrupt:
    sys.exit(130)
