Source code for yzlite.core.tflite_micro.tflite_micro_model
from typing import List, Dict
import re
import collections
import logging
import numpy as np
import msgpack
from yzlite.utils.string_formatting import format_units
from .tflite_micro_accelerator import TfliteMicroAccelerator
[docs]class TfliteMicroModelDetails:
"""TF-Lite Micro Model Details"""
[docs] def __init__(self, wrapper_details:dict):
self._details:dict = wrapper_details
@property
def name(self) -> str:
"""Name of model"""
return self._details['name']
@property
def version(self)-> int:
"""Version of model"""
return self._details['version']
@property
def date(self)-> str:
"""Date of model in ISO8601 format"""
return self._details['date']
@property
def description(self)-> str:
"""Description of model"""
return self._details['description']
@property
def classes(self) -> List[str]:
"""List of class labels"""
return self._details['classes']
@property
def hash(self)-> str:
"""Unique hash of model data"""
return self._details['hash']
@property
def accelerator(self)-> str:
"""Accelerater kernels loaded into TFLM interpreter"""
return self._details['accelerator']
@property
def runtime_memory_size(self)-> int:
"""Total amount of RAM required at runtime to run model"""
return self._details['runtime_memory_size']
def __str__(self):
s = ''
s += f"Name: {self.name}\n"
s += f"Version: {self.version}\n"
s += f"Date: {self.date}\n"
s += f"Description: {self.description}\n"
s += f"Hash: {self.hash}\n"
s += f"Accelerator: {self.accelerator}\n"
s += f"Classes: {', '.join(self.classes)}\n"
s += f"Total runtime memory: {format_units(self.runtime_memory_size)}Bytes\n"
return s
[docs]class TfliteMicroLayerError:
WARNING_RE = re.compile(r'.*Op(\d+)-(\S+)\ not supported: (.*)')
[docs] def __init__(self, index:int, name:str, msg:str):
self._index = index
self._name = name
self._msg = msg
@property
def index(self) -> int:
"""The layer's index"""
return self._index
@property
def name(self) -> int:
"""The layer's full name"""
return self._name
@property
def msg(self) -> int:
"""The layer's error msg"""
return self._msg
@staticmethod
def _parse_error_log(msg):
match = TfliteMicroLayerError.WARNING_RE.match(msg)
if not match:
return None
return TfliteMicroLayerError(
index=int(match.group(1)),
name=msg.split()[0],
msg=match.group(3)
)
def __str__(self) -> str:
return f'{self.name}: {self.msg}'
[docs]class TfliteMicroProfiledLayerResult(collections.defaultdict):
"""Result of profiling a specifc model layer"""
[docs] def __init__(self, result:dict):
collections.defaultdict.__init__(self, lambda: 0, result)
@property
def name(self) -> int:
"""Name of layer"""
return self['name']
@property
def macs(self) -> int:
"""Number of Multiple-Accumulate operations required by this profiler"""
return self['macs']
@property
def ops(self) -> int:
"""Number of operations required by this profiler"""
return self['ops']
@property
def accelerator_cycles(self) -> int:
"""Number of accelerator clock cycles required by this profiler"""
return self['accelerator_cycles']
@property
def time(self) -> float:
"""Time in seconds required by this layer"""
return self['time']
@property
def cpu_cycles(self) -> float:
"""Number of CPU clock cycles required by this layer"""
return self['cpu_cycles']
@property
def energy(self) -> float:
"""Energy in Joules required by this layer
The energy is relative to the 'baseline' energy (i.e. energy used while the device was idling)"""
return self['energy']
[docs]class TfliteMicroModel:
"""This class wrappers the TF-Lite Micro interpreter loaded with a .tflite model"""
[docs] def __init__(
self,
tflm_wrapper,
tflm_accelerator:TfliteMicroAccelerator,
flatbuffer_data:bytes,
enable_profiler:bool=False,
enable_tensor_recorder:bool=False,
force_buffer_overlap:bool=False,
runtime_buffer_size:int=0,
):
# pylint: disable=protected-access
from .tflite_micro import TfliteMicro
# Hide non-critical errors while loading the model
# Any layer errors will be parsed separately
logger = TfliteMicro.get_logger()
saved_log_level = logger.level
logger.setLevel(logging.CRITICAL)
try:
TfliteMicro._clear_logged_errors()
accelerator_wrapper = None if tflm_accelerator is None else tflm_accelerator.accelerator_wrapper
self._model_wrapper = tflm_wrapper.TfliteMicroModelWrapper()
if not self._model_wrapper.load(
flatbuffer_data,
accelerator_wrapper,
enable_profiler,
enable_tensor_recorder,
force_buffer_overlap,
runtime_buffer_size
):
raise Exception(
f'Failed to load model, additional info:\n{TfliteMicro._get_logged_errors_str()}'
)
finally:
logger.setLevel(saved_log_level)
self._tflm_accelerator = tflm_accelerator
self._layer_errors:List[TfliteMicroLayerError] = []
for msg in TfliteMicro._get_logged_errors():
err = TfliteMicroLayerError._parse_error_log(msg)
if err:
self._layer_errors.append(err)
@property
def accelerator(self) -> TfliteMicroAccelerator:
"""Reference to hardware accelerator used by model"""
return self._tflm_accelerator
@property
def layer_errors(self) -> List[TfliteMicroLayerError]:
"""List of error messages triggered by kernels while loading the model.
Typically, these errors indicate that a given model layer is not supported by
a hardware accelerator and had to fallback to a default kernel implementation.
"""
return self._layer_errors
@property
def details(self) -> TfliteMicroModelDetails:
"""Return details about loaded model"""
return TfliteMicroModelDetails(self._model_wrapper.get_details())
@property
def input_size(self) -> int:
"""Number of input tensors"""
return self._model_wrapper.get_input_size()
@property
def output_size(self) -> int:
"""Number of output tensors"""
return self._model_wrapper.get_output_size()
[docs] def output(self, index = 0) -> np.ndarray:
"""Return a reference to a model output tensor's data
"""
if index >= self.output_size:
raise IndexError(f'Output index: {index} >= max size: { self.output_size}')
return self._model_wrapper.get_output(index)
[docs] def invoke(self):
"""Invoke the model to execute one inference"""
# pylint: disable=protected-access
from .tflite_micro import TfliteMicro
TfliteMicro._clear_logged_errors()
if not self._model_wrapper.invoke():
raise Exception(f'Failed to invoke model, additional info:\n{TfliteMicro._get_logged_errors_str()}')
@property
def is_profiler_enabled(self) -> bool:
"""Return if the profiler is enabled"""
return self._model_wrapper.is_profiler_enabled()
[docs] def get_profiling_results(self) -> List[TfliteMicroProfiledLayerResult]:
"""Return the profiling results of each model layer
Returns:
A list where each entry contains the profiling results
of the associated model layer
"""
retval = []
results = self._model_wrapper.get_profiling_results()
for e in results:
retval.append(TfliteMicroProfiledLayerResult(e))
return retval
@property
def is_tensor_recorder_enabled(self) -> bool:
"""Return if the tensor recorder is enabled """
return self._model_wrapper.is_tensor_recorder_enabled()
[docs] def get_recorded_data(self) -> List[Dict[str,object]]:
"""Return the recorded contents of each model layers' tensors
Returns:
A list where each entry contains the input/output tensors
of the associated model layer
"""
results_bin = self._model_wrapper.get_recorded_data()
if results_bin is None:
raise RuntimeError('Failed to retrieve recorded model data from Tensorflow-Lite Micro')
try:
recorded_data = msgpack.loads(results_bin)
except Exception as e:
raise RuntimeError(f'Failed to parse recorded model binary data, err: {e}')
return recorded_data
[docs] def get_layer_error(self, index:int) -> TfliteMicroLayerError:
"""Return the TfliteMicroLayerError at the given layer index if found else return None"""
for err in self._layer_errors:
if err.index == index:
return err
return None
def __str__(self) -> str:
return f'{self.details}'