#!/usr/bin/python

# **********************************************************
# Copyright (c) 2021 Google, Inc.   All rights reserved.
# Copyright (c) 2016 - 2023 ARM Limited. All rights reserved.
# **********************************************************

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of ARM Limited nor the names of its contributors may be
#   used to endorse or promote products derived from this software without
#   specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL ARM LIMITED OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.

# This script reads opnd_defs.txt, codec_<version>.txt files and generates:
# opnd_decode_funcs.h
# opnd_encode_funcs.h
# encode_gen_<version>.h
# decode_gen_<version>.h
# opcode_names.h
# opcode_api.h
# opcode_opnd_pairs.h
# It is automatically run by cmake when opnd_defs.txt or codec_<version>.txt
# change.

import os
import re
import sys

N = 32 # bits in an instruction word
ONES = (1 << N) - 1

# Stores instances of FallthroughDecode objects for resolution of overlapping
# encodings.
FALLTHROUGH = dict()

opnd_header = '/* This file was generated by codec.py from opnd_defs.txt and the codec_<version>.txt files. */\n\n'
opcode_header = '/* This file was generated by codec.py from codec_<version>.txt files. */\n\n'

class Opcode:
    def __init__(self, name, nzcv_rw, feat):
        self.name = name
        self.nzcv_rw = nzcv_rw
        self.feat = feat

class Opnd:
    def __init__(self, gen, used, must_be_set):
        self.gen = gen
        self.used = used
        self.non_zero = must_be_set

class Opndset:
    def __init__(self, fixed, dsts, srcs, enc_order):
        for (ds, i, ot) in enc_order:
            if not ((ds == 'dst' or ds == 'src') and
                    (dsts if ds == 'dst' else srcs)[i] == ot):
                raise Exception
        self.fixed = fixed
        self.dsts = dsts
        self.srcs = srcs
        self.enc_order = enc_order

class FallthroughDecode:
    def __init__(self, opcode, opndset='', decode_clause='', decode_function=''):
        self.flag_name = opcode + '_fallthrough_flag'
        self.opndset = opndset
        self.decode_clause = decode_clause
        self.decode_function = decode_function

class Pattern:
    def __init__(self, pattern, opcode_bits, opnd_bits, high_soft_bits, opcode, opndset, enum, feat):
        self.pattern = pattern
        self.opcode_bits = opcode_bits
        self.opnd_bits = opnd_bits
        # High soft bits are bits that are allowed to vary by the spec but
        # default to 1.They are represented by ^ in codec.txt
        self.high_soft_bits = high_soft_bits
        self.opcode = opcode
        self.opndset = opndset
        self.enum = enum
        self.feat = feat

    def __iter__(self):
        for field in (self.opcode_bits, self.opnd_bits, self.opcode, self.opndset):
            yield field

    def all_opnds(self):
        if isinstance(self.opndset, tuple):
            return self.opndset[0] + self.opndset[1]
        if isinstance(self.opndset, str):
            return [self.opndset]
        return self.opndset

    def ignored_bit_mask(self):
        return ~(self.opnd_bits | self.high_soft_bits)

    def set_bits(self):
        return self.opcode_bits | self.high_soft_bits

def codec_header(isa_version):
    return '/* This file was generated by codec.py from codec_%s.txt. */\n\n' % isa_version

def fallthrough_instr_id(opcode, opcode_bits, opnd_bits):
    return '%s_%08x_%08x' % (opcode, opcode_bits, opnd_bits)

def generate_opndset_decoders(opndsettab, opndtab):
    c = []
    c += ['bool {} = false;'.format(opcode.flag_name) for opcode in
          FALLTHROUGH.values()]
    c += ['\n']
    for name in sorted(opndsettab):
        opnd_set = opndsettab[name]
        (dsts, srcs) = (opnd_set.dsts, opnd_set.srcs)
        c += ['/* %s <- %s */' % (opnd_set.dsts, opnd_set.srcs)]
        c += ['static bool',
              'decode_opnds%s(uint enc, dcontext_t *dcontext, byte *pc, '
              'instr_t *instr, int opcode)' % name, '{']
        if dsts + srcs != []:
            vars = (['dst%d' % i for i in range(len(dsts))] +
                    ['src%d' % i for i in range(len(srcs))])
            tests = (['!decode_opnd_%s(enc & 0x%08x, opcode, pc, &dst%d)' %
                      (dsts[i], opndtab[dsts[i]].gen | opndtab[dsts[i]].used, i)
                      for i in range(len(dsts))]
                     +
                     ['!decode_opnd_%s(enc & 0x%08x, opcode, pc, &src%d)' %
                      (srcs[i], opndtab[srcs[i]].gen | opndtab[srcs[i]].used, i)
                      for i in range(len(srcs))])
            c += ['    opnd_t ' + ', '.join(vars) + ';']
            c += ['    if (' + ' ||\n        '.join(tests) + ')']
            c += ['        return false;']
        c.append('    instr_set_opcode(instr, opcode);')
        c.append('    instr_set_num_opnds(dcontext, instr, %d, %d);' %
                 (len(dsts), len(srcs)))
        for i in range(len(dsts)):
            c.append('    instr_set_dst(instr, %d, dst%d);' % (i, i))
        for i in range(len(srcs)):
            c.append('    instr_set_src(instr, %d, src%d);' % (i, i))
        c.append('    return true;')
        c.append('}')
        c.append('')
    return '\n'.join(c) + '\n'

def generate_decoder(patterns, opndsettab, opndtab, opc_props, curr_isa, next_isa):

    # Recursive function to generate nested conditionals in main decoder.
    def gen(c, pats, depth):
        def reorder_key(t):
            f, v, m, t = t
            return (m, t, f, v)

        def indent_append(text):
            c.append('{}{}'.format('    ' * depth, text))

        if len(pats) < 4:
            else_str = ''
            for pattern in sorted(pats, key=reorder_key):

                not_zero_mask = 0
                try:
                    opnd_set  = opndsettab[pattern.opndset]
                    for mask in (opndtab[o].non_zero for o in opnd_set.dsts + opnd_set.srcs):
                        not_zero_mask |= mask
                except KeyError:
                    pass

                indent_append('%sif ((enc & 0x%08x) == 0x%08x%s)'  % (
                    else_str,
                    ((1 << N) - 1) & pattern.ignored_bit_mask(),
                    pattern.opcode_bits,
                    ' && (enc & 0x%08x) != 0' % not_zero_mask if not_zero_mask else ''))

                if not else_str:
                    else_str = 'else '
                if opc_props[pattern.opcode].nzcv_rw != 'n':
                    c[-1] = c[-1] + ' {'
                    # Uncomment this for debug output in generated code:
                    # indent_append('    // %s->%s' % (m, opc_props[m].nzcv_rw))
                    if opc_props[pattern.opcode].nzcv_rw == 'r':
                        indent_append('    instr->eflags |= EFLAGS_READ_NZCV;')
                    elif opc_props[pattern.opcode].nzcv_rw == 'w':
                        indent_append('    instr->eflags |= EFLAGS_WRITE_NZCV;')
                    elif opc_props[pattern.opcode].nzcv_rw in ['rw', 'wr']:
                        indent_append('    instr->eflags |= (EFLAGS_READ_NZCV | '
                                      'EFLAGS_WRITE_NZCV);')
                    elif opc_props[pattern.opcode].nzcv_rw in ['er', 'ew']:
                        indent_append(
                            '    // instr->eflags handling for %s is '
                            'manually handled in codec.c\'s decode_common().' % pattern.opcode)
                    else:
                        indent_append('    ASSERT(0);')
                enc_key = fallthrough_instr_id(
                    pattern.opcode, pattern.opcode_bits, pattern.opnd_bits)
                if enc_key in FALLTHROUGH and pattern.opndset == FALLTHROUGH[enc_key].opndset:
                    indent_append('    %s = true;' % FALLTHROUGH[enc_key].flag_name)
                    FALLTHROUGH[enc_key].decode_clause = \
                        'if ((enc & 0x%08x) == 0x%08x && %s == true)' % \
                        (((1 << N) - 1) & pattern.ignored_bit_mask(), pattern.opcode_bits, \
                        FALLTHROUGH[enc_key].flag_name)
                    FALLTHROUGH[enc_key].decode_function = \
                        'return decode_opnds%s(enc, dc, pc, instr, OP_%s);' % \
                        (pattern.opndset, pattern.opcode)
                else:
                    indent_append('    return decode_opnds%s(enc, dc, pc, '
                                  'instr, OP_%s);' % (pattern.opndset, pattern.opcode))
                if opc_props[pattern.opcode].nzcv_rw != 'n':
                    indent_append('}')
            return
        # Look for best bit to test. We aim to reduce the number of patterns
        # remaining.
        best_switch_bit = -1
        least_patterns_selected = len(pats)
        for switch_bit in range(N):
            bit_not_set_or_variable = 0
            bit_set_or_variable = 0
            for p in pats:
                # In how many patterns is this bit not set or included
                # in the variable bits.
                if (1 << switch_bit) & (~p.opcode_bits | p.opnd_bits | p.high_soft_bits):
                    bit_not_set_or_variable += 1
                # How many patterns have this b set and or in the variable
                # bits.
                if (1 << switch_bit) & (p.opcode_bits | p.opnd_bits | p.high_soft_bits):
                    bit_set_or_variable += 1
            patterns_selected = max(bit_not_set_or_variable, bit_set_or_variable)
            if patterns_selected < least_patterns_selected:
                best_switch_bit = switch_bit
                least_patterns_selected = patterns_selected
        indent_append('if ((enc >> %d & 1) == 0) {' % (best_switch_bit,))
        pats0 = []
        pats1 = []
        # Split the decode tree on this bit, if the bit
        # lies in the operand bits, then it goes in both trees.
        for p in pats:
            if (1 << best_switch_bit) & (~p.opcode_bits | p.opnd_bits | p.high_soft_bits):
                pats0.append(p)
            if (1 << best_switch_bit) & (p.opcode_bits | p.opnd_bits | p.high_soft_bits):
                pats1.append(p)
        gen(c, pats0, depth + 1)
        indent_append('} else {')
        gen(c, pats1, depth + 1)
        indent_append('}')

    c = ['static bool',
         'decoder_' + curr_isa + '(uint enc, dcontext_t *dc, byte *pc, instr_t *instr)',
          '{']
    gen(c, patterns, 1)
    for opcode in FALLTHROUGH.values():
        c += ['    %s' % opcode.decode_clause]
        c += ['        %s' % opcode.decode_function]
    # Call the next version of the decoder if defined.
    if next_isa != '':
        c.append('    return decoder_' + next_isa + '(enc, dc, pc, instr);')
    else:
        c.append('    return false;')
    c.append('}')
    return '\n'.join(c) + '\n'

def find_required(fixed, reordered, i, opndtab):
    known = fixed
    used = opndtab[reordered[i][2]].used
    req = []
    for j in range(i):
        if used & ~known == 0:
            break
        if opndtab[reordered[j][2]].gen & used & ~known != 0:
            req = req + ['%s%d' % (reordered[j][0], reordered[j][1])]
            known = known | opndtab[reordered[j][2]].gen
    return 'enc' if req == [] else '(enc | %s)' % ' | '.join(req)

def make_enc(n, reordered, f, opndtab):
    (ds, i, ot) = reordered[n]
    instr_arg_if_required = 'instr, ' if ot == 'imm16' else ''
    encode_method_format_str = ('encode_opnd_%s(%s, opcode, pc, '
                                'instr_get_%s(instr, %d), ' +
                                instr_arg_if_required + '&%s%d)')
    ret_str = (encode_method_format_str %
            (ot, ('0' if opndtab[ot].used == 0 else
                  'enc & 0x%08x' % opndtab[ot].used
                  if opndtab[ot].used & ~f == 0 else
                  '%s & 0x%08x' % (find_required(f, reordered, n, opndtab),
                                   opndtab[ot].used)),
             ds, i, ds, i))
    return ret_str

def generate_opndset_encoders(opndsettab, opndtab):
    c = []
    for name in sorted(opndsettab):
        os = opndsettab[name]
        (fixed, dsts, srcs, enc_order) = (os.fixed, os.dsts, os.srcs, os.enc_order)
        c += ['/* %s <- %s */' % (os.dsts, os.srcs)]
        c += ['static uint',
              ('encode_opnds%s' % name) +
              '(byte *pc, instr_t *instr, uint enc, decode_info_t *di)',
              '{']
        if dsts + srcs == []:
            c.append('    return enc;')
        else:
            vars = (['dst%d' % i for i in range(len(dsts))] +
                    ['src%d' % i for i in range(len(srcs))])
            c += ['    int opcode = instr->opcode;']
            # The initial values are only required to silence a bad compiler warning:
            c += ['    uint ' + ' = 0, '.join(vars) + ' = 0;']
            tests = (['instr_num_dsts(instr) == %d && instr_num_srcs(instr) == %d' %
                      (len(dsts), len(srcs))] +
                     [make_enc(i, enc_order, fixed, opndtab)
                      for i in range(len(enc_order))])

            tests2 = (['dst%d == (enc & 0x%08x)' % (i, opndtab[dsts[i]].gen)
                       for i in range(len(dsts))] +
                      ['src%d == (enc & 0x%08x)' % (i, opndtab[srcs[i]].gen)
                       for i in range(len(srcs))])
            c += ['    if (' + ' &&\n        '.join(tests) + ') {']
            c += ['        ASSERT((dst%d & 0x%08x) == 0);' %
                  (i, ONES & ~opndtab[dsts[i]].gen) for i in range(len(dsts))]
            c += ['        ASSERT((src%d & 0x%08x) == 0);' %
                  (i, ONES & ~opndtab[srcs[i]].gen) for i in range(len(srcs))]
            c += ['        enc |= ' + ' | '.join(vars) + ';']
            c += ['        if (' + ' &&\n            '.join(tests2) + ')']
            c += ['            return enc;']
            c += ['    }']
            c += ['    return ENCFAIL;']
        c.append('}')
        c.append('')
    return '\n'.join(c) + '\n'

def generate_encoder(patterns, opndsettab, opndtab, opc_props, curr_isa, next_isa):
    c = []
    case = dict()
    for p in patterns:
        (opcode_bits, opnd_bits, opcode, opndset) = p
        if opcode not in case:
            case[opcode] = []
        case[opcode].append(p)

    c += ['static uint',
          'encoder_' + curr_isa + '(byte *pc, instr_t *instr, decode_info_t *di)',
          '{',
          '    uint enc;',
          '    (void)enc;',
          '    switch (instr->opcode) {']

    def reorder_key(t):
        opcode_bits, opnd_bits, opcode, opndset = t
        return (opcode, opndset, opcode_bits, opnd_bits)

    for opcode in sorted(case):
        c.append('    case OP_%s:' % opcode)
        if opc_props[opcode].feat != 'BASE':
            c.append('#       if !defined(DR_HOST_NOT_TARGET) && !defined(STANDALONE_DECODER)')
            c.append('        if (!proc_has_feature(FEATURE_%s))' % opc_props[opcode].feat)
            c.append('            return ENCFAIL;')
            c.append('#       endif')
        patterns = sorted(case[opcode], key=reorder_key)
        last_pattern = patterns.pop()
        for pattern in patterns:
            c.append('        enc = encode_opnds%s(pc, instr, 0x%08x, di);' % (
                pattern.opndset, pattern.set_bits()))
            c.append('        if (enc != ENCFAIL)')
            c.append('            return enc;')
        # Fallthrough to call the next version of the encoder if defined.
        if next_isa != '':
            c.append('        enc = encode_opnds%s(pc, instr, 0x%08x, di);' % (
                last_pattern.opndset, last_pattern.set_bits()))
            c.append('        if (enc != ENCFAIL)')
            c.append('            return enc;')
            c += ['        break;']
        else:
            c.append('        return encode_opnds%s(pc, instr, 0x%08x, di);' % (
                last_pattern.opndset, last_pattern.set_bits()))
    c += ['    }']
    # Call the next version of the encoder if defined.
    if next_isa != '':
        c += ['    return encoder_' + next_isa + '(pc, instr, di);']
    else:
        c += ['    return ENCFAIL;']
    c += ['}']
    return '\n'.join(c) + '\n'

def generate_opcodes(patterns):
    c = ['#ifndef _DR_IR_OPCODES_AARCH64_H_',
         '#define _DR_IR_OPCODES_AARCH64_H_ 1',
         '',
         '/****************************************************************************',
         ' * OPCODES',
         ' */',
         '/**',
         ' * @file dr_ir_opcodes_aarch64.h',
         ' * @brief Instruction opcode constants for AArch64.',
         ' */',
         '/** Opcode constants for use in the instr_t data structure. */',
         'enum {',
         '/*   0 */     OP_INVALID,  /* NULL, */ /**< INVALID opcode */',
         '/*   1 */     OP_UNDECODED,  /* NULL, */ /**< UNDECODED opcode */',
         '/*   2 */     OP_CONTD,    /* NULL, */ /**< CONTD opcode */',
         '/*   3 */     OP_LABEL,    /* NULL, */ /**< LABEL opcode */',
         '/*   4 */    OP_xx, /* placeholder for undecoded instructions */',
         '/*   5 */    OP_ldstex, /* single-entry single-exit block with exclusive load/store */',
         '']
    pattern_d = {int(p.enum): p.opcode for p in patterns}
    for i in range(6, max(pattern_d.keys()) + 1):
        try:
            c.append('/* {i:>4} */     OP_{opcode} = {i}, /**< AArch64 {opcode} opcode. */'.format(
                i=i, opcode=pattern_d[i]))
        except KeyError:
            pass
    c += ['',

          '',
          '    OP_AFTER_LAST,',
          '    OP_FIRST = OP_LABEL + 1,      /**< First real opcode. */',
          '    OP_LAST  = OP_AFTER_LAST - 1, /**< Last real opcode. */',
          '};',
          '',
          '/* alternative names */',
          '#define OP_jmp       OP_b      '
          '/**< Platform-independent opcode name for jump. */',
          '#define OP_jmp_short OP_b      '
          '/**< Platform-independent opcode name for short jump. */',
          '#define OP_load      OP_ldr    '
          '/**< Platform-independent opcode name for load. */',
          '#define OP_store     OP_str    '
          '/**< Platform-independent opcode name for store. */',
          '',
          '/******************************'
          '**********************************************/',
          '',
          '#endif /* _DR_IR_OPCODES_AARCH64_H */']
    return '\n'.join(c) + '\n'

def generate_opcode_names(patterns):
    c = ['#ifndef OPCODE_NAMES_H',
         '#define OPCODE_NAMES_H 1',
         '',
         'const char *opcode_names[] = {',
         '/*   0 */ "<invalid>",',
         '/*   1 */ "<undecoded>",',
         '/*   2 */ "<contd>",',
         '/*   3 */ "<label>",',
         '/*   4 */ "xx",',
         '/*   5 */ "ldstex",']
    pattern_d = {int(p.enum): p for p in patterns}
    for i in range(6, max(pattern_d.keys()) + 1):
        try:
            name = pattern_d[i].opcode
        except KeyError:
            name = "<invalid>"
        c.append('/*{:>4} */ "{}",'.format(i, name))
    c += [
          '};',
          '',
          '#endif /* OPCODE_NAMES_H */']
    return '\n'.join(c) + '\n'

# Generates pairs of opcodes and masks for their operands for all the
# side-effect-free and non-branch instructions suppported by the decoder.
# The generated file is used by the drstatecmp-fuzz-app.
def generate_opcode_opnd_pairs(patterns):
    c = ['#ifndef DR_OPCODE_OPND_PAIRS_H',
         '#define DR_OPCODE_OPND_PAIRS_H 1',
         '',
         '#include <stdint.h>',
         '',
         'typedef struct {',
         '  uint32_t opcode;',
         '  uint32_t opnd;',
         '} dr_opcode_opnd_pair_t;',
         '',
         'const dr_opcode_opnd_pair_t dr_fuzz_opcode_opnd_pairs[] = {']
    # Exclude instructions with side-effects and branch instructions.
    # Particularly,  exclude: i) Load/Stores (opcode: x1x0);
    # ii) Branches, Exception Generating and System instructions (opcode: 101x);
    # iii) SVE memory (opcode: 0010 and 1 for the most significant bit).
    # The allowed instructions include all the Data Processing instructions (including
    # the FP and SIMD ones) and the non-memory SVE instructions.
    excluded_opcodes = re.compile('....1.0|...101.|1..0010')
    cnt = 0
    for p in patterns:
        if excluded_opcodes.match(p.pattern):
            continue
        cnt += 1
        c.append('/* %s */ {%s, %s},' % (p.opcode, bin(p.opcode_bits), bin(p.opnd_bits)))
    c += ['};',
          '']
    c.append('#define DR_FUZZ_INST_CNT %d' % cnt)
    c += ['',
          '#endif /* DR_OPCODE_OPND_PAIRS_H */']
    return '\n'.join(c) + '\n'

def write_if_changed(file, data):
    try:
        if open(file, 'r').read() == data:
            return
    except IOError:
        pass
    open(file, 'w').write(data)

def read_opnd_defs_file(path):
    opndtab = dict()
    file_msg = 'operand definitions file'

    try:
        with open(path, 'r') as file:
            for line in (l.split('#')[0].strip() for l in file):
                if not line:
                    continue
                if not re.match('^[x\?\-\+]{32} +[a-zA-Z_0-9]+$', line):
                    raise Exception('Cannot parse line: %s in %s' % (line, file_msg))
                # Syntax: mask opndtype
                mask, opndtype = line.split()
                if opndtype in opndtab:
                    raise Exception('Repeated definition of opndtype %s in %s' % (opndtype, file_msg))
                opndtab[opndtype] = Opnd(int(re.sub('[x\+]', '1', re.sub('[^x^\+]', '0', mask)), 2),
                                         int(re.sub('\?', '1', re.sub('[^\?]', '0', mask)), 2),
                                         int(re.sub('\+', '1', re.sub('[^\+]', '0', mask)), 2))
    except IOError as e:
        raise Exception('Unable to read operand definitions file, {}: {}'.format(path, e.strerror))

    return opndtab

def read_codec_file(path):
    opc_props = dict()
    patterns = []

    try:
        with open(path, 'r') as file:
            for line in (l.split('#')[0].strip() for l in file):
                if not line:
                    continue
                if re.match('^[01x\^]{32} +[n|r|w|rw|wr|er|ew]+ +[0-9]+ +[a-zA-Z0-9]* +[a-zA-Z_0-9][a-zA-Z_0-9 ]*:[a-zA-Z_0-9 ]*$', line):
                    # Syntax: pattern opcode opndtype* : opndtype*
                    pattern, nzcv_rw_flag, enum, feat, opcode, args = line.split(None, 5)
                    dsts, srcs = [a.split() for a in args.split(':')]
                    opcode_bits = int(re.sub('[\^x]', '0', pattern), 2)
                    opnd_bits = int(re.sub('x', '1', re.sub('[1\^]', '0', pattern)), 2)
                    high_soft_bits = int(re.sub('\^', '1', re.sub('[10x]', '0', pattern)), 2)
                    patterns.append(Pattern(pattern, opcode_bits, opnd_bits, high_soft_bits, opcode, (dsts, srcs), enum, feat))
                    opc_props[opcode] = Opcode(opcode, nzcv_rw_flag, feat)
                    continue
                if re.match('^[01x\^]{32} +[n|r|w|rw|wr|er|ew]+ +[0-9]+ +[a-zA-Z0-9]* +[a-zA-Z_0-9]+ +[a-zA-Z_0-9]+', line):
                    # Syntax: pattern opcode opndset
                    pattern, nzcv_rw_flag, enum, feat, opcode, opndset = line.split()
                    opcode_bits = int(re.sub('x', '0', pattern), 2)
                    opnd_bits = int(re.sub('x', '1', re.sub('1', '0', pattern)), 2)
                    high_soft_bits = int(re.sub('\^', '1', re.sub('[10x]', '0', pattern)), 2)
                    patterns.append(Pattern(pattern, opcode_bits, opnd_bits, high_soft_bits, opcode, opndset, enum, feat))
                    opc_props[opcode] = Opcode(opcode, nzcv_rw_flag, feat)
                    continue
                raise Exception('Cannot parse line: %s in %s' % (line, path))
    except IOError as e:
        print('Unable to read instruction definitions file, %s: %s' % (path, e.strerror))
        sys.exit(1)

    return patterns, opc_props

def pattern_to_str(opcode_bits, opnd_bits, opcode, opndset):
    p = ''
    for i in range(N - 1, -1, -1):
        p += 'x' if (opnd_bits >> i & 1) else '%d' % (opcode_bits >> i & 1)
    t = opndset
    if not type(t) is str:
        (dsts, srcs) = t
        t = ' '.join(dsts) + ' : ' + ' '.join(srcs)
    return '%s %s %s' % (p, opcode, t)

def consistency_check(patterns, opndtab):
    for p in patterns:
        if not isinstance(p.opndset, str):
            dsts, srcs = p.opndset
            unhandled_bits = p.opnd_bits
            for ot in dsts + srcs:
                try:
                    unhandled_bits &= ~opndtab[ot].gen
                except KeyError:
                    raise Exception('Undefined opndtype %s in:\n%s' %
                                    (ot, pattern_to_str(*p)))
            if unhandled_bits:
                raise Exception('Unhandled bits:\n%32s in:\n%s' %
                                (re.sub('1', 'x', re.sub('0', ' ', bin(unhandled_bits)[2:])),
                                 pattern_to_str(*p)))
    # Detect and mark overlapping patterns for special handling. Named as
    # 'fallthrough' because the special handling is done at the end of the
    # decoder's main if/then/else clauses block.
    for i, pattern_a in enumerate(patterns):
        for pattern_b in patterns[:i]:
            non_zero_bits_a = 0
            non_zero_bits_b = 0
            try:
                for opnd in (opndtab[op] for op in pattern_a.all_opnds()):
                    non_zero_bits_a &= opnd.non_zero
            except KeyError:
                pass
            try:
                for opnd in (opndtab[op] for op in pattern_b.all_opnds()):
                    non_zero_bits_b &= opnd.non_zero
            except KeyError:
                pass

            zero_overlap = (
                non_zero_bits_a & pattern_b.opnd_bits == 0 or
                non_zero_bits_b & pattern_b.opnd_bits == 0)

            if ((pattern_b.opcode_bits ^ pattern_a.opcode_bits) &
                ~pattern_b.opnd_bits & ~pattern_a.opnd_bits == 0 and
                not zero_overlap):
                print('Overlap found between:\n%s\nand\n%s' %
                      (pattern_to_str(*pattern_b),
                      pattern_to_str(*pattern_a)))
                enc_key = fallthrough_instr_id(pattern_a.opcode, pattern_a.opcode_bits, pattern_a.opnd_bits)
                if enc_key in FALLTHROUGH:
                    raise Exception('Error: multiple overlaps detected for '
                                    '%s. Unable to resolve.\n' % enc_key)
                print('Resolving overlap.')
                FALLTHROUGH[enc_key] = FallthroughDecode(enc_key)

# This function reorders the operands for encoding so that no operand encoder
# requires bits that are generated by an operand encoder that has not yet
# been executed.
def reorder_opnds(fixed, dsts, srcs, opndtab):
    def reorder(known, opnds):
        if opnds == []:
            return []
        for i in range(len(opnds)):
            (_, _, _, gen, used) = opnds[i]
            if used & ~known == 0:
                return [opnds[i]] + reorder(known | gen, opnds[0:i] + opnds[i + 1:])
        raise Exception('Cyclic dependency: %s' %
                        ' '.join([x for (_, _, x, _, _) in opnds]))
    opnds = ([('dst', i, dsts[i]) for i in range(len(dsts))] +
             [('src', i, srcs[i]) for i in range(len(srcs))])
    opnds_expanded = [(ds, i, ot, opndtab[ot].gen, opndtab[ot].used)
                      for (ds, i, ot) in opnds]
    opnds_reordered = [(ds, i, ot)
                       for (ds, i, ot, _, _) in reorder(fixed, opnds_expanded)]
    return Opndset(fixed, dsts, srcs, opnds_reordered)

# Here we give the opndsets names, which will be used in function names.
# Opndsets specified in 'codec_<version>.txt' are prefixed with '_', while
# generated names have the form 'gen_X', where X is the hex representation of
# the smallest matching instruction word.
def opndset_naming(patterns, opndtab):
    opndsets = dict() # maps (dst, src, opnd_bits) to smallest pattern seen so far
    for opcode_bits, opnd_bits, opcode, opndset in patterns:
        if type(opndset) is not str:
            dsts, srcs = opndset
            h = (' '.join(dsts), ' '.join(srcs), opnd_bits)
            if h not in opndsets or opcode_bits < opndsets[h]:
                opndsets[h] = opcode_bits

    opndsettab = dict() # maps generated name to original opndsets
    for p in patterns:
        if type(p.opndset) is str:
            new_opndset = '_' + p.opndset
        else:
            (dsts, srcs) = p.opndset
            h = (' '.join(dsts), ' '.join(srcs), p.opnd_bits)
            new_opndset = 'gen_%08x_%08x' % (opndsets[h], p.opnd_bits)
            reordered = reorder_opnds(ONES & ~p.opnd_bits, dsts, srcs, opndtab)
            if not new_opndset in opndsettab:
                opndsettab[new_opndset] = reordered
        p.opndset = new_opndset
        enc_key = fallthrough_instr_id(p.opcode, p.opcode_bits, p.opnd_bits)
        if enc_key in FALLTHROUGH:
            FALLTHROUGH[enc_key].opndset = new_opndset
    return (patterns, opndsettab)

def main():
    if len(sys.argv) != 3:
        print('Usage: codec.py path/to/input_dir path/to/output_dir')
        sys.exit(1)
    input_dir = sys.argv[1]
    output_dir = sys.argv[2]

    # The Arm AArch64's architecture versions supported by the DynamoRIO codec.
    # Currently, v8.0 is fully supported, while v8.1, v8.2, v8.3, v8.4, v8.6, SVE,
    # and SVE2 are partially supported. The null terminator element at the end
    # is required by some generator functions to correctly generate links
    # between each version's decode/encode logic.
    isa_versions = ['v80', 'v81', 'v82', 'v83', 'v84', 'v86', 'sve', 'sve2', '']

    # Read the instruction operand definitions. Used by the codec when
    # generating code to decode and encode instructions.
    opndtab = read_opnd_defs_file(os.path.join(input_dir, 'opnd_defs.txt'))

    # Read all instruction definitions and use the instructions' operand
    # signatures (opndsets) to generate decoder and encoder functions.
    opndset_decode_funcs = opnd_header
    opndset_encode_funcs = opnd_header
    for isa_version in isa_versions[:-1]:
        patterns, opc_props = read_codec_file(os.path.join(input_dir, 'codec_' + isa_version + '.txt'))
        # Consistency check within each codec_<version>.txt file. Global check
        # across all codec_<version>.txt files is done by
        # aarch64_check_codec_order.py and codecsort.py.
        consistency_check(patterns, opndtab)
        (patterns, opndsettab) = opndset_naming(patterns, opndtab)
        opndset_decode_funcs += generate_opndset_decoders(opndsettab, opndtab)
        opndset_encode_funcs += generate_opndset_encoders(opndsettab, opndtab)
    write_if_changed(os.path.join(output_dir, 'opnd_decode_funcs.h'),
                     opndset_decode_funcs)
    write_if_changed(os.path.join(output_dir, 'opnd_encode_funcs.h'),
                     opndset_encode_funcs)

    # Read all instruction definitions and use the instructions' bitmask to
    # generate decode and encode logic.
    for idx, isa_version in enumerate(isa_versions):
        (patterns, opc_props) = read_codec_file(os.path.join(input_dir, 'codec_' + isa_version + '.txt'))
        (patterns, opndsettab) = opndset_naming(patterns, opndtab)
        write_if_changed(os.path.join(output_dir, 'decode_gen_' + isa_version +'.h'),
                         codec_header(isa_version) + generate_decoder(patterns, opndsettab, opndtab, opc_props, isa_version, isa_versions[idx + 1]))
        write_if_changed(os.path.join(output_dir, 'encode_gen_' + isa_version +'.h'),
                         codec_header(isa_version) + generate_encoder(patterns, opndsettab, opndtab, opc_props, isa_version, isa_versions[idx + 1]))
        if isa_versions[idx + 1] == '':
            break

    # Generate opcode declarations and definitions for the API and fuzz
    # testing.
    opcode_api_patterns = []
    opcode_names_patterns = []
    opcode_opnd_pairs_patterns = []
    for isa_version in isa_versions[:-1]:
        (patterns, opc_props) = read_codec_file(os.path.join(input_dir, 'codec_' + isa_version + '.txt'))
        opcode_api_patterns += patterns
        opcode_names_patterns += patterns
        opcode_opnd_pairs_patterns += patterns
    write_if_changed(os.path.join(output_dir, 'opcode_api.h'),
                     opcode_header + generate_opcodes(opcode_api_patterns))
    write_if_changed(os.path.join(output_dir, 'opcode_names.h'),
                     opcode_header + generate_opcode_names(opcode_names_patterns))
    opcode_opnd_pairs_patterns.sort(key=lambda pat: pat.opcode)
    write_if_changed(os.path.join(output_dir, 'opcode_opnd_pairs.h'),
                     opcode_header + generate_opcode_opnd_pairs(opcode_opnd_pairs_patterns))


if __name__ == '__main__':
    main()