import contextlib
import json
import os
import shutil
import sys
import tempfile

__all__ = ['MockCommand', 'assert_calls']

pkgdir = os.path.dirname(__file__)

recording_dir = None

def prepend_to_path(dir):
    os.environ['PATH'] = dir + os.pathsep + os.environ['PATH']

def remove_from_path(dir):
    path_dirs = os.environ['PATH'].split(os.pathsep)
    path_dirs.remove(dir)
    os.environ['PATH'] = os.pathsep.join(path_dirs)


_record_run = """#!{python}
import os, sys
import json

with open({recording_file!r}, 'a') as f:
    json.dump({{'env': dict(os.environ),
               'argv': sys.argv,
               'cwd': os.getcwd()}},
              f)
    f.write('\\x1e') # ASCII record separator
"""

# TODO: Overlapping calls to the same command may interleave writes.

class MockCommand(object):
    """Context manager to mock a system command.
    
    The mock command will be written to a directory at the front of $PATH,
    taking precedence over any existing command with the same name.
    
    By specifying content as a string, you can determine what running the
    command will do. The default content records each time the command is
    called and exits: you can access these records with mockcmd.get_calls().
    
    On Windows, the specified content will be run by the Python interpreter in
    use. On Unix, it should start with a shebang (``#!/path/to/interpreter``).
    """
    def __init__(self, name, content=None):
        global recording_dir
        self.name = name
        self.content = content

        if recording_dir is None:
            recording_dir = tempfile.mkdtemp()
        fd, self.recording_file = tempfile.mkstemp(dir=recording_dir,
                                                prefix=name, suffix='.json')
        os.close(fd)
        self.command_dir = tempfile.mkdtemp()

    def _copy_exe(self):
        bitness = '64' if (sys.maxsize > 2**32) else '32'
        src = os.path.join(pkgdir, 'cli-%s.exe' % bitness)
        dst = os.path.join(self.command_dir, self.name+'.exe')
        shutil.copy(src, dst)

    @property
    def _cmd_path(self):
        # Can only be used once commands_dir has been set
        p = os.path.join(self.command_dir, self.name)
        if os.name == 'nt':
            p += '-script.py'
        return p

    def __enter__(self):
        if os.path.isfile(self._cmd_path):
            raise EnvironmentError("Command %r already exists at %s" %
                                            (self.name, self._cmd_path))
        
        if self.content is None:
            self.content = _record_run.format(python=sys.executable,
                                             recording_file=self.recording_file)

        with open(self._cmd_path, 'w') as f:
            f.write(self.content)
        
        if os.name == 'nt':
            self._copy_exe()
        else:
            os.chmod(self._cmd_path, 0o755) # Set executable bit

        prepend_to_path(self.command_dir)

        return self
    
    def __exit__(self, etype, evalue, tb):
        remove_from_path(self.command_dir)
        shutil.rmtree(self.command_dir, ignore_errors=True)

    def get_calls(self):
        """Get a list of calls made to this mocked command.
        
        This relies on the default script content, so it will return an
        empty list if you specified a different content parameter.
        
        For each time the command was run, the list will contain a dictionary
        with keys argv, env and cwd.
        """
        if recording_dir is None:
            return []
        if not os.path.isfile(self.recording_file):
            return []
        
        with open(self.recording_file, 'r') as f:
            # 1E is ASCII record separator, last chunk is empty
            chunks = f.read().split('\x1e')[:-1]
        
        return [json.loads(c) for c in chunks]


@contextlib.contextmanager
def assert_calls(cmd, args=None):
    """Assert that a block of code runs the given command.
    
    If args is passed, also check that it was called at least once with the
    given arguments (not including the command name).
    
    Use as a context manager, e.g.::
    
        with assert_calls('git'):
            some_function_wrapping_git()
            
        with assert_calls('git', ['add', myfile]):
            some_other_function()
    """
    with MockCommand(cmd) as mc:
        yield
    
    calls = mc.get_calls()
    assert calls != [], "Command %r was not called" % cmd

    if args is not None:
        if not any(args == c['argv'][1:] for c in calls):
            msg = ["Command %r was not called with specified args (%r)" %
                            (cmd, args),
                   "It was called with these arguments: "]
            for c in calls:
                msg.append('  %r' % c['argv'][1:])
            raise AssertionError('\n'.join(msg))