#!/usr/bin/env python3
# vim: set ts=4 sw=4 tw=99 et:

# iongraph -- Translate IonMonkey JSON to GraphViz.
# Originally by Sean Stangl. See LICENSE.

import argparse
import html
import json
import glob
import os
import re
import shutil
import subprocess
import sys
import tempfile

def quote(s):
    return '"%s"' % str(s)

# Simple classes for the used subset of GraphViz' Dot format.
# There are more complicated constructors out there, but they all
# pull in annoying dependencies (and are annoying dependencies themselves).
class GraphWidget:
    def __init__(self):
        self.name = ''
        self.props = {}

    def addprops(self, propdict):
        for p in propdict:
            self.props[p] = propdict[p]


class Node(GraphWidget):
    def __init__(self, name):
        GraphWidget.__init__(self)
        self.name = str(name)

class Edge(GraphWidget):
    def __init__(self, nfrom, nto):
        GraphWidget.__init__(self)
        self.nfrom = str(nfrom)
        self.nto = str(nto)

class Graph(GraphWidget):
    def __init__(self, name, func, type):
        GraphWidget.__init__(self)
        self.name = name
        self.func = func
        self.type = str(type)
        self.props = {}
        self.nodes = []
        self.edges = []

    def addnode(self, n):
        self.nodes.append(n)

    def addedge(self, e):
        self.edges.append(e)

    def writeprops(self, f, o):
        if len(o.props) == 0:
            return

        print('[', end=' ', file=f)
        for p in o.props:
            print(str(p) + '=' + str(o.props[p]), end=' ', file=f)
        print(']', end=' ', file=f)

    def write(self, f):
        print(self.type, '{', file=f)

        # Use the pass name as the graph title (at the top).
        legend = '<font color="blue">movable</font>, <u>guard</u>, <font color="red">in worklist</font>, <font color="gray50">recovered on bailout</font>'
        print('labelloc = t;', file=f)
        print('labelfontsize = 30;', file=f)
        print('label = <<b>%s - %s</b><br/>%s<br/>&nbsp;>;' % (self.func['name'], self.name, legend), file=f)

        # Output graph properties.
        for p in self.props:
            print('  ' + str(p) + '=' + str(self.props[p]), file=f)
        print('', file=f)

        # Output node list.
        for n in self.nodes:
            print('  ' + n.name, end=' ', file=f)
            self.writeprops(f, n)
            print(';', file=f)
        print('', file=f)
        
        # Output edge list.
        for e in self.edges:
            print('  ' + e.nfrom, '->', e.nto, end=' ', file=f)
            self.writeprops(f, e)
            print(';', file=f)

        print('}', file=f)


# block obj -> node string with quotations
def getBlockNodeName(b):
    return blockNumToNodeName(b['id'])

# int -> node string with quotations
def blockNumToNodeName(i):
    return quote('Block' + str(i))

# resumePoint obj -> HTML-formatted string
def getResumePointRow(rp, mode):
    if mode != None and mode != rp['mode']:
        return ''

    # Left column: caller.
    rpCaller = '<td align="left"></td>'
    if 'caller' in rp:
        rpCaller = '<td align="left">&#40;&#40;%s&#41;&#41;</td>' % str(rp['caller'])

    # Middle column: ordered contents of the MResumePoint.
    insts = ''.join('%s ' % t for t in rp['operands'])
    rpContents = '<td align="left"><font color="grey50">resumepoint %s</font></td>' % insts

    # Right column: unused.
    rpRight = '<td></td>'

    return '<tr>%s%s%s</tr>' % (rpCaller, rpContents, rpRight)

# memInputs obj -> HTML-formatted string
def getMemInputsRow(list):
    if len(list) == 0:
        return ''

    # Left column: caller.
    memLeft = '<td align="left"></td>'

    # Middle column: ordered contents of the MResumePoint.
    insts = ''.join('%s ' % str(t) for t in list)
    memContents = '<td align="left"><font color="grey50">memory %s</font></td>' % insts

    # Right column: unused.
    memRight = '<td></td>'

    return '<tr>%s%s%s</tr>' % (memLeft, memContents, memRight)

# Outputs a single row for an instruction, excluding MResumePoints.
# instruction -> HTML-formatted string
def getInstructionRow(inst):
    # Left column: instruction ID.
    instId = str(inst['id'])
    instLabel = '<td align="right" port="i%s">%s</td>' % (instId, instId)

    # Middle column: instruction name.
    instName = inst['opcode'].replace('->', '→').replace('<-', '←')
    instName = html.escape(instName)
    if 'attributes' in inst:
        if 'RecoveredOnBailout' in inst['attributes']:
            instName = '<font color="gray50">%s</font>' % instName
        elif 'Movable' in inst['attributes']:
            instName = '<font color="blue">%s</font>' % instName
        if 'Guard' in inst['attributes']:
            instName = '<u>%s</u>' % instName
        if 'InWorklist' in inst['attributes']:
            instName = '<font color="red">%s</font>' % instName
    instName = '<td align="left">%s</td>' % instName

    # Right column: instruction MIRType.
    instType = ''
    if 'type' in inst and inst['type'] != "None":
        instType = '<td align="right">%s</td>' % html.escape(inst['type'])

    return '<tr>%s%s%s</tr>' % (instLabel, instName, instType)

# block obj -> HTML-formatted string
def getBlockLabel(b):
    s = '<<table border="0" cellborder="0" cellpadding="1">'

    if 'blockUseCount' in b:
        blockUseCount = " (Count: %s)" % str(b['blockUseCount'])
    else:
        blockUseCount = ""

    blockAttr = ""
    if 'attributes' in b:
        if 'backedge' in b['attributes']:
            blockAttr = ' <font color="lightpink">(backedge)</font>'
        if 'loopheader' in b['attributes']:
            blockAttr = ' <font color="lightgreen">(loop header)</font>'
        if 'splitedge' in b['attributes']:
            blockAttr = " (split edge)"

    blockTitle = '<font color="white">Block %s%s%s</font>' % (str(b['id']), blockUseCount, blockAttr)
    blockTitle = '<td align="center" bgcolor="black" colspan="3">%s</td>' % blockTitle
    s += '<tr>%s</tr>' % blockTitle
    
    if 'resumePoint' in b:
        s += getResumePointRow(b['resumePoint'], None)

    for inst in b['instructions']:
        if 'resumePoint' in inst:
            s += getResumePointRow(inst['resumePoint'], 'At')

        s += getInstructionRow(inst)

        if 'memInputs' in inst:
            s += getMemInputsRow(inst['memInputs'])

        if 'resumePoint' in inst:
            s += getResumePointRow(inst['resumePoint'], 'After')

    s += '</table>>'
    return s

# str -> ir obj -> ir obj -> Graph
# 'ir' is the IR to be used.
# 'mir' is always the MIR.
#  This is because the LIR graph does not contain successor information.
def buildGraphForIR(name, func, ir, mir):
    if len(ir['blocks']) == 0:
        return None

    g = Graph(name, func, 'digraph')
    g.addprops({'rankdir':'TB', 'splines':'true'})

    for i in range(0, len(ir['blocks'])):
        bactive = ir['blocks'][i] # Used for block contents.
        b = mir['blocks'][i] # Used for drawing blocks and edges.

        node = Node(getBlockNodeName(bactive))
        node.addprops({'shape':'box', 'label':getBlockLabel(bactive)})
        
        if 'backedge' in b['attributes']:
            node.addprops({'color':'crimson'})
        if 'loopheader' in b['attributes']:
            node.addprops({'color':'limegreen'})
        if 'splitedge' in b['attributes']:
            node.addprops({'style':'dashed'})

        g.addnode(node)

        for succ in b['successors']: # which are integers
            edge = Edge(getBlockNodeName(bactive), blockNumToNodeName(succ))
                
            if len(b['successors']) == 2:
                if succ == b['successors'][0]:
                    edge.addprops({'label':'1'})
                else:
                    edge.addprops({'label':'0'})

            g.addedge(edge)

    return g

# pass obj -> output file -> (Graph OR None, Graph OR None)
# The return value is (MIR, LIR); either one may be absent.
def buildGraphsForPass(p, func):
    name = p['name']
    mir = p['mir']
    lir = p['lir']
    return (buildGraphForIR(name, func, mir, mir), buildGraphForIR(name, func, lir, mir))

# function obj -> (Graph OR None, Graph OR None) list
# First entry in each tuple corresponds to MIR; second, to LIR.
def buildGraphs(func):
    graphstup = []
    for p in func['passes']:
        gtup = buildGraphsForPass(p, func)
        graphstup.append(gtup)
    return graphstup

# function obj -> (Graph OR None, Graph OR None) list
# Only builds the final pass.
def buildOnlyFinalPass(func):
    if len(func['passes']) == 0:
        return [None, None]
    p = func['passes'][-1]
    return [buildGraphsForPass(p, func)]

# Write out a graph, constructing a nice filename.
# function id -> pass id -> IR string -> Graph -> void
def outputPass(dir, fnum, pnum, irname, g):
    funcid = str(fnum).zfill(2)
    passid = str(pnum).zfill(2)

    filename = os.path.join(dir, 'func%s-pass%s-%s-%s.gv' % (funcid, passid, g.name, str(irname)))
    with open(filename, 'w') as fd:
        g.write(fd)

# Add in closing } and ] braces to close a JSON file in case of error.
def parenthesize(s):
    stack = []
    inString = False

    for c in s:
        if c == '"': # Doesn't handle escaped strings.
            inString = not inString

        if not inString:
            if   c == '{' or c == '[':
                stack.append(c)
            elif c == '}' or c == ']':
                stack.pop()
    
    while stack:
        c = stack.pop()
        if   c == '{': s += '}'
        elif c == '[': s += ']'

    return s


def genfiles(format, indir, outdir):
    gvs = glob.glob(os.path.join(indir, '*.gv'))
    gvs.sort()
    for gv in gvs:
        with open(os.path.join(outdir, '%s.%s' % (os.path.basename(gv), format)), 'w') as outfile:
            sys.stderr.write(' writing %s\n' % (outfile.name))
            subprocess.run(['dot', gv, '-T%s' % (format)], stdout=outfile, check=True)


def gengvs(indir, outdir):
    gvs = glob.glob(os.path.join(indir, '*.gv'))
    gvs.sort()
    for gv in gvs:
        sys.stderr.write(' writing %s\n' % (os.path.basename(gv)))
        shutil.copy(gv, outdir)


def genmergedpdfs(indir, outdir):
    gvs = glob.glob(os.path.join(indir, '*.gv'))
    gvs.sort()
    for gv in gvs:
        with open(os.path.join(indir, '%s.pdf' % (os.path.basename(gv))), 'w') as outfile:
            sys.stderr.write(' writing pdf %s\n' % (outfile.name))
            subprocess.run(['dot', gv, '-Tpdf'], stdout=outfile, check=True)

    sys.stderr.write('combining pdfs...\n')
    which = (shutil.which('pdftk') and 'pdftk') or (shutil.which('qpdf') and 'qpdf')
    prefixes = [os.path.basename(x) for x in glob.glob(os.path.join(indir, 'func*.pdf'))]
    prefixes = [re.match(r'func[^-]*', x).group(0) for x in prefixes]
    prefixes = list(set(prefixes))
    prefixes.sort()
    for prefix in prefixes:
        pages = glob.glob(os.path.join(indir, '%s-*.pdf' % (prefix)))
        pages.sort()
        outfile = os.path.join(outdir, '%s.pdf' % (prefix))
        sys.stderr.write(' writing pdf %s\n' % (outfile))
        if which == 'pdftk':
            subprocess.run(['pdftk', *pages, 'cat', 'output', outfile], check=True)
        elif which == 'qpdf':
            subprocess.run(['qpdf', '--empty', '--pages', *pages, '--', outfile], check=True)
        else:
            raise Exception("unknown pdf program")


def parsenums(numstr):
    if numstr is None:
        return None
    return [int(x) for x in numstr.split(',')]


def parsenames(namestr):
    return namestr and namestr.split(',')


def validate(args):
    if not shutil.which('dot'):
        sys.stderr.write("ERROR: graphviz (dot) is not installed\n")
        exit(1)
    if args.format == 'pdf':
        if not (shutil.which('pdftk') or shutil.which('qpdf')):
            sys.stderr.write("ERROR: either pdftk or qpdf must be installed in order to combine the generated PDFs.\n")
            sys.stderr.write("If you don't care and just want individual PDFs, use `--format pdfs`.\n")
            exit(1)
    if not os.path.isdir(args.outdir):
        sys.stderr.write("ERROR: could not find a directory at %s\n" % (args.outdir))
        exit(1)


def gen(args):
    validate(args)

    with open(args.input, 'r') as fd:
        s = fd.read()

    sys.stderr.write("loading %s...\n" % (args.input))
    ion = json.loads(parenthesize(s))

    sys.stderr.write("generating graphviz...\n")

    funcnums = parsenums(args.funcnum)
    funcnames = parsenames(args.funcname)
    passnums = parsenums(args.passnum)
    with tempfile.TemporaryDirectory() as tmpdir:
        for i in range(0, len(ion['functions'])):
            func = ion['functions'][i]

            if funcnums and i not in funcnums:
                continue
            if funcnames and func['name'] not in funcnames:
                continue

            gtl = buildOnlyFinalPass(func) if args.final else buildGraphs(func)

            if len(gtl) == 0:
                sys.stderr.write(" function %d (%s): abort during SSA construction.\n" % (i, func['name']))
            else:
                sys.stderr.write(" function %d (%s): success; %d passes.\n" % (i, func['name'], len(gtl)))

            for j in range(0, len(gtl)):
                gt = gtl[j]
                if gt == None:
                    continue

                mir = gt[0]
                lir = gt[1]

                if passnums and j in passnums:
                    if lir != None and args.out_lir:
                        lir.write(args.out_lir)
                    if mir != None and args.out_mir:
                        mir.write(args.out_mir)
                    if args.out_lir and args.out_mir:
                        break
                elif passnums:
                    continue

                # If only the final pass is requested, output both MIR and LIR.
                if args.final:
                    if lir != None:
                        outputPass(tmpdir, i, j, 'lir', lir)
                    if mir != None:
                        outputPass(tmpdir, i, j, 'mir', mir)
                    continue

                # Normally, only output one of (MIR, LIR), preferring LIR.
                if lir != None:
                    outputPass(tmpdir, i, j, 'lir', lir)
                elif mir != None:
                    outputPass(tmpdir, i, j, 'mir', mir)

        if args.format == 'pdf':
            sys.stderr.write("generating pdfs...\n")
            genmergedpdfs(tmpdir, args.outdir)
        elif args.format == 'gv':
            sys.stderr.write("copying gvs...\n")
            gengvs(tmpdir, args.outdir)
        else:
            format = 'pdf' if args.format == 'pdfs' else args.format
            sys.stderr.write("generating %ss...\n" % (format))
            genfiles(format, tmpdir, args.outdir)


def js_and_gen(args):
    validate(args)

    jsargs = args.remaining[1:]
    jsenv = os.environ.copy()
    flags = (jsenv['IONFLAGS'] if 'IONFLAGS' in jsenv else 'logs').split(',')
    if 'logs' not in flags:
        flags.append('logs')
    jsenv['IONFLAGS'] = ','.join(flags)
    subprocess.run([args.jspath or 'js', *jsargs], env=jsenv, check=True)

    args.input = '/tmp/ion.json'
    gen(args)


def jittest_and_gen(args):
    validate(args)

    testargs = args.remaining[1:]
    testenv = os.environ.copy()
    flags = (testenv['IONFLAGS'] if 'IONFLAGS' in testenv else 'logs').split(',')
    if 'logs' not in flags:
        flags.append('logs')
    testenv['IONFLAGS'] = ','.join(flags)
    jittest_path = os.path.join("js", "src", "jit-test", "jit_test.py")
    cmd = " ".join([sys.executable, jittest_path, shutil.which("js"), "--one", *testargs])
    print("Command:", cmd)
    subprocess.run(cmd, env=testenv, shell=True, check=True)

    args.input = '/tmp/ion.json'
    gen(args)


def add_main_arguments(parser):
    parser.add_argument('-o', '--outdir', help='The directory in which to store the output file(s).',
                                          default='.')
    parser.add_argument('-f', '--funcnum', help='Only operate on the specified function(s), by index. Multiple functions can be separated by commas, e.g. `1,5,234`.')
    parser.add_argument('-n', '--funcname', help='Only operate on the specified function(s), by name. Multiple functions can be separated by commas, e.g. `foo,bar,baz`.')
    parser.add_argument('-p', '--passnum', help='Only operate on the specified pass(es), by index. Multiple passes can be separated by commas, e.g. `1,5,234`.')
    parser.add_argument('--format', help='The output file format (pdf by default). `pdf` will merge all the graphs for each function into a single PDF; all other formats will produce a single file per graph.',
                                    choices=['gv', 'pdf', 'pdfs', 'png', 'svg'], default='pdf')
    parser.add_argument('--final', help='Only generate the final optimized MIR/LIR graphs.',
                                   action='store_true')
    parser.add_argument('--out-mir', help='Select the file where the MIR output would be written to.',
                                     type=argparse.FileType('w'))
    parser.add_argument('--out-lir', help='Select the file where the LIR output would be written to.',
                                     type=argparse.FileType('w'))


def main():
    parser = argparse.ArgumentParser(description='Visualize Ion graphs using GraphViz.')
    subparsers = parser.add_subparsers()
    json = subparsers.add_parser('json', formatter_class=argparse.RawDescriptionHelpFormatter,
                                         help='Generate a graph from a JSON file.',
                                         description='Generate a graph from a JSON file.')
    js = subparsers.add_parser('js', formatter_class=argparse.RawDescriptionHelpFormatter,
                                     help='Run js and iongraph together in one call.',
                                     description='Run js and iongraph together in one call. Arguments before the -- separator are for iongraph, while arguments after the -- will be passed to js.\n\nexample:\n  iongraph js --funcnum 1 -- -m mymodule.mjs')
    jittest = subparsers.add_parser('jit-test', formatter_class=argparse.RawDescriptionHelpFormatter,
                                                help='Run jit-test and iongraph together in one call.',
                                                description='Run jit-test and iongraph together in one call. Arguments before the -- separator are for iongraph, while arguments after the -- will be passed to jit-test.\n\nexample:\n  iongraph jit-test --funcnum 1 -- category/test.js')

    add_main_arguments(json)
    json.add_argument('input', help='The JSON log file generated by IonMonkey. (Default: /tmp/ion.json)',
                                 nargs='?', default='/tmp/ion.json')
    json.set_defaults(func=gen)

    js.add_argument('--jspath', help='The path to the js executable.')
    add_main_arguments(js)
    js.add_argument('remaining', nargs=argparse.REMAINDER)
    js.set_defaults(func=js_and_gen)

    add_main_arguments(jittest)
    jittest.add_argument('remaining', nargs=argparse.REMAINDER)
    jittest.set_defaults(func=jittest_and_gen)

    args = parser.parse_args()
    args.func(args)


if __name__ == '__main__':
    main()
