# Reference: STSW-STUSB004 Documentation
# Accession: G00084

import logging
import argparse
import asyncio

from ...interface.i2c_initiator_deprecated import I2CInitiatorApplet
from ... import *


class StUsb4500NvmInterface:
    def __init__(self, interface, logger, i2c_address):
        self.lower     = interface
        self._logger   = logger
        self._level    = logging.DEBUG if self._logger.name == __name__ else logging.TRACE
        self._i2c_addr = i2c_address

    def _log(self, message, *args):
        self._logger.log(self._level, "stusb4500_nvm: " + message, *args)

    FTP_DATA_BASE = 0x53
    FTP_KEY       = 0x95
    FTP_CTRL_0    = 0x96
    FTP_CTRL_1    = 0x97

    FTP_KEY_VALUE = 0x47

    async def _read_regs(self, addr, length):
        self._log("i2c-addr=%#02x reg-addr=%#02x", self._i2c_addr, addr)
        result = await self.lower.write(self._i2c_addr, addr.to_bytes(1, "little"))
        if result is False:
            self._log("unacked")
            return None

        self._log("read=")
        chunk = await self.lower.read(self._i2c_addr, length, stop=True)
        if chunk is None:
            self._log("unacked")
        else:
            self._log("<%s>", chunk.hex())

        return list(chunk)

    async def _write_regs(self, addr, data):
        self._log("i2c-addr=%#02x reg-addr=%#02x", self._i2c_addr, addr)

        if not isinstance(data, list):
            data = [data]
        chunk = addr.to_bytes(1, "little") + bytes(data)
        self._log("write=<%s>", chunk[1:].hex())

        result = await self.lower.write(self._i2c_addr, chunk, stop=True)
        if result is False:
            self._log("unacked")
            return None

        return True

    async def _exec_cmd(self, cmd, c0_lsb=0):
        result =            await self._write_regs(self.FTP_CTRL_1, cmd)
        result = result and await self._write_regs(self.FTP_CTRL_0, 0x50 | c0_lsb)
        await asyncio.sleep(0.005)
        return result

    async def enable(self):
        result =            await self._write_regs(self.FTP_KEY, self.FTP_KEY_VALUE)
        result = result and await self._write_regs(self.FTP_DATA_BASE, 0x00)
        result = result and await self._write_regs(self.FTP_CTRL_0, 0x40)
        result = result and await self._write_regs(self.FTP_CTRL_0, 0x00)
        await asyncio.sleep(0.001)
        result = result and await self._write_regs(self.FTP_CTRL_0, 0x40)
        return result

    async def disable(self):
        result =            await self._write_regs(self.FTP_CTRL_0, [0x40, 0x00])
        result = result and await self._write_regs(self.FTP_KEY, 0x00)
        return result

    async def erase(self):
        result =            await self._exec_cmd(0xFA)
        result = result and await self._exec_cmd(0x07)
        result = result and await self._exec_cmd(0x05)
        return result

    async def read_sector(self, sector):
        result = await self._exec_cmd(0x00, c0_lsb=sector)
        if not result:
            return None
        return await self._read_regs(self.FTP_DATA_BASE, 8)

    async def write_sector(self, sector, data):
        result =            await self._write_regs(self.FTP_DATA_BASE, data)
        result = result and await self._exec_cmd(0x01)
        result = result and await self._exec_cmd(0x06, c0_lsb=sector)
        return result


class StUsb4500NvmApplet(I2CInitiatorApplet):
    logger = logging.getLogger(__name__)
    help = "read and write STUSB4500 NVM"
    description = """
    Read and write the NVM inside the STUSB4500 USB-PD initiator.

    The file format used is the .txt generated by the STSW-STUSB002 GUI utility provided by ST.
    """

    @classmethod
    def add_run_arguments(cls, parser, access):
        super().add_run_arguments(parser, access)

        def address(arg):
            return int(arg, 0)

        parser.add_argument(
            "-A", "--i2c-address", type=address, metavar="I2C-ADDR", default=0b0101000,
            help="I²C address of the STUSB4500; typically 0b0101000 "
                 "(default: 0b0101000)")

    async def run(self, device, args):
        i2c_iface = await super().run(device, args)
        return StUsb4500NvmInterface(
            i2c_iface, self.logger, args.i2c_address)

    @classmethod
    def add_interact_arguments(cls, parser):
        p_operation = parser.add_subparsers(dest="operation", metavar="OPERATION", required=True)

        p_read = p_operation.add_parser(
            "read", help="read NVM")
        p_read.add_argument(
            "-f", "--file", metavar="FILENAME", type=argparse.FileType("w"),
            help="write NVM contents to FILENAME")

        p_write = p_operation.add_parser(
            "write", help="write NVM")
        p_write.add_argument(
            "-f", "--file", metavar="FILENAME", type=argparse.FileType("r"),
            help="write NVM with contents of FILENAME")

    def _read_data_file(self, fh):
        data = {}
        for l in fh.readlines():
            if l.startswith("0x"):
                a = int(l[0:4], 16)
                d = [int(x, 16) for x in l[5:].split()]
                data[a] = d
        return data

    def _write_data_file(self, fh, data):
        for a in sorted(data.keys()):
            fh.write(f"0x{a:02X}:\t" + "\t".join([f"0x{x:02X}" for x in data[a]]) + "\t\r\n")
        fh.write("\r\n")
        fh.write("# NVM memory map : STUSBxx \r\n")
        fh.write("\r\n")

    async def interact(self, device, args, iface):
        if args.operation == "read":
            success = await iface.enable()
            if not success:
                raise GlasgowAppletError("Could not enable NVM access")

            data = {}
            for sector in range(5):
                d = await iface.read_sector(sector)
                if d is None:
                    raise GlasgowAppletError(f"Could not read NVM sector {sector:d}")

                data[0xC0 + 8 * sector] = d

            success = await iface.disable()
            if not success:
                raise GlasgowAppletError("Could not disable NVM access")

            self._write_data_file(args.file, data)

        if args.operation == "write":
            data = self._read_data_file(args.file)

            success = await iface.enable()
            if not success:
                raise GlasgowAppletError("Could not enable NVM access")

            success = await iface.erase()
            if not success:
                raise GlasgowAppletError("Could not erase NVM")

            for sector in range(5):
                k = 0xC0 + 8 * sector
                if k not in data:
                    continue

                success = await iface.write_sector(sector, data[k])
                if not success:
                    raise GlasgowAppletError(f"Could not write NVM sector {sector:d}")

            success = await iface.disable()
            if not success:
                raise GlasgowAppletError("Could not disable NVM access")
