#!/usr/bin/env python

# note: this should be launched with "dmtcp_launch --new-coordinator runenergyplus_cr.py <input.idf> <weather.epw> <controller port>"
# If run via SSH, the following need to be done in advance:
# Prepare <input.idf> and variables.cfg using MLE+ and make sure MLE+ can run successfully
# with them using runenergyplus or RunEP.bat on the local machine.
# Copy <input.idf> and variables.cfg to the remote machine.
# Copy bcvtb directory from MLE+ installation on local machine to a location on the remote machine.
# The bcvtb directory should be placed in the same directory as <input.idf> or you should 
# set the BCVTB_HOME environment variable on the remote machine to point to it. 


# TODO: this could improve the interaction with matlab by making the connection to 
# matlab immediately (instead of waiting for a connection from energy plus first),
# and by closing sockets when an error occurs (so the matlab script will stop running too).

import socket
import os
import sys
import shutil
import tempfile
from textwrap import dedent
import subprocess
import threading
import glob
import datetime
import signal
import ctypes
import time
    
# keep a (short) list of dead sockets so we don't try to shut them down after they're closed
# (e.g., when the master socket disconnects, we close it and tell the eplus socket to shutdown;
#  then when the eplus socket shuts down, we don't want to tell the master to shutdown again).
closed_sockets = []

# energy plus process, needs to be killed at restart time
eplus_process = None

# the sockets are kept as global variables, because the master socket needs to be replaced 
# after it is disconnected during each checkpoint or restart (there's no way to reuse sockets directly)
# Using a global variable is a simple way to transmit this change to other users.
# (Alternatively, they could both be part of a shared dictionary or object.)
eplus_socket = None
master_socket = None
eplus_addr = None
master_addr = None
master_local_addr = None

def main(argv):
    global eplus_process, eplus_socket, master_socket, eplus_addr, master_addr, master_local_addr, dmtcp_listener_port
    
    # called with 5 arguments: script name, input file, weather file, port number, dmtcp listener port
    if len(argv) != 5:
        print "Usage: {name} <input_file> <weather_file> <controller port> <dmtcp listener port>".format(name=argv[0])
        sys.exit(2)
    idf_file = str(sys.argv[1])
    epw_file = str(sys.argv[2])
    master_port = int(sys.argv[3])
    dmtcp_listener_port = str(sys.argv[4])
    # load functions from libdmtcp, if available
    load_libdmtcp()

    if dmtcp_running:
        log("Loaded with dmtcp_launch; checkpoint and restart are available.")
    else:
        log("Warning: not loaded with dmtcp_launch; checkpoint and restart are not available.")

    make_eplus_socket()
    eplus_addr=eplus_socket.getsockname()
    master_addr=('localhost', master_port)
    
    temp_dir = setup_eplus_dir(idf_file, epw_file)
    eplus_process = launch_eplus(temp_dir, idf_file, epw_file)
    process_messages()

    # log("Waiting for Energy Plus to shutdown")
    exit_code = eplus_process.wait() 
    if exit_code == 0:
        # closed successfully; remove temporary directory
        shutil.rmtree(temp_dir, ignore_errors=True)
        log("Energy Plus completed successfully. Temporary directory has been removed.")
    else:
        log("Energy Plus completed with an error. Temporary directory '{dir}' has been retained.".format(dir=temp_dir))

    log("{name} complete.".format(name=argv[0]))

    return exit_code


# connect to the dmtcp library, if it was used to load this script
# note: libdmtcp code in this script is based on dmtcp.py provided 
# with dmtcp package, but with some bug fixes, and integrated to 
# remove the external dependency. Functions in libdmtcp are listed
# in https://github.com/dmtcp/dmtcp/blob/master/include/dmtcp.h
def load_libdmtcp():
    global libdmtcp, dmtcp_running
    # load all functions from attached DLLs
    libdmtcp = ctypes.CDLL(None)
    if hasattr(libdmtcp, "dmtcp_get_ckpt_filename"):
        dmtcp_loaded = True
        # flag the function(s) that don't return integers
        libdmtcp.dmtcp_get_ckpt_filename.restype = ctypes.c_char_p
    else:
        dmtcp_loaded = False

    if dmtcp_loaded:
        n = ctypes.c_int()
        ir = ctypes.c_int(0)
        libdmtcp.dmtcp_get_coordinator_status(ctypes.byref(n), ctypes.byref(ir))
        dmtcp_running = (ir.value == 1)
    else:
        dmtcp_running = False

    log("dmtcp library is {load_status} and {run_status}.".format(
        load_status = "loaded" if dmtcp_loaded else "not loaded",
        run_status = "running" if dmtcp_running else "not running"
    ))


def make_eplus_socket():
    global eplus_socket
    # create a socket to receive data from energy plus
    # and set it up listening
    eplus_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    eplus_socket.bind(('localhost', 0))
    eplus_socket.listen(0)  # don't queue additional connections
    return

def connect_master_socket():
    global master_socket, master_addr, master_local_addr
    # create a socket to send data back to the master controller
    master_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # allow reuse of the port number, to conserve port numbers
    # in case many instances are checkpointed and/or restarted in a short time
    # (doesn't actually work!)
    master_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    if master_local_addr is None:
        master_socket.connect(master_addr)
        master_local_addr = master_socket.getsockname()
        log("connected to master: master_addr={m}, local_addr={l}".format(m=master_addr, l=master_local_addr))
    else:
        log("connecting to master: local_addr={l}, master_addr={m}".format(m=master_addr, l=master_local_addr))
        # NOTE: if the socket is specifically bound to the old port, the reconnection fails.
        # (you can get one or two reuses if the scripts are setup so the master drops the connection 
        # and then this script reconnects, instead of the script dropping the connection). 
        # Without binding to the old port, there may be a few minutes' delay before the port 
        # will be available for reuse (at least for this master_addr). 
        # Until then, every checkpoint or restart will consume one port number.
        # DOESN'T WORK: master_socket.bind(master_local_addr)
        master_socket.connect(master_addr)
    return


def log(message):
    # send a message to the user, but highlight it so it's distinct from the
    # energy plus output which may also be scrolling by
    sys.stdout.write(">>>> [{msg}] <<<<\n".format(msg=message))
    sys.stdout.flush()


def setup_eplus_dir(idf_file, epw_file):
    # create a temporary directory and copy idf_file, epw_file and "variables.cfg" to it
    # from the current directory, if present
    # also write "socket.cfg" in the temporary directory, based on the eplus_addr
    temp_dir = tempfile.mkdtemp(dir=os.getcwd())

    # energy plus will probably fail if variables.cfg or idf_file are missing, but it may find them elsewhere(?)
    soft_copy_file(idf_file + ".idf", temp_dir, warn=True)
    soft_copy_file("variables.cfg", temp_dir, warn=True)
    # it's usually OK if the epw file is missing, because it is often pulled from the energy plus installation
    soft_copy_file(epw_file + ".epw", temp_dir)

    # create socket.cfg
    socket_xml = dedent(
     """<?xml version="1.0" encoding="ISO-8859-1"?>
        <BCVTB-client>
            <ipc>
                <socket port="{port}" hostname="{host}"/>
            </ipc>
        </BCVTB-client>""".format(host=eplus_addr[0], port=eplus_addr[1]))
    with open(os.path.join(temp_dir, "socket.cfg"), "w") as f:
        f.write(socket_xml)

    return temp_dir

def soft_copy_file(file, dir, warn=False):
    # try to copy a file from the current directory to the specified directory
    # but don't raise an error if it doesn't succeed (energy plus may find it elsewhere)
    try:
        shutil.copy(file, dir)
    except IOError:
        if warn:
            log("Warning: file {file} could not be copied to {dir}.".format(file=file, dir=dir))


def launch_eplus(temp_dir, idf_file, epw_file):
    # launch "runenergyplus" as a detached process in the temp_dir
    # it will read the socket.cfg file and open a socket, which will be managed
    # by process_messages().
    # its output will be interleaved with other output to stdout from python

    # make sure BCVTB_HOME environment variable is set
    env = os.environ.copy()
    if "BCVTB_HOME" not in env:
        bcvtbDir = os.path.join(os.getcwd(), "bcvtb")
        if os.path.isdir(bcvtbDir):
            env["BCVTB_HOME"] = bcvtbDir
        else:
            log("WARNING: BCVTB_HOME is not set and there is no {dir} directory. Energy Plus will probably fail.".format(dir=bcvtbDir))

    log("Launching runenergyplus in {dir}".format(dir=temp_dir))

    # launch "runenergyplus", and tell it to create a new process group with the same ID
    # as its pid. This will make it easy to kill runenergyplus and its children later.
    # (based on http://pymotw.com/2/subprocess/ )
    return subprocess.Popen(
        ["runenergyplus", idf_file, epw_file], 
        cwd=temp_dir, env=env, preexec_fn=os.setsid
    )



def process_messages():
    global eplus_socket, master_socket, eplus_addr, master_addr, master_local_addr
    # accept a connection on the eplus_socket (which should already be listening),
    # then connect to the master_addr on the master_socket
    log("waiting for connection from Energy Plus on port {port}".format(port=eplus_socket.getsockname()[1]))
    conn, eplus_addr = eplus_socket.accept()
    close_socket(eplus_socket)  # no longer need the original listener socket
    eplus_socket = conn         # now refer to the connected socket as eplus_socket
    log("connecting to master controller at {host}:{port}".format(host=master_addr[0], port=master_addr[1]))
    connect_master_socket()
    log("waiting for data from Energy Plus")
    r_eplus = launch_thread(recv_eplus)
    r_master = launch_thread(recv_master)
    # wait for communication to finish before returning
    # note: this could get stuck if the sockets are never
    # closed correctly from the other end; but the alternative
    # is to let these run in the background while we wait
    # for energy plus to terminate; the problem with that is
    # that we keep terminating energy plus on purpose, just
    # before each restart, which would create a race condition
    # between the main thread finishing and the recv_master
    # thread prompting a restart.
    r_eplus.join()
    r_master.join()


def launch_thread(func, *args, **kwargs):
    # launches the specified function as a thread
    thread = threading.Thread(target=func, name=func.__name__, args=args, kwargs=kwargs)
    thread.daemon = True   # close automatically when the main program ends
    thread.start()
    return thread

def close_socket(sock):
    global closed_sockets
    sock.close()
    closed_sockets.append(sock)

def shutdown_socket(sock):
    # send shutdown message to a socket (if it is still available)
    # this will cause the other end to disconnect gracefully,
    # then the relevant recv thread will get a disconnect message and close the socket
    if sock == eplus_socket:
        sockname = "energy plus"
    elif sock == master_socket:
        sockname = "master"
    else:
        sockname = "unknown"
    #if sock not in closed_sockets:
    try:
        log("shutting down {sockname} socket.".format(sockname=sockname))
        sock.shutdown(socket.SHUT_RDWR)
    except socket.error, e:
        #e = sys.exc_info()[0]
        log("received error message during shutdown of {sockname} socket.".format(sockname=sockname))
        log(str(e))
            

def recv_eplus():
    # note: python docs recommend that the receive buffer size should be
    # "a relatively small power of 2, for example, 4096."
    while True:
        data = eplus_socket.recv(4096)
        if len(data) > 0:      # TRY CHANGE THIS?
            #log("Data from Energy Plus: {data}".format(data=data))
            master_socket.sendall(data)
        else:
            # socket was shutdown by Energy Plus
            # close this end to finish the job, and also shutdown the master socket
            log("Socket shutdown by Energy Plus; shutting down connection to master controller.")
            close_socket(eplus_socket)
            shutdown_socket(master_socket)
            # note: this tells the master to drop the connection; after that
            # the the recv_master thread will close the master_socket
            break

def recv_master():
    # Note: we assume that all checkpoint/restart commands from the controller will come in a single packet.
    while True:
        data = master_socket.recv(4096)
        if len(data) > 0:
            if data in ("checkpoint\n", "restart\n"):
                data = data[:-1]    # drop the trailing new line
                # checkpoint or restart this script and Energy Plus
                log("{msg} requested.".format(msg=data))
                # shutdown and close the master socket connection
                log("closing socket to master controller.")
                master_socket.shutdown(socket.SHUT_RDWR)
                if master_socket.recv(4096) != "":
                    raise RuntimeError("Master controller sent data instead of closing socket during checkpoint/restart.")
                master_socket.close()
                if data == "checkpoint":
                    checkpoint()
                else:
                    restart()
                log("checkpoint or restart completed; re-connecting to master controller.")
                connect_master_socket()
                log("reconnected to master controller.")
            else:
                # just received normal data; relay to Energy Plus
                #log("Data from master controller: {data}".format(data=data))
                eplus_socket.sendall(data)
        else:
            # socket was shutdown by the master controller
            # close this end to finish the job, and also shutdown the Energy Plus socket
            log("Socket shutdown by master controller; shutting down connection to Energy Plus.")
            close_socket(master_socket)
            shutdown_socket(eplus_socket)
            # note: this tells Energy Plus to drop the connection; after that
            # the the recv_eplus thread will close the eplus_socket
            break


def find_restart_script():
    # based loosely on dmtcp.createSessionList()
    matching_scripts = []
    
    # Build a list of all scripts that mention the current checkpoint file.
    # These could be used to restart the current process.
    dmtcp_file_name = libdmtcp.dmtcp_get_ckpt_filename()
    dir_name = os.path.dirname(dmtcp_file_name)
    for script in glob.glob(os.path.join(dir_name, 'dmtcp_restart_script_*.sh')):
        with open(script) as f:
            # Search the file for ckpt_timestamp and given_ckpt_files entries.
            # If the current checkpoint file is found in the list of filenames, 
            # add the timestamp and script file name to the matching_scripts list.
            datetime_val = None
            dmtcp_files = None
            for line in f:
                if line.startswith('ckpt_timestamp="'):
                    datetime_str = line.split('"')[1]
                    # note: datetime_str looks like "Thu Apr 16 09:58:23 2015"
                    # which is identical to the "Locale's appropriate date and time representation",
                    # at least on ubuntu 14.04 in the U.S. 
                    # It seems more likely this will remain locale-specific on other systems,
                    # rather than always be formatted this way. So we use the locale-specific flag %c.
                    datetime_val = datetime.datetime.strptime(datetime_str, "%c")
                    if dmtcp_files is not None:
                        break
                if line.startswith('given_ckpt_files="'):
                    dmtcp_files = line.split('"')[1].split(' ')[1:]
                    if datetime_val is not None:
                        break
            # check whether this script matches the current process
            if datetime_val is not None and dmtcp_files is not None:
                if dmtcp_file_name in dmtcp_files:
                    matching_scripts.append((datetime_val, script))

    log("{num} checkpoint(s) available for restart.".format(num=len(matching_scripts)))
    if len(matching_scripts) > 0:
        matching_scripts.sort()
        latest_script = os.path.join(os.getcwd(), matching_scripts[-1][1])
        #latest_script = matching_scripts[-1][1]
        log('{latest} is the most recent restart script for this process.'.format(latest=latest_script))
    else:
        latest_script = None
        log('WARNING: no restart script available.')
    return latest_script


def checkpoint():
    if dmtcp_running:
        log("creating checkpoint.")
        libdmtcp.dmtcp_checkpoint()
        log("just finished checkpoint or restore.")
    else:
        raise RuntimeError("Checkpoint requested, but script was not loaded with dmtcp.")
    
    
def restart():
    # find the latest restart files for the current process, 
    # use them to replace the current process.
    # based loosely on dmtcp.restore()
    global eplus_process, dmtcp_listener_port
    if dmtcp_running:
        restart_script = find_restart_script()
        if restart_script is None:
            raise RuntimeError("No restart scripts found for the current session.")
        else:
            log('Terminating runenergyplus subprocess.')
            log("BEFORE process tree:")
            os.system("ps xjf")

            killpid = libdmtcp.dmtcp_virtual_to_real_pid(eplus_process.pid)
            log('Killing PID: {thispid}'.format(thispid=killpid))
            os.killpg(libdmtcp.dmtcp_virtual_to_real_pid(eplus_process.pid), signal.SIGKILL)
            eplus_process.wait()    # get first child's exit status, so it won't become a zombie
            # close the socket too for good measure
            try:
                eplus_socket.shutdown(socket.SHUT_RDWR)
                eplus_socket.shutdown(socket.SHUT_RDWR)
            except socket.error:
                # sometimes the socket will close quickly from the other end, which is OK
                pass
            eplus_socket.close()
            log("Closed runenergyplus.")
            log("AFTER process tree:")
            os.system("ps xjf")
	
            log('Using script {script} to restart session.'.format(script=restart_script))

            os.execlp('dmtcp_nocheckpoint', 'sh', restart_script, '--host', 'localhost', '--port', dmtcp_listener_port)
            #os.execlp('dmtcp_nocheckpoint', 'sh', 'dmtcp_restart', '--new-coordinator', *restart_files)
    else:
        raise RuntimeError("Restart requested, but script was not loaded with dmtcp.")

if __name__ == "__main__":
    exit_code = main(sys.argv)
    sys.exit(exit_code)
