# sgude 1_22_2021

# Performed modifications for NVDV5

import ctypes

import struct

import jtag

import re

import math

import time

import os

import sys

import argparse


class NVDV_Jtag:

    """Class for interfacing with FTDI 2322x modules via USB/JTAG"""



    state = jtag.states.unknown

    # instr = [('idcode', 'com_32', None, '0b11111101'),

    #          ('jtag_host', 'com_39', None, '0b10100000')

    #          ]

    instr = []



    def __init__(self, device_index=0):

        self.device_handle = ctypes.c_void_p()

        self._libraries = None

        if os.name == 'posix':

            # os.system("sudo rmmod ftdi_sio")

            # os.system("sudo rmmod usbserial")

            self._libraries = ctypes.CDLL('/usr/local/lib/libftd2xx.so')

        else:

            self._libraries = ctypes.WinDLL('ftd2xx')



        self._libraries.FT_Open(device_index, ctypes.byref(self.device_handle))



        dev = ctypes.c_void_p()

        id = ctypes.c_void_p()

        sn = ctypes.create_string_buffer(16)

        desc = ctypes.create_string_buffer(64)



        USBftStatus = self._libraries.FT_GetDeviceInfo(self.device_handle, ctypes.byref(dev), ctypes.byref(id), sn,

                                                       desc, None)



        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)



        self.device = dev.value

        self.id = id.value

        self.serial = sn.value

        self.description = desc.value

        USBftStatus = self._libraries.FT_ResetDevice(self.device_handle)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        FT_PURGE_RX = 1

        FT_PURGE_TX = 2

        USBftStatus = self._libraries.FT_Purge(self.device_handle, FT_PURGE_RX | FT_PURGE_TX)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        USBftStatus = self._libraries.FT_SetUSBParameters(self.device_handle, 65536, 65536)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)



        USBftStatus = self._libraries.FT_SetChars(self.device_handle, False, 0, False, 0)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        USBftStatus = self._libraries.FT_SetTimeouts(self.device_handle, 500, 500)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        USBftStatus = self._libraries.FT_SetLatencyTimer(self.device_handle, 1)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        USBftStatus = self._libraries.FT_SetFlowControl(self.device_handle, 0x200, 0,

                                                        0)  # 500ms read/write timeout

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)



        mask = 0x00

        mode = 0  # set MPSSE mode

        USBftStatus = self._libraries.FT_SetBitMode(self.device_handle, mask, mode)



        mask = 0x00

        mode = 2  # set MPSSE mode

        USBftStatus = self._libraries.FT_SetBitMode(self.device_handle, mask, mode)

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)



        outstr = struct.pack('BBB',  # 0x8B,  # 0x8B set 2232H fast/slow mode

                             0x86,  # setup tck divide

                             0x00,  # div L 12.5Hz

                             0x00)

        self._ft_write(outstr)

                # print(self._ft_read(100))

                # self._ft_write(struct.pack('B',0x80))

                # print(self._ft_read(100))

                # self._ft_write(struct.pack('B',0x81))

                # print(self._ft_read(100))

                # self._ft_write(outstr)

                # print(self._ft_read(100))

                # outstr = struct.pack('BB',0x86,5)

                # self._ft_write(outstr)

                # print(self._ft_read(100))



        outstr = struct.pack('BBB', 0x80, 0x20, 0xFB)  # TRST

        self._ft_write(outstr)

        # self._ft_write(struct.pack('B',0x81))

        # print(self._ft_read(100))

        # print(self._ft_read(100))



        # outstr = struct.pack('BBB',0x82,0,0xFB)

        # self._ft_write(outstr)

        # print(self._ft_read(100))



        self._ft_write(struct.pack('B', 0x85))  # set to 84 for LoopBack

        # self._ft_write(struct.pack("BB", 82, 12)) # send bad command for sync

        # print(self._ft_read(100))

        outstr = self._write_tms('0', self.state.reset)

        self._ft_write(outstr)

        self.state = jtag.states.reset



        for t in self.instr:

            setattr(self, 'w_' + t[0], self.make_write_instr(t))

            setattr(self, 'r_' + t[0], self.make_read_instr(t))



    def make_write_instr(self, instr_tuple):

        def instr():

            return self.write_ir(instr_tuple[3])



        return instr



    def make_read_instr(self, instr_tuple):

        m = re.search(r'com_([0-9]*)', instr_tuple[1])

        if m and m.group(1):

            num_nibbles = int(int(m.group(1)) / 4)



        def instr():

            return self.write_dr('0x' + num_nibbles * '0')



        return instr



    def _write_tdi_bytes(self, byte_list, read=True):

        """Return the command packet needed to write a list of bytes over TDI.

        Bytes are written as-is, so should be LSB first for JTAG.

        """

        if not byte_list:

            return



        if len(byte_list) > 0xFFFF:

            raise Exception("Byte input list is too long!")



        for b in byte_list:

            if b > 0xFF:

                raise Exception("Element in list larger than 8-bit:", hex(b))



        length = len(byte_list) - 1

        length_upper = (length & 0xFF00) >> 8

        length_lower = length & 0x00FF



        byte_str = bytes((b for b in byte_list))



        bytes_written = ctypes.c_int()



        if read:

            opcode = 0x31

        else:

            opcode = 0x11



        # outstr = chr(opcode) + struct.pack('BB', length_lower, length_upper) + byte_str #msb first

        outstr = bytes((opcode, length_lower, length_upper)) + byte_str  # msb first



        return outstr



    def _write_tdi_bits(self, bit_list, read=True):

        """Return the command packet needed to write a list of binary values

        over TDI.

        """

        if not bit_list:

            return



        bit_str = ''.join([str(b) for b in bit_list])

        bit_str = bit_str + '0' * (8 - len(bit_str))



        bit_int = int(bit_str, 2)



        length = len(bit_list)

        if length > 8:

            raise ValueError("Input string longer than 8 bits")



        bytes_written = ctypes.c_int()



        if read:

            opcode = 0x33

        else:

            opcode = 0x13



        # outstr = chr(opcode) + struct.pack('BB', length - 1, bit_int) #msb first

        outstr = bytes((opcode, length - 1, bit_int))  # msb first

        return outstr



    def _write_tms(self, tdi_val, tms, read=False):

        """Return the command packet needed to write a single value to TDI and

        a single or multiple values to TMS.

        tdi_val: string or int: '1', '0', 1, 0

        tms: list, string or jtag.TMSPath: '10010', [1, 0, 0, 1, 0],

        ['1', '0', '0', '1', '0']

        """

        # TODO, make this work for lists longer than 7, maybe recurse?

        # also, formatting could be more consistant



        # Check tdi_val formatting

        tdi_int = int(tdi_val)

        if str(tdi_val) != '0' and str(tdi_val) != '1':

            raise Exception("tdi_val doesn't meet expected format:", tdi_val)



        # Convert tms to list to parse state transitions

        if isinstance(tms, str):

            tms = [int(s) for s in tms]



        tms_len = len(tms)

        if tms_len > 7:

            raise Exception("tms is too long", tms)



        if self.state is not jtag.states.unknown:

            for t in tms:

                # print t, '{:<12}'.format(self.state), '->', self.state[t]

                self.state = self.state[t]



        if isinstance(tms, list) or isinstance(tms, jtag.TMSPath):

            tms = ''.join([str(x) for x in tms])



        tms = tms[::-1]  # LSb first

        tms_int = int(tms, 2)

        byte1 = (tdi_int << 7) | tms_int



        if read:

            opcode = 0x6B

        else:

            opcode = 0x4B



        # outstr = struct.pack('BBB',[opcode, tms_len - 1, byte1])

        outstr = bytes((opcode, tms_len - 1, byte1))

        return outstr



    def _write(self, data, write_state=None, next_state=None, read=True):

        """Default next state is shift-ir/dr"""

        byte_list, bit_list, last_bit, orig = self.to_bytes_bits(data)



        outstr = b''

        if write_state:

            outstr += self._write_tms(0, self.state[write_state])



        tmp = self._write_tdi_bytes(byte_list, read=read)

        if tmp:

            outstr += tmp



        tmp = self._write_tdi_bits(bit_list, read=read)

        if tmp:

            outstr += tmp



        tmp = b''

        if next_state:

            tmp = (self._write_tms(last_bit, self.state[next_state], read=read))

        else:

            tmp = (self._write_tms(last_bit, self.state[self.state], read=read))



        for i in range(1):

            outstr += tmp



        self._ft_write(outstr)



        if read:

            a = self._read(len(byte_list) + (1 if len(bit_list) else 0) + 1)

            rebuilt_data = self._rebuild_read_data(a[0], len(byte_list), len(bit_list))



            return self.bin2hex(rebuilt_data)



    def _read(self, expected_num):

        """Read expected_num bytes from the FTDI chip, once there are that many

        availible in the buffer. Return the raw bytes as a tuple of binary and

        hex strings."""

        bytes_avail = ctypes.c_int()

        while bytes_avail.value != expected_num:

            if bytes_avail.value > expected_num:

                raise Exception("More bytes in buffer than expected!")

            USBftStatus = self._libraries.FT_GetQueueStatus(self.device_handle, ctypes.byref(bytes_avail))

            if USBftStatus:

                raise Exception("USBftStatus was non-zero:", USBftStatus)



        readback = self._ft_read(bytes_avail.value)



        if readback:

            byte_tuple = struct.unpack('B' * len(readback), readback)

            bin_read_str = ''

            hex_read_str = ''

            if byte_tuple:

                bin_read_str = ''.join([bin(byte)[2:].zfill(8) for byte in byte_tuple])

                hex_read_str = ''.join([hex(byte)[2:].zfill(2) for byte in byte_tuple])

        return (bin_read_str, hex_read_str)



    def _flush(self):

        """Flush USB receive buffer"""

        opcode = 0x87

        bytes_written = self._ft_write(struct.pack('B', opcode))

        return bytes_written



    def _ft_write(self, outstr):

        """Low level call to ftdi dll"""

        outstr += b'\x87'

        bytes_written = ctypes.c_int()
        
        USBftStatus = self._libraries.FT_Write(self.device_handle, ctypes.c_char_p(outstr), len(outstr),

                                               ctypes.byref(bytes_written))

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        return bytes_written.value



    def _ft_read(self, numbytes):

        """Low level call to ftdi dll"""

        bytes_read = ctypes.c_int()

        inbuf = ctypes.create_string_buffer(numbytes)

        USBftStatus = self._libraries.FT_Read(self.device_handle, inbuf, numbytes, ctypes.byref(bytes_read))

        if USBftStatus:

            raise Exception("USBftStatus was non-zero:", USBftStatus)

        return inbuf.raw



    def get_id(self):

        """Return the device ID of the part

        State machine transitions current -> reset -> shift_dr -> exit1_dr

        """



        # outstr = self._write_tms('0', j.state.reset)

        outstr = self._write_tms('0', self.state.reset)

        self._ft_write(outstr)

        data = self.write_dr('0x00000000')

        return data



    def reset(self):

        """Return to the reset state"""

        # outstr = self._write_tms('0', j.state.reset)

        outstr = self._write_tms('0', self.state.reset)

        self._ft_write(outstr)

        return



    def idle(self):

        """Go to idle state and clock once"""

        # outstr = self._write_tms('0', j.state.idle.pad(minpause=2))

        outstr = self._write_tms('0', self.state.idle.pad(minpause=2))

        self._ft_write(outstr)

        return



    def write_ir(self, cmd, next_state=jtag.states.idle, read=True):

        """Write cmd while in the shift_ir state, return read back cmd"""



        cmd_readback = self._write(cmd, write_state=jtag.states.shift_ir,

                                   next_state=jtag.states.idle, read=read)

        return (cmd, cmd_readback)



    def write_dr(self, data, next_state=jtag.states.idle, read=True):

        """Write data while in the shift_dr state, return read back data"""



        data_readback = self._write(data, write_state=jtag.states.shift_dr,

                                    next_state=next_state, read=read)

        return (data, data_readback)



    def _nvdv_frame_write(self, addr, data):

        """Write data to the fabric using the JTAG2HOST command."""



        cmd = '0xA0'

        addr = '0x' + hex(addr)[2:].zfill(4)

        data = '0x' + hex(data)[2:].zfill(4)

        pack = '0b' + '1' + self.hex2bin(data)[2:] + self.hex2bin(addr)[2:] + '0000' + '0' + '0'

        temp = self.write_ir(cmd, next_state=jtag.states.idle)

        self.write_dr(pack, next_state=jtag.states.idle)

        pack = '0b' + '1' + self.hex2bin(data)[2:] + self.hex2bin(addr)[2:] + '0000' + '0' + '0'

        temp3 = self.write_dr(pack, next_state=jtag.states.idle)

        return (temp, temp3)



    def nvdv_write(self, addr, data):

        '''

        nvdv register write using JTAG2HOST_INTFC

        This is something specific to Nvidia Test Chip

        :param addr: address in int

        :param data: data in int

        :return: returns data written and observed on tdo

        '''

        res = self._nvdv_frame_write(addr, data)

        res_hex = res[1][1]

        if int(res_hex, 16) & 1 == 1:

            res_int = (int(res_hex, 16) >> 22) & 0xFFFF

            if data != res_int:

                raise Exception(f"Data Written to device: {data} is not matching with readbadck Readback:{res_int}")

            return res_int

        else:

            raise Exception("HOST acknowledgement not received")



    def nvdv_read(self, addr):

        '''

        nvdv register write using JTAG2HOST_INTFC

        This is something specific to Nvidia Test Chip

        :param addr: address in int

        :param data: data in int

        :return: returns data written and observed on tdo, this is just for sanity check comparision

        '''

        res = self._nvdv_frame_read(addr)

        res_hex = res[1][1]

        if int(res_hex, 16) & 1 == 1:

            res_int = (int(res_hex, 16) >> 22) & 0xFFFF

            return res_int

        else:

            raise Exception(f"HOST acknowledgement not received.  read_back {res_hex}")



    def _nvdv_frame_read(self, addr):

        """Write command to the fabric using the JTAG2HOST command and read data."""

        cmd = '0xA0'

        addr = '0x' + hex(addr)[2:].zfill(4)

        data = '0x' + hex(0xAAAA)[2:].zfill(4)

        pack = '0b' + '1' + self.hex2bin(data)[2:] + self.hex2bin(addr)[2:] + '0000' + '1' + '0'

        temp = self.write_ir(cmd, next_state=jtag.states.idle)

        self.write_dr(pack, next_state=jtag.states.idle)

        pack = '0b' + '1' + self.hex2bin(data)[2:] + self.hex2bin(addr)[2:] + '0000' + '1' + '0'

        #self.write_dr(pack)

        temp3 = self.write_dr(pack)

        return (temp, temp3)



    def hex2bin(self, string):

        return '0b' + bin(int(string, 16))[2:].zfill(len(string[2:] * 4))



    def bin2hex(self, string):

        return '0x' + hex(int(string, 2))[2:].zfill(math.ceil((len(string[2:]) - 1) / 4) + 1)



    def to_bytes_bits(self, data):

        """Return a tuple containing a list of bytes, remainder bits, the value of

        the last bit and a reconstruction of the original string.



        data: string representation of a binary or hex number '0xabc', '0b101010'

        e.g. to_bytes_bits('0xabc')        -> ([0xab], [1, 1, 0], 0)

             to_bytes_bits('0b010100011')  -> ([0x51], [], 1)

        """

        byte_list = []

        bit_list = []



        if data[-1] == 'L':

            data = data[:len(data) - 1]



        if data[:2] == '0b':

            base = 2

            data = data[2:]

        elif data[:2] == '0x':

            temp_data = ''

            base = 16

            data = data[2:]

            for i in range(len(data)):

                temp_data = temp_data + bin(int(data[i:(i + 1)], 16))[2:].zfill(4)

            data = temp_data

        else:

            raise ValueError("Data does not match expected format", data)



        length = len(data)



        data = data[::-1]  # data needs to be LSb first



        for i in range(int(length / 8)):

            byte_list.append(int(data[(8 * i):(8 * (i + 1))], 2))

        if -(length % 8):

            bit_list = [int(b, 2) for b in data[-(length % 8):]]



        bit_list_last = None

        if bit_list:

            bit_list_last = bit_list[-1]

            bit_list = bit_list[:-1]

        else:

            bit_list = [int(b, 2) for b in bin(int(byte_list[-1]))[2:].zfill(8)]

            byte_list = byte_list[:-1]

            bit_list_last = bit_list[-1]

            bit_list = bit_list[:-1]



        if base == 2:

            original = '0b' + (''.join([bin(b)[2:].zfill(8) for b in byte_list]) + \

                               ''.join([bin(b)[2:] for b in bit_list]) + \

                               bin(bit_list_last)[2:])[::-1]



        elif base == 16:

            original = hex(int((''.join([bin(b)[2:].zfill(8) for b in byte_list]) + \

                                ''.join([bin(b)[2:] for b in bit_list]) + \

                                bin(bit_list_last)[2:])[::-1], 2))



        if original[len(original) - 1] == 'L':

            original = original[:len(original) - 1]



        return (byte_list, bit_list, bit_list_last, original)



    def _rebuild_read_data(self, data, num_bytes, num_bits, tms_bit=True):

        """Rebuild a binary string received from the FTDI chip

        Return the reconstructed string MSb first.

        String will be formatted as follows:

            num_bytes of valid bytes



            8 - num_bits of garbage

            num_bits of valid bits



            if tms_bit 1 else 0 bit

            8 - (tms_bit 1 else 0 bit) of garbage

        """

        temp_data = ''



        if num_bytes > 0:

            temp_data = data[0:8 * num_bytes]

            data = data[8 * num_bytes:]



        if num_bits > 0:

            # bits are shifted out msb first, and shifted in from the lsb

            temp_data += data[8 - num_bits: 8]

            if len(data) > 8:

                data = data[8:]



        if tms_bit:

            temp_data += data[0]  # bit position 0 contains the value of TDO when TDI is clocked in

            data = data[1:]



        return '0b' + temp_data[::-1]



    def shift(self, bits):

        for i in range(bits):

            mask = int('0x' + 'f' * math.ceil(bits / 4), 16)

            yield "0x%s" % ('0' * math.ceil(bits / 4) + "%x" % (1 << i & mask))[-math.ceil(bits / 4):]



    def test_to_bytes_bits(self):

        tests = ['0b0101010101',

                 '0b101010101111111',

                 '0x11223344',

                 '0b00010001001000100011001101000100',

                 '0xf102030f102030f102030f1020304f1020304f1020304f1020304444f1020304f',

                 '0xabc',

                 '0xabcd',

                 '0xabcde',

                 '0xabcdef',

                 '0b010100011']



        for t in tests:

            byte_list, bit_list, bit_list_last, original = self.to_bytes_bits(t)

            if t[:2] == '0b':

                if original == t:

                    pass

                else:

                    raise Exception("Test failed")

            elif t[:2] == '0x':

                if original == t:

                    pass

                else:

                    raise Exception("Test failed")

        return



    def read_mpsse_settings(self):

        """Read expected_num bytes from the FTDI chip, once there are that many

        availible in the buffer. Return the raw bytes as a tuple of binary and

        hex strings."""

        for i in range(0xFF):

            self._ft_write(struct.pack('BB', 0x85, i))  # set to 84 for LoopBack

            bytes_avail = ctypes.c_int()

            USBftStatus = self._libraries.FT_GetQueueStatus(self.device_handle, ctypes.byref(bytes_avail))

            if USBftStatus:

                raise Exception("USBftStatus was non-zero:", USBftStatus)



            readback = self._ft_read(bytes_avail.value)

            print(i, readback)


def parse():
    parser = argparse.ArgumentParser()
    parser_required_args = parser.add_argument_group('Required arguments')
    parser_required_args.add_argument("-d", "--device-index",
                                      help="index of NVJTAG device",
                                      required=True,
                                      type=int)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('-w',  metavar="DATA", help="write operation", type=lambda x: int(x,0))
    group.add_argument('-r', help="read operation", action='store_true')
    parser_required_args.add_argument("-a", "--address",
                                      help="address to write/read from",
                                      required=True,
                                      type=lambda x: int(x,0))

    return parser.parse_args()


if __name__ == "__main__":
    args = parse()
    j = NVDV_Jtag(args.device_index)
    j.reset()
    if args.r:
        print(hex(j.nvdv_read(args.address)), end='\0')
    else:
        j.nvdv_write(args.address, args.w)

   

