Commit 3adfe119 by tatsukiishikawa

Adding lab 5, 6, 7

parent 0bd9f570
Showing with 4887 additions and 0 deletions
File added
cmake_minimum_required(VERSION 3.12)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
include(pico_sdk_import.cmake)
# project(pico-tflite-inference-test)
project(pico_tflite_inference_test C CXX ASM)
# initialize the Pico SDK
pico_sdk_init()
add_executable(main main.cpp) # main function to run.
add_executable(main_arena_size_test arena_size_test.cpp) # to check arena size for a given model.
add_executable(main_inference_test inference_test.cpp) # to test hand written digit recognition with static test data.
target_include_directories(main
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
target_include_directories(main_arena_size_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
target_include_directories(main_inference_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
main
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
set_target_properties(
main_arena_size_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
set_target_properties(
main_inference_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
add_subdirectory(src)
add_subdirectory(models)
add_subdirectory(lib/pico-tflmicro)
add_subdirectory(lib/config)
add_subdirectory(lib/lcd)
add_subdirectory(lib/font)
include_directories(./src)
include_directories(./models)
include_directories(./lib/pico-tflmicro)
include_directories(./lib/config)
include_directories(./lib/lcd)
include_directories(./lib/font)
target_link_libraries(
main
src
lcd
font
config
models
pico_stdlib
hardware_spi
hardware_pwm
pico_multicore
hardware_adc
pico-tflmicro
)
target_link_libraries(
main_arena_size_test
src
config
models
pico_stdlib
hardware_spi
hardware_pwm
pico_multicore
hardware_adc
pico-tflmicro
)
target_link_libraries(
main_inference_test
src
config
models
pico_stdlib
hardware_spi
hardware_pwm
pico_multicore
hardware_adc
pico-tflmicro
)
pico_set_program_flash(program1 0x140000)
pico_set_program_flash(program2 0x180000)
# enable usb and uart output
pico_enable_stdio_usb(main 1)
pico_enable_stdio_uart(main 1)
pico_enable_stdio_usb(main_arena_size_test 1)
pico_enable_stdio_uart(main_arena_size_test 1)
pico_enable_stdio_usb(main_inference_test 1)
pico_enable_stdio_uart(main_inference_test 1)
# create map/bin/hex/uf2 file etc.
pico_add_extra_outputs(main)
pico_add_extra_outputs(main_arena_size_test)
pico_add_extra_outputs(main_inference_test)
"""
Lab 5: Neural Network Model Quantization
In this lab, you will learn how to quantize a trained neural network model using TensorFlow Lite.
Import the exported model and convert it to a TensorFlow Lite model with quantization.
Options of quantization techniques are:
- weight pruning
- weight clustering
- uint8 quantization
"""
#include <cmath>
#include<iostream>
#include <cstdlib>
#include <iostream>
#include <stdio.h>
#include "pico/stdlib.h"
#include "pico/multicore.h"
#include "pico/sync.h"
#include "hardware/watchdog.h"
#include "hardware/sync.h"
#include "DEV_Config.h"
#include "model.h"
#include "model_settings.h"
#include "mnist_model_data.h"
using namespace std;
Model ml_model;
int main() {
System_Init();
sleep_ms(1000);
// initialize ML model
if (!ml_model.setup()) {
printf("Failed to initialize ML model!\n");
return -1;
}
printf("Model initialized\n");
uint8_t* test_image_input = ml_model.input_data();
if (test_image_input == nullptr) {
printf("Cannot set input\n");
return -1;
}
int byte_size = ml_model.byte_size();
if (!byte_size) {
printf("Byte size not found\n");
return -1;
}
while(1) {
printf("The tensor arena size: %d\n", ml_model.interpreter->arena_used_bytes());
sleep_ms(1000);
}
return 0;
}
\ No newline at end of file
#include <cmath>
#include<iostream>
#include <cstdlib>
#include <iostream>
#include <stdio.h>
#include "pico/stdlib.h"
#include "pico/multicore.h"
#include "pico/sync.h"
#include "hardware/watchdog.h"
#include "hardware/sync.h"
#include "DEV_Config.h"
#include "inference.h"
using namespace std;
int main(void) {
System_Init();
sleep_ms(5000);
inference_test();
return 0;
}
\ No newline at end of file
# 查找当前目录下的所有源文件
# 并将名称保存到 DIR_Config_SRCS 变量
aux_source_directory(. DIR_CONFIG_SRCS)
# 生成链接库
add_library(config ${DIR_CONFIG_SRCS})
target_link_libraries(config PUBLIC pico_stdlib hardware_spi hardware_pwm)
/*****************************************************************************
* | File : DEV_Config.c
* | Author : Waveshare team
* | Function : Show SDcard BMP picto LCD
* | Info :
* Provide the hardware underlying interface
*----------------
* | This version: V1.0
* | Date : 2018-01-11
* | Info : Basic version
*
******************************************************************************/
#include "DEV_Config.h"
#include "pico/stdlib.h"
void DEV_Digital_Write(UWORD Pin, UBYTE Value)
{
gpio_put(Pin, Value);
}
UBYTE DEV_Digital_Read(UWORD Pin)
{
return gpio_get(Pin);
}
/**
* GPIO Mode
**/
void DEV_GPIO_Mode(UWORD Pin, UWORD Mode)
{
gpio_init(Pin);
if(Mode == 0 || Mode == GPIO_IN) {
gpio_set_dir(Pin, GPIO_IN);
} else {
gpio_set_dir(Pin, GPIO_OUT);
}
}
void DEV_GPIO_Init(void)
{
DEV_GPIO_Mode(LCD_RST_PIN,GPIO_OUT);
DEV_GPIO_Mode(LCD_DC_PIN, GPIO_OUT);
//DEV_GPIO_Mode(LCD_BKL_PIN, GPIO_OUT);
DEV_GPIO_Mode(LCD_CS_PIN, GPIO_OUT);
DEV_GPIO_Mode(TP_CS_PIN,GPIO_OUT);
DEV_GPIO_Mode(TP_IRQ_PIN,GPIO_IN);
DEV_GPIO_Mode(SD_CS_PIN,GPIO_OUT);
//gpio_set_pulls(TP_IRQ_PIN,true,false);
DEV_Digital_Write(TP_CS_PIN, 1);
DEV_Digital_Write(LCD_CS_PIN, 1);
//DEV_Digital_Write(LCD_BKL_PIN, 0);
DEV_Digital_Write(SD_CS_PIN, 1);
gpio_set_function(LCD_BKL_PIN, GPIO_FUNC_PWM);
}
/********************************************************************************
function: System Init
note:
Initialize the communication method
********************************************************************************/
uint8_t System_Init(void)
{
stdio_init_all();
DEV_GPIO_Init();
spi_init(SPI_PORT,5000000);
gpio_set_function(LCD_CLK_PIN,GPIO_FUNC_SPI);
gpio_set_function(LCD_MOSI_PIN,GPIO_FUNC_SPI);
gpio_set_function(LCD_MISO_PIN,GPIO_FUNC_SPI);
return 0;
}
void System_Exit(void)
{
}
/*********************************************
function: Hardware interface
note:
SPI4W_Write_Byte(value) :
Register hardware SPI
*********************************************/
uint8_t SPI4W_Write_Byte(uint8_t value)
{
uint8_t rxDat;
spi_write_read_blocking(spi1,&value,&rxDat,1);
return rxDat;
}
uint8_t SPI4W_Read_Byte(uint8_t value)
{
return SPI4W_Write_Byte(value);
}
/********************************************************************************
function: Delay function
note:
Driver_Delay_ms(xms) : Delay x ms
Driver_Delay_us(xus) : Delay x us
********************************************************************************/
void Driver_Delay_ms(uint32_t xms)
{
sleep_ms(xms);
}
void Driver_Delay_us(uint32_t xus)
{
int j;
for(j=xus; j > 0; j--);
}
/*****************************************************************************
* | File : DEV_Config.c
* | Author : Waveshare team
* | Function : GPIO Function
* | Info :
* Provide the hardware underlying interface
*----------------
* | This version: V1.0
* | Date : 2018-01-11
* | Info : Basic version
*
******************************************************************************/
#ifndef TFLITE_INFERENCE_TEST_DEV_CONFIG_H_
#define TFLITE_INFERENCE_TEST_DEV_CONFIG_H_
#ifdef __cplusplus
extern "C" {
#endif
#include "pico/stdlib.h"
#include "hardware/spi.h"
#include "hardware/pwm.h"
#include "stdio.h"
#define UBYTE uint8_t
#define UWORD uint16_t
#define UDOUBLE uint32_t
#define LCD_RST_PIN 15
#define LCD_DC_PIN 8
#define LCD_CS_PIN 9
#define LCD_CLK_PIN 10
#define LCD_BKL_PIN 13
#define LCD_MOSI_PIN 11
#define LCD_MISO_PIN 12
#define TP_CS_PIN 16
#define TP_IRQ_PIN 17
#define SD_CS_PIN 22
#define INPUT_IMAGE_SIZE 28
#define MULTICORE_RUN_INFERENCE_FLAG 123
#define UNKNOWN_PREDICTION 100
#define SPI_PORT spi1
#define MAX_BMP_FILES 25
/*------------------------------------------------------------------------------------------------------*/
void DEV_Digital_Write(UWORD Pin, UBYTE Value);
UBYTE DEV_Digital_Read(UWORD Pin);
void DEV_GPIO_Mode(UWORD Pin, UWORD Mode);
void DEV_GPIO_Init(void);
uint8_t System_Init(void);
void System_Exit(void);
uint8_t SPI4W_Write_Byte(uint8_t value);
uint8_t SPI4W_Read_Byte(uint8_t value);
void Driver_Delay_ms(uint32_t xms);
void Driver_Delay_us(uint32_t xus);
#ifdef __cplusplus
}
#endif
#endif // TFLITE_INFERENCE_TEST_DEV_CONFIG_H_
\ No newline at end of file
aux_source_directory(. DIR_font_SRCS)
add_library(font ${DIR_font_SRCS})
target_link_libraries(font PUBLIC)
/**
******************************************************************************
* @file fonts.h
* @author MCD Application Team
* @version V1.0.0
* @date 18-February-2014
* @brief Header for fonts.c file
******************************************************************************
* @attention
*
* <h2><center>&copy; COPYRIGHT(c) 2014 STMicroelectronics</center></h2>
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* 3. Neither the name of STMicroelectronics nor the names of its contributors
* may be used to endorse or promote products derived from this software
* without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************
*/
/* Define to prevent recursive inclusion -------------------------------------*/
#ifndef TFLITE_INFERENCE_TEST_FONT_H_
#define TFLITE_INFERENCE_TEST_FONT_H_
/* Max size of bitmap will based on a font24 (17x24) */
#define MAX_HEIGHT_FONT 24
#define MAX_WIDTH_FONT 17
#define OFFSET_BITMAP 54
#ifdef __cplusplus
extern "C" {
#endif
/* Includes ------------------------------------------------------------------*/
#include <stdint.h>
typedef struct _tFont
{
const uint8_t *table;
uint16_t Width;
uint16_t Height;
} sFONT;
extern sFONT Font24;
extern sFONT Font20;
extern sFONT Font16;
extern sFONT Font12;
extern sFONT Font8;
#ifdef __cplusplus
}
#endif
#endif /* __FONTS_H */ // TFLITE_INFERENCE_TEST_FONT_H_
/************************ (C) COPYRIGHT STMicroelectronics *****END OF FILE****/
aux_source_directory(. DIR_LCD_SRCS)
include_directories(../config)
include_directories(../font)
include_directories(../../src)
add_library(lcd ${DIR_LCD_SRCS})
target_link_libraries(
lcd
PUBLIC
config
font
src
pico_stdlib
pico_multicore
)
/*****************************************************************************
* | File : LCD_Driver.h
* | Author : Waveshare team
* | Function : ILI9486 Drive function
* | Info :
* Image scanning:
* Please use progressive scanning to generate images or fonts
*----------------
* | This version: V1.0
* | Date : 2018-01-11
* | Info : Basic version
*
******************************************************************************/
/**************************Intermediate driver layer**************************/
#ifndef TFLITE_INFERENCE_TEST_LCD_DRIVER_H_
#define TFLITE_INFERENCE_TEST_LCD_DRIVER_H_
#ifdef __cplusplus
extern "C" {
#endif
#include "DEV_Config.h"
#define LCD_2_8 0x52
#define LCD_3_5 0x00
#define COLOR uint16_t //The variable type of the color (unsigned short)
#define POINT uint16_t //The type of coordinate (unsigned short)
#define LENGTH uint16_t //The type of coordinate (unsigned short)
/********************************************************************************
function:
Define the full screen height length of the display
********************************************************************************/
#define LCD_X_MAXPIXEL 480 //LCD width maximum memory
#define LCD_Y_MAXPIXEL 320 //LCD height maximum memory
#define LCD_X 0
#define LCD_Y 0
#define LCD_3_5_WIDTH (LCD_X_MAXPIXEL - 2 * LCD_X) //LCD width
#define LCD_3_5_HEIGHT LCD_Y_MAXPIXEL //LCD height
#define LCD_2_8_WIDTH 240 //LCD width
#define LCD_2_8_HEIGHT 320
/********************************************************************************
function:
scanning method
********************************************************************************/
typedef enum {
L2R_U2D = 0, //The display interface is displayed , left to right, up to down
L2R_D2U ,
R2L_U2D ,
R2L_D2U ,
U2D_L2R ,
U2D_R2L ,
D2U_L2R ,
D2U_R2L ,
} LCD_SCAN_DIR;
#define SCAN_DIR_DFT D2U_L2R //Default scan direction = L2R_U2D
/********************************************************************************
function:
Defines the total number of rows in the display area
********************************************************************************/
typedef struct {
LENGTH LCD_Dis_Column; //COLUMN
LENGTH LCD_Dis_Page; //PAGE
LCD_SCAN_DIR LCD_Scan_Dir;
POINT LCD_X_Adjust; //LCD x actual display position calibration
POINT LCD_Y_Adjust; //LCD y actual display position calibration
} LCD_DIS;
/********************************************************************************
function:
Macro definition variable name
********************************************************************************/
void LCD_Init(LCD_SCAN_DIR LCD_ScanDir, uint16_t LCD_BLval);
void LCD_SetGramScanWay(LCD_SCAN_DIR Scan_dir);
void LCD_WriteReg(uint8_t Reg);
void LCD_WriteData(uint16_t Data);
void LCD_SetWindow(POINT Xstart, POINT Ystart, POINT Xend, POINT Yend);
void LCD_SetCursor(POINT Xpoint, POINT Ypoint);
void LCD_SetColor(COLOR Color ,POINT Xpoint, POINT Ypoint);
void LCD_SetPointlColor(POINT Xpoint, POINT Ypoint, COLOR Color);
void LCD_SetArealColor(POINT Xstart, POINT Ystart, POINT Xend, POINT Yend,COLOR Color);
void LCD_Clear(COLOR Color);
uint8_t LCD_Read_Id(void);
void LCD_SetBackLight(uint16_t value);
#ifdef __cplusplus
}
#endif
#endif // TFLITE_INFERENCE_TEST_LCD_DRIVER_H_
/*****************************************************************************
* | File : LCD_GUI.h
* | Author : Waveshare team
* | Function : Achieve drawing: draw points, lines, boxes, circles and
* their size, solid dotted line, solid rectangle hollow
* rectangle, solid circle hollow circle.
* | Info :
* Achieve display characters: Display a single character, string, number
* Achieve time display: adaptive size display time minutes and seconds
*----------------
* | This version: V1.0
* | Date : 2017-08-16
* | Info : Basic version
*
******************************************************************************/
/****************************Upper application layer**************************/
#ifndef TFLITE_INFERENCE_TEST_LCD_GUI_H_
#define TFLITE_INFERENCE_TEST_LCD_GUI_H_
#ifdef __cplusplus
extern "C" {
#endif
#include "LCD_Driver.h"
#include "fonts.h"
#define LOW_Speed_Show 0
#define HIGH_Speed_Show 1
/********************************************************************************
function:
dot pixel
********************************************************************************/
typedef enum {
DOT_PIXEL_1X1 = 1, // dot pixel 1 x 1
DOT_PIXEL_2X2 , // dot pixel 2 X 2
DOT_PIXEL_3X3 , // dot pixel 3 X 3
DOT_PIXEL_4X4 , // dot pixel 4 X 4
DOT_PIXEL_5X5 , // dot pixel 5 X 5
DOT_PIXEL_6X6 , // dot pixel 6 X 6
DOT_PIXEL_7X7 , // dot pixel 7 X 7
DOT_PIXEL_8X8 , // dot pixel 8 X 8
} DOT_PIXEL;
#define DOT_PIXEL_DFT DOT_PIXEL_1X1 //Default dot pilex
/********************************************************************************
function:
dot Fill style
********************************************************************************/
typedef enum {
DOT_FILL_AROUND = 1, // dot pixel 1 x 1
DOT_FILL_RIGHTUP , // dot pixel 2 X 2
} DOT_STYLE;
#define DOT_STYLE_DFT DOT_FILL_AROUND //Default dot pilex
/********************************************************************************
function:
solid line and dotted line
********************************************************************************/
typedef enum {
LINE_SOLID = 0,
LINE_DOTTED,
} LINE_STYLE;
/********************************************************************************
function:
DRAW Internal fill
********************************************************************************/
typedef enum {
DRAW_EMPTY = 0,
DRAW_FULL,
} DRAW_FILL;
/********************************************************************************
function:
time
********************************************************************************/
typedef struct {
uint16_t Year; //0000
uint8_t Month; //1 - 12
uint8_t Day; //1 - 30
uint8_t Hour; //0 - 23
uint8_t Min; //0 - 59
uint8_t Sec; //0 - 59
} DEV_TIME;
extern DEV_TIME sDev_time;
/********************************************************************************
function:
Defines commonly used colors for the display
********************************************************************************/
#define LCD_BACKGROUND WHITE //Default background color
#define FONT_BACKGROUND WHITE //Default font background color
#define FONT_FOREGROUND GRED //Default font foreground color
#define WHITE 0xFFFF
#define BLACK 0x0000
#define BLUE 0x001F
#define BRED 0XF81F
#define GRED 0XFFE0
#define GBLUE 0X07FF
#define RED 0xF800
#define MAGENTA 0xF81F
#define GREEN 0x07E0
#define CYAN 0x7FFF
#define YELLOW 0xFFE0
#define BROWN 0XBC40
#define BRRED 0XFC07
#define GRAY 0X8430
/********************************************************************************
function:
Macro definition variable name
********************************************************************************/
//Clear
void GUI_Clear(COLOR Color);
//Drawing
void GUI_DrawPoint(POINT Xpoint, POINT Ypoint, COLOR Color, DOT_PIXEL Dot_Pixel, DOT_STYLE Dot_FillWay);
void GUI_DrawLine(POINT Xstart, POINT Ystart, POINT Xend, POINT Yend, COLOR Color, LINE_STYLE Line_Style, DOT_PIXEL Dot_Pixel);
void GUI_DrawRectangle(POINT Xstart, POINT Ystart, POINT Xend, POINT Yend, COLOR Color, DRAW_FILL Filled , DOT_PIXEL Dot_Pixel );
void GUI_DrawCircle(POINT X_Center, POINT Y_Center, LENGTH Radius, COLOR Color, DRAW_FILL Draw_Fill , DOT_PIXEL Dot_Pixel );
//pic
void GUI_Disbitmap(POINT Xpoint, POINT Ypoint, const unsigned char *pMap, POINT Width, POINT Height);
void GUI_DisGrayMap(POINT Xpoint, POINT Ypoint, const unsigned char *pBmp);
//Display string
void GUI_DisChar(POINT Xstart, POINT Ystart, const char Acsii_Char, sFONT* Font, COLOR Color_Background, COLOR Color_Foreground);
void GUI_DisString_EN(POINT Xstart, POINT Ystart, const char * pString, sFONT* Font, COLOR Color_Background, COLOR Color_Foreground );
void GUI_DisNum(POINT Xpoint, POINT Ypoint, int32_t Nummber, sFONT* Font, COLOR Color_Background, COLOR Color_Foreground );
void GUI_Showtime(POINT Xstart, POINT Ystart, POINT Xend, POINT Yend, DEV_TIME *pTime, COLOR Color);
//show
void GUI_Show(void);
#ifdef __cplusplus
}
#endif
#endif // TFLITE_INFERENCE_TEST_LCD_GUI_H_
\ No newline at end of file
/*****************************************************************************
* | File : LCD_Touch.h
* | Author : Waveshare team
* | Function : LCD Touch Pad Driver and Draw
* | Info :
* Image scanning
* Please use progressive scanning to generate images or fonts
*----------------
* | This version: V1.0
* | Date : 2017-08-16
* | Info : Basic version
*
******************************************************************************/
#ifndef TFLITE_INFERENCE_TEST_LCD_TOUCH_H_
#define TFLITE_INFERENCE_TEST_LCD_TOUCH_H_
#ifdef __cplusplus
extern "C" {
#endif
#include "DEV_Config.h"
#include "LCD_Driver.h"
#include "LCD_GUI.h"
#include <math.h>
#include <stdio.h>
#include "pico/stdlib.h"
#include "pico/float.h"
#include "pico/multicore.h"
#include "pico/sync.h"
#define TP_PRESS_DOWN 0x80
#define TP_PRESSED 0x40
#define IGNORE_INTERVAL_MS 200
#define DIGIT_INPUT_COUNT 4
#define BOX_SIZE 84
#define BOX_START_Y 120 // Starting Y position for the boxes
#define BORDER_THICKNESS 2
#define BOX_PADDING 4
#define POINT_SPACE 2
#define WINDOW_SIZE 3
#define SPACE_BETWEEN_BOXES 10
//Touch screen structure
typedef struct {
POINT Xpoint0;
POINT Ypoint0;
POINT Xpoint;
POINT Ypoint;
uint8_t chStatus;
uint8_t chType;
int16_t iXoff;
int16_t iYoff;
float fXfac;
float fYfac;
//Select the coordinates of the XPT2046 touch \
screen relative to what scan direction
LCD_SCAN_DIR TP_Scan_Dir;
}TP_DEV;
//Brush structure
typedef struct{
POINT Xpoint;
POINT Ypoint;
COLOR Color;
DOT_PIXEL DotPixel;
}TP_DRAW;
typedef struct{
uint8_t InputData[INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE]; // 784 uint8_t array
int8_t PredictedDigit;
} USER_INPUT;
typedef struct{
// semaphore_t Semaphore;
bool IsProcessing;
USER_INPUT UserInputs[DIGIT_INPUT_COUNT]; // 4 inputs
} INFERENCE;
typedef struct{
int start_x;
int start_y;
int end_x;
int end_y;
uint8_t content[BOX_SIZE][BOX_SIZE];
} BOX_REFERENCE;
void TP_GetAdFac(void);
void TP_Adjust(void);
void TP_Dialog(void);
void TP_DrawBoard(void);
void TP_Init( LCD_SCAN_DIR Lcd_ScanDir );
void init_gui(void);
void reset_inference(INFERENCE* _inference);
int find_box_by_point(void);
void clear_drawing();
void draw_inference_result();
#ifdef __cplusplus
}
#endif
#endif // TFLITE_INFERENCE_TEST_LCD_TOUCH_H_
cmake_minimum_required(VERSION 3.12)
# Pull in PICO SDK (must be before project)
include(pico_sdk_import.cmake)
project(pico-tflmicro C CXX ASM)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
pico_sdk_init()
add_library(pico-tflmicro STATIC)
target_include_directories(pico-tflmicro
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/src/
${CMAKE_CURRENT_LIST_DIR}/src/third_party/ruy
${CMAKE_CURRENT_LIST_DIR}/src/third_party/gemmlowp
${CMAKE_CURRENT_LIST_DIR}/src/third_party/kissfft
${CMAKE_CURRENT_LIST_DIR}/src/third_party/flatbuffers
${CMAKE_CURRENT_LIST_DIR}/src/third_party/cmsis/CMSIS/Core/Include
${CMAKE_CURRENT_LIST_DIR}/src/third_party/flatbuffers/include
${CMAKE_CURRENT_LIST_DIR}/src/third_party/cmsis_nn/Include
)
target_compile_definitions(
pico-tflmicro
PUBLIC
COMPILE_DEFINITIONS TF_LITE_DISABLE_X86_NEON=1
COMPILE_DEFINITIONS TF_LITE_STATIC_MEMORY=1
COMPILE_DEFINITIONS TF_LITE_USE_CTIME=1
COMPILE_DEFINITIONS CMSIS_NN=1
COMPILE_DEFINITIONS ARDUINO=1
COMPILE_DEFINITIONS TFLITE_USE_CTIME=1
)
set_target_properties(
pico-tflmicro
PROPERTIES
COMPILE_FLAGS -Os
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_link_libraries(
pico-tflmicro
pico_stdlib
pico_multicore
)
target_sources(pico-tflmicro
PRIVATE
{{LIBRARY_SOURCES}}
)
add_library(pico-tflmicro_test STATIC)
target_include_directories(pico-tflmicro_test
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/src/
${CMAKE_CURRENT_LIST_DIR}/src/third_party/ruy
${CMAKE_CURRENT_LIST_DIR}/src/third_party/gemmlowp
${CMAKE_CURRENT_LIST_DIR}/src/third_party/kissfft
${CMAKE_CURRENT_LIST_DIR}/src/third_party/flatbuffers
${CMAKE_CURRENT_LIST_DIR}/src/third_party/cmsis/CMSIS/Core/Include
${CMAKE_CURRENT_LIST_DIR}/src/third_party/flatbuffers/include
${CMAKE_CURRENT_LIST_DIR}/src/third_party/cmsis_nn/Include
)
target_compile_definitions(
pico-tflmicro_test
PUBLIC
COMPILE_DEFINITIONS TF_LITE_DISABLE_X86_NEON=1
COMPILE_DEFINITIONS TF_LITE_STATIC_MEMORY=1
COMPILE_DEFINITIONS CMSIS_NN=1
)
set_target_properties(
pico-tflmicro_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_link_libraries(
pico-tflmicro_test
pico_stdlib
pico_multicore
)
add_subdirectory("examples/hello_world")
add_subdirectory("examples/person_detection")
{{TEST_FOLDERS}}
# TensorFlow Lite Micro
An Open Source Machine Learning Framework for Everyone.
## Introduction
This is a version of the [TensorFlow Lite Micro library](https://www.tensorflow.org/lite/microcontrollers)
for the Raspberry Pi Pico microcontroller. It allows you to run machine
learning models to do things like voice recognition, detect people in images,
recognize gestures from an accelerometer, and other sensor analysis tasks.
This version has scripts to upstream changes from the Google codebase. It also
takes advantage of the RP2040's dual cores for increased speed on some
operations.
## Getting Started
First you'll need to follow the Pico setup instructions to initialize the
development environment on your machine. Once that is done, make sure that the
`PICO_SDK_PATH` environment variable has been set to the location of the Pico
SDK, either in the shell you're building in, or the CMake configure environment
variable setting of the extension if you're using VS Code.
You should then be able to build the library, tests, and examples. The easiest
way to build is using VS Code's CMake integration, by loading the project and
choosing the build option at the bottom of the window.
Alternatively you can build the entire project, including tests, by running the
following commands from a terminal once you're in this repo's directory:
```bash
mkdir build
cd build
cmake ..
make
```
## What's Included
There are several example applications included. The simplest one to begin with
is the hello_world project. This demonstrates the fundamentals of deploying an
ML model on a device, driving the Pico's LED in a learned sine-wave pattern.
Once you have built the project, a UF2 file you can copy to the Pico should be
present at `build/examples/hello_world/hello_world.uf2`.
Another example is the person detector, but since the Pico doesn't come with
image inputs you'll need to write some code to hook up your own sensor. You can
find a fork of TFLM for the Arducam Pico4ML that does this at [arducam.com/pico4ml-an-rp2040-based-platform-for-tiny-machine-learning/](https://www.arducam.com/pico4ml-an-rp2040-based-platform-for-tiny-machine-learning/).
## Contributing
This repository (https://github.com/raspberrypi/pico-tflmicro) is read-only,
because it has been automatically generated from the master TensorFlow
repository at https://github.com/tensorflow/tensorflow. It's maintained by
@petewarden on a best effort basis, so bugs and PRs may not get addressed. You
can generate an updated version of this generated project by running the command:
```
sync/sync_with_upstream.sh
```
This should create a Pico-compatible project from the latest version of the
TensorFlow repository.
## Learning More
The [TensorFlow website](https://www.tensorflow.org/lite/microcontrollers) has
information on training, tutorials, and other resources.
The [TinyML Book](https://tinymlbook.com) is a guide to using TensorFlow Lite Micro
across a variety of different systems.
[TensorFlowLite Micro: Embedded Machine Learning on TinyML Systems](https://arxiv.org/pdf/2010.08678.pdf)
has more details on the design and implementation of the framework.
## Licensing
The TensorFlow source code is covered by the Apache 2 license described in
src/tensorflow/LICENSE, components from other libraries have the appropriate
licenses included in their third_party folders.
\ No newline at end of file
Results of running person_detection_benchmark with and without multicore optimizations.
To reproduce these, run `make person_detection_benchmark`, with and without the
TF_LITE_PICO_MULTICORE macro defined at the top of src/third_party/cmsis_nn/Source/NNSup
portFunctions/arm_nn_mat_mult_nt_t_s8.c
Without multicore CONV2D optimizations:
NoPersonDataIterations(1) took 823658 ticks (823 ms)
DEPTHWISE_CONV_2D took 34553 ticks (34 ms).
DEPTHWISE_CONV_2D took 60260 ticks (60 ms).
CONV_2D took 47509 ticks (47 ms).
DEPTHWISE_CONV_2D took 29581 ticks (29 ms).
CONV_2D took 32941 ticks (32 ms).
DEPTHWISE_CONV_2D took 57434 ticks (57 ms).
CONV_2D took 51301 ticks (51 ms).
DEPTHWISE_CONV_2D took 14411 ticks (14 ms).
CONV_2D took 26003 ticks (26 ms).
DEPTHWISE_CONV_2D took 27689 ticks (27 ms).
CONV_2D took 44571 ticks (44 ms).
DEPTHWISE_CONV_2D took 7025 ticks (7 ms).
CONV_2D took 23344 ticks (23 ms).
DEPTHWISE_CONV_2D took 13935 ticks (13 ms).
CONV_2D took 43007 ticks (43 ms).
DEPTHWISE_CONV_2D took 12996 ticks (12 ms).
CONV_2D took 42947 ticks (42 ms).
DEPTHWISE_CONV_2D took 12983 ticks (12 ms).
CONV_2D took 42953 ticks (42 ms).
DEPTHWISE_CONV_2D took 13023 ticks (13 ms).
CONV_2D took 42979 ticks (42 ms).
DEPTHWISE_CONV_2D took 13015 ticks (13 ms).
CONV_2D took 42951 ticks (42 ms).
DEPTHWISE_CONV_2D took 3522 ticks (3 ms).
CONV_2D took 25795 ticks (25 ms).
DEPTHWISE_CONV_2D took 6016 ticks (6 ms).
CONV_2D took 49461 ticks (49 ms).
AVERAGE_POOL_2D took 874 ticks (0 ms).
CONV_2D took 220 ticks (0 ms).
RESHAPE took 21 ticks (0 ms).
SOFTMAX took 338 ticks (0 ms).
Multi-core CONV2D and Depthwise Conv:
NoPersonDataIterations(1) took 587400 ticks (587 ms)
DEPTHWISE_CONV_2D took 34550 ticks (34 ms).
DEPTHWISE_CONV_2D took 31942 ticks (31 ms).
CONV_2D took 29140 ticks (29 ms).
DEPTHWISE_CONV_2D took 15765 ticks (15 ms).
CONV_2D took 21402 ticks (21 ms).
DEPTHWISE_CONV_2D took 30346 ticks (30 ms).
CONV_2D took 35317 ticks (35 ms).
DEPTHWISE_CONV_2D took 7792 ticks (7 ms).
CONV_2D took 17922 ticks (17 ms).
DEPTHWISE_CONV_2D took 14706 ticks (14 ms).
CONV_2D took 32168 ticks (32 ms).
DEPTHWISE_CONV_2D took 4780 ticks (4 ms).
CONV_2D took 16981 ticks (16 ms).
DEPTHWISE_CONV_2D took 9800 ticks (9 ms).
CONV_2D took 36303 ticks (36 ms).
DEPTHWISE_CONV_2D took 7141 ticks (7 ms).
CONV_2D took 36236 ticks (36 ms).
DEPTHWISE_CONV_2D took 7137 ticks (7 ms).
CONV_2D took 36343 ticks (36 ms).
DEPTHWISE_CONV_2D took 7166 ticks (7 ms).
CONV_2D took 36217 ticks (36 ms).
DEPTHWISE_CONV_2D took 7148 ticks (7 ms).
CONV_2D took 36216 ticks (36 ms).
DEPTHWISE_CONV_2D took 3624 ticks (3 ms).
CONV_2D took 22197 ticks (22 ms).
DEPTHWISE_CONV_2D took 4526 ticks (4 ms).
CONV_2D took 43024 ticks (43 ms).
AVERAGE_POOL_2D took 876 ticks (0 ms).
CONV_2D took 275 ticks (0 ms).
RESHAPE took 20 ticks (0 ms).
SOFTMAX took 340 ticks (0 ms).
\ No newline at end of file
cmake_minimum_required(VERSION 3.12)
project(hello_world C CXX ASM)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
add_executable(hello_world_test "")
target_include_directories(hello_world_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
hello_world_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(hello_world_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/hello_world_float_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/hello_world_int8_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/hello_world_test.cpp
)
target_link_libraries(
hello_world_test
pico-tflmicro
hardware_pwm
pico-tflmicro_test
)
pico_enable_stdio_usb(hello_world_test 1)
pico_enable_stdio_uart(hello_world_test 0)
pico_add_extra_outputs(hello_world_test)
add_executable(hello_world "")
target_include_directories(hello_world
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
hello_world
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(hello_world
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/constants.cpp
${CMAKE_CURRENT_LIST_DIR}/hello_world_float_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/hello_world_int8_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/main.cpp
${CMAKE_CURRENT_LIST_DIR}/main_functions.cpp
${CMAKE_CURRENT_LIST_DIR}/rp2/output_handler.cpp
${CMAKE_CURRENT_LIST_DIR}/constants.h
${CMAKE_CURRENT_LIST_DIR}/main_functions.h
${CMAKE_CURRENT_LIST_DIR}/output_handler.h
)
target_link_libraries(
hello_world
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(hello_world 1)
pico_enable_stdio_uart(hello_world 0)
pico_add_extra_outputs(hello_world)
add_executable(output_handler_test "")
target_include_directories(output_handler_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
output_handler_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(output_handler_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/output_handler_test.cpp
${CMAKE_CURRENT_LIST_DIR}/rp2/output_handler.cpp
${CMAKE_CURRENT_LIST_DIR}/constants.h
${CMAKE_CURRENT_LIST_DIR}/output_handler.h
)
target_link_libraries(
output_handler_test
pico-tflmicro
hardware_pwm
pico-tflmicro_test
)
pico_add_extra_outputs(output_handler_test)
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "constants.h"
// This is a small number so that it's easy to read the logs
const int kInferencesPerCycle = 20;
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_
// This constant represents the range of x values our model was trained on,
// which is from 0 to (2 * Pi). We approximate Pi to avoid requiring additional
// libraries.
const float kXrange = 2.f * 3.14159265359f;
// This constant determines the number of inferences to perform across the range
// of x values defined above. Since each inference takes time, the higher this
// number, the more time it will take to run through the entire range. The value
// of this constant can be tuned so that one full cycle takes a desired amount
// of time. Since different devices take different amounts of time to perform
// inference, this value should be defined per-device.
extern const int kInferencesPerCycle;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_CONSTANTS_H_
#include <cstdint>
extern const unsigned int g_hello_world_float_model_data_size;
extern const unsigned char g_hello_world_float_model_data[];
#include <cstdint>
extern const unsigned int g_hello_world_int8_model_data_size;
extern const unsigned char g_hello_world_int8_model_data[];
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <math.h>
#include "tensorflow/lite/core/c/common.h"
#include "hello_world_float_model_data.h"
#include "hello_world_int8_model_data.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/micro/recording_micro_interpreter.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace {
using HelloWorldOpResolver = tflite::MicroMutableOpResolver<1>;
TfLiteStatus RegisterOps(HelloWorldOpResolver& op_resolver) {
TF_LITE_ENSURE_STATUS(op_resolver.AddFullyConnected());
return kTfLiteOk;
}
} // namespace
TfLiteStatus ProfileMemoryAndLatency() {
tflite::MicroProfiler profiler;
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
// Arena size just a round number. The exact arena usage can be determined
// using the RecordingMicroInterpreter.
constexpr int kTensorArenaSize = 3000;
uint8_t tensor_arena[kTensorArenaSize];
constexpr int kNumResourceVariables = 24;
tflite::RecordingMicroAllocator* allocator(
tflite::RecordingMicroAllocator::Create(tensor_arena, kTensorArenaSize));
tflite::RecordingMicroInterpreter interpreter(
tflite::GetModel(g_hello_world_float_model_data), op_resolver, allocator,
tflite::MicroResourceVariables::Create(allocator, kNumResourceVariables),
&profiler);
TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors());
TFLITE_CHECK_EQ(interpreter.inputs_size(), 1);
interpreter.input(0)->data.f[0] = 1.f;
TF_LITE_ENSURE_STATUS(interpreter.Invoke());
MicroPrintf(""); // Print an empty new line
profiler.LogTicksPerTagCsv();
MicroPrintf(""); // Print an empty new line
interpreter.GetMicroAllocator().PrintAllocations();
return kTfLiteOk;
}
TfLiteStatus LoadFloatModelAndPerformInference() {
const tflite::Model* model =
::tflite::GetModel(g_hello_world_float_model_data);
TFLITE_CHECK_EQ(model->version(), TFLITE_SCHEMA_VERSION);
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
// Arena size just a round number. The exact arena usage can be determined
// using the RecordingMicroInterpreter.
constexpr int kTensorArenaSize = 3000;
uint8_t tensor_arena[kTensorArenaSize];
tflite::MicroInterpreter interpreter(model, op_resolver, tensor_arena,
kTensorArenaSize);
TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors());
// Check if the predicted output is within a small range of the
// expected output
float epsilon = 0.05f;
constexpr int kNumTestValues = 4;
float golden_inputs[kNumTestValues] = {0.f, 1.f, 3.f, 5.f};
for (int i = 0; i < kNumTestValues; ++i) {
interpreter.input(0)->data.f[0] = golden_inputs[i];
TF_LITE_ENSURE_STATUS(interpreter.Invoke());
float y_pred = interpreter.output(0)->data.f[0];
TFLITE_CHECK_LE(abs(sin(golden_inputs[i]) - y_pred), epsilon);
}
return kTfLiteOk;
}
TfLiteStatus LoadQuantModelAndPerformInference() {
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
const tflite::Model* model =
::tflite::GetModel(g_hello_world_int8_model_data);
TFLITE_CHECK_EQ(model->version(), TFLITE_SCHEMA_VERSION);
HelloWorldOpResolver op_resolver;
TF_LITE_ENSURE_STATUS(RegisterOps(op_resolver));
// Arena size just a round number. The exact arena usage can be determined
// using the RecordingMicroInterpreter.
constexpr int kTensorArenaSize = 3000;
uint8_t tensor_arena[kTensorArenaSize];
tflite::MicroInterpreter interpreter(model, op_resolver, tensor_arena,
kTensorArenaSize);
TF_LITE_ENSURE_STATUS(interpreter.AllocateTensors());
TfLiteTensor* input = interpreter.input(0);
TFLITE_CHECK_NE(input, nullptr);
TfLiteTensor* output = interpreter.output(0);
TFLITE_CHECK_NE(output, nullptr);
float output_scale = output->params.scale;
int output_zero_point = output->params.zero_point;
// Check if the predicted output is within a small range of the
// expected output
float epsilon = 0.05;
constexpr int kNumTestValues = 4;
float golden_inputs_float[kNumTestValues] = {0.77, 1.57, 2.3, 3.14};
// The int8 values are calculated using the following formula
// (golden_inputs_float[i] / input->params.scale + input->params.scale)
int8_t golden_inputs_int8[kNumTestValues] = {-96, -63, -34, 0};
for (int i = 0; i < kNumTestValues; ++i) {
input->data.int8[0] = golden_inputs_int8[i];
TF_LITE_ENSURE_STATUS(interpreter.Invoke());
float y_pred = (output->data.int8[0] - output_zero_point) * output_scale;
TFLITE_CHECK_LE(abs(sin(golden_inputs_float[i]) - y_pred), epsilon);
}
return kTfLiteOk;
}
int main(int argc, char* argv[]) {
tflite::InitializeTarget();
TF_LITE_ENSURE_STATUS(ProfileMemoryAndLatency());
TF_LITE_ENSURE_STATUS(LoadFloatModelAndPerformInference());
TF_LITE_ENSURE_STATUS(LoadQuantModelAndPerformInference());
MicroPrintf("~~~ALL TESTS PASSED~~~\n");
return kTfLiteOk;
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
// This is the default main used on systems that have the standard C entry
// point. Other devices (for example FreeRTOS or ESP32) that have different
// requirements for entry code (like an app_main function) should specialize
// this main.cc file in a target-specific subfolder.
int main(int argc, char* argv[]) {
setup();
while (true) {
loop();
}
}
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "constants.h"
#include "hello_world_float_model_data.h"
#include "main_functions.h"
#include "output_handler.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
namespace {
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
TfLiteTensor* output = nullptr;
int inference_count = 0;
constexpr int kTensorArenaSize = 2000;
uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
tflite::InitializeTarget();
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_hello_world_float_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
MicroPrintf(
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// This pulls in all the operation implementations we need.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<1> resolver;
TfLiteStatus resolve_status = resolver.AddFullyConnected();
if (resolve_status != kTfLiteOk) {
MicroPrintf("Op resolution failed");
return;
}
// Build an interpreter to run the model with.
static tflite::MicroInterpreter static_interpreter(
model, resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
MicroPrintf("AllocateTensors() failed");
return;
}
// Obtain pointers to the model's input and output tensors.
input = interpreter->input(0);
output = interpreter->output(0);
// Keep track of how many inferences we have performed.
inference_count = 0;
}
// The name of this function is important for Arduino compatibility.
void loop() {
// Calculate an x value to feed into the model. We compare the current
// inference_count to the number of inferences per cycle to determine
// our position within the range of possible x values the model was
// trained on, and use this to calculate a value.
float position = static_cast<float>(inference_count) /
static_cast<float>(kInferencesPerCycle);
float x = position * kXrange;
// Quantize the input from floating-point to integer
int8_t x_quantized = x / input->params.scale + input->params.zero_point;
// Place the quantized input in the model's input tensor
input->data.int8[0] = x_quantized;
// Run inference, and report any error
TfLiteStatus invoke_status = interpreter->Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Invoke failed on x: %f\n", static_cast<double>(x));
return;
}
// Obtain the quantized output from model's output tensor
int8_t y_quantized = output->data.int8[0];
// Dequantize the output from integer to floating-point
float y = (y_quantized - output->params.zero_point) * output->params.scale;
// Output the results. A custom HandleOutput function can be implemented
// for each supported hardware target.
HandleOutput(x, y);
// Increment the inference_counter, and reset it if we have reached
// the total number per cycle
inference_count += 1;
if (inference_count >= kInferencesPerCycle) inference_count = 0;
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_
// Expose a C friendly interface for main functions.
#ifdef __cplusplus
extern "C" {
#endif
// Initializes all data needed for the example. The name is important, and needs
// to be setup() for Arduino compatibility.
void setup();
// Runs one iteration of data gathering and inference. This should be called
// repeatedly from the application code. The name needs to be loop() for Arduino
// compatibility.
void loop();
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_MAIN_FUNCTIONS_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_log.h"
// Called by the main loop to produce some output based on the x and y values
void HandleOutput(float x_value, float y_value);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_HELLO_WORLD_OUTPUT_HANDLER_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "output_handler.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestCallability) {
// This will have external side-effects (like printing to the debug console
// or lighting an LED) that are hard to observe, so the most we can do is
// make sure the call doesn't crash.
HandleOutput(0, 0);
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "output_handler.h"
#include "pico/stdlib.h"
#include "pico/time.h"
#include "hardware/irq.h"
#include "hardware/resets.h"
#include "hardware/pwm.h"
#include "constants.h"
namespace {
int g_led_brightness = 0;
// For details on what this code is doing, see
// https://github.com/raspberrypi/pico-examples/blob/master/pwm/led_fade
extern "C" void on_pwm_wrap() {
// Clear the interrupt flag that brought us here
pwm_clear_irq(pwm_gpio_to_slice_num(PICO_DEFAULT_LED_PIN));
// Square the value to make the LED's brightness appear more linear
// Note this range matches with the wrap value
pwm_set_gpio_level(PICO_DEFAULT_LED_PIN, g_led_brightness * g_led_brightness);
}
void init_pwm_fade() {
// Tell the LED pin that the PWM is in charge of its value.
gpio_set_function(PICO_DEFAULT_LED_PIN, GPIO_FUNC_PWM);
// Figure out which slice we just connected to the LED pin
uint slice_num = pwm_gpio_to_slice_num(PICO_DEFAULT_LED_PIN);
// Mask our slice's IRQ output into the PWM block's single interrupt line,
// and register our interrupt handler
pwm_clear_irq(slice_num);
pwm_set_irq_enabled(slice_num, true);
irq_set_exclusive_handler(PWM_IRQ_WRAP, on_pwm_wrap);
irq_set_enabled(PWM_IRQ_WRAP, true);
// Get some sensible defaults for the slice configuration. By default, the
// counter is allowed to wrap over its maximum range (0 to 2**16-1)
pwm_config config = pwm_get_default_config();
// Set divider, reduces counter clock to sysclock/this value
pwm_config_set_clkdiv(&config, 4.f);
// Load the configuration into our PWM slice, and set it running.
pwm_init(slice_num, &config, true);
}
} // namespace
void HandleOutput(float x_value, float y_value) {
// Do this only once
static bool is_initialized = false;
if (!is_initialized) {
init_pwm_fade();
is_initialized = true;
}
// Calculate the brightness of the LED such that y=-1 is fully off
// and y=1 is fully on. The LED's brightness can range from 0-255.
g_led_brightness = (int)(127.5f * (y_value + 1));
// Log the current brightness value for display in the console.
MicroPrintf("%d\n", g_led_brightness);
// By default the sine wave is too fast to see in the LED, so slow
// down the whole program deliberately so it's more visible.
sleep_ms(10);
}
\ No newline at end of file
cmake_minimum_required(VERSION 3.12)
project(person_detection C CXX ASM)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 11)
add_executable(detection_responder_test "")
target_include_directories(detection_responder_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
detection_responder_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(detection_responder_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/detection_responder.cpp
${CMAKE_CURRENT_LIST_DIR}/detection_responder_test.cpp
${CMAKE_CURRENT_LIST_DIR}/detection_responder.h
)
target_link_libraries(
detection_responder_test
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(detection_responder_test 1)
pico_enable_stdio_uart(detection_responder_test 0)
pico_add_extra_outputs(detection_responder_test)
add_executable(person_detection_benchmark "")
target_include_directories(person_detection_benchmark
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
person_detection_benchmark
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(person_detection_benchmark
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/benchmarks/person_detection_benchmark.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/no_person_image_data.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/person_detect_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/person_image_data.cpp
${CMAKE_CURRENT_LIST_DIR}/person_detect_model_data.h
)
target_link_libraries(
person_detection_benchmark
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(person_detection_benchmark 1)
pico_enable_stdio_uart(person_detection_benchmark 0)
pico_add_extra_outputs(person_detection_benchmark)
add_executable(person_detection_test "")
target_include_directories(person_detection_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
person_detection_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(person_detection_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/model_settings.cpp
${CMAKE_CURRENT_LIST_DIR}/person_detection_test.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/no_person_image_data.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/person_detect_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/person_image_data.cpp
${CMAKE_CURRENT_LIST_DIR}/model_settings.h
${CMAKE_CURRENT_LIST_DIR}/no_person_image_data.h
${CMAKE_CURRENT_LIST_DIR}/person_detect_model_data.h
${CMAKE_CURRENT_LIST_DIR}/person_image_data.h
)
target_link_libraries(
person_detection_test
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(person_detection_test 1)
pico_enable_stdio_uart(person_detection_test 0)
pico_add_extra_outputs(person_detection_test)
add_executable(person_detection "")
target_include_directories(person_detection
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
person_detection
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(person_detection
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/detection_responder.cpp
${CMAKE_CURRENT_LIST_DIR}/image_provider.cpp
${CMAKE_CURRENT_LIST_DIR}/main.cpp
${CMAKE_CURRENT_LIST_DIR}/main_functions.cpp
${CMAKE_CURRENT_LIST_DIR}/model_settings.cpp
${CMAKE_CURRENT_LIST_DIR}/tensorflow/lite/micro/tools/make/downloads/person_model/person_detect_model_data.cpp
${CMAKE_CURRENT_LIST_DIR}/detection_responder.h
${CMAKE_CURRENT_LIST_DIR}/image_provider.h
${CMAKE_CURRENT_LIST_DIR}/main_functions.h
${CMAKE_CURRENT_LIST_DIR}/model_settings.h
${CMAKE_CURRENT_LIST_DIR}/person_detect_model_data.h
)
target_link_libraries(
person_detection
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(person_detection 1)
pico_enable_stdio_uart(person_detection 0)
pico_add_extra_outputs(person_detection)
add_executable(image_provider_test "")
target_include_directories(image_provider_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/.
)
set_target_properties(
image_provider_test
PROPERTIES
COMPILE_FLAGS -fno-rtti
COMPILE_FLAGS -fno-exceptions
COMPILE_FLAGS -fno-threadsafe-statics
COMPILE_FLAGS -nostdlib
)
target_sources(image_provider_test
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/image_provider.cpp
${CMAKE_CURRENT_LIST_DIR}/image_provider_test.cpp
${CMAKE_CURRENT_LIST_DIR}/model_settings.cpp
${CMAKE_CURRENT_LIST_DIR}/image_provider.h
${CMAKE_CURRENT_LIST_DIR}/model_settings.h
)
target_link_libraries(
image_provider_test
pico-tflmicro
hardware_pwm
)
pico_enable_stdio_usb(image_provider_test 1)
pico_enable_stdio_uart(image_provider_test 0)
pico_add_extra_outputs(image_provider_test)
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "detection_responder.h"
// This dummy implementation writes person and no person scores to the error
// console. Real applications will want to take some custom action instead, and
// should implement their own versions of this function.
void RespondToDetection(tflite::ErrorReporter* error_reporter,
int8_t person_score, int8_t no_person_score) {
TF_LITE_REPORT_ERROR(error_reporter, "person score:%d no person score %d",
person_score, no_person_score);
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Provides an interface to take an action based on the output from the person
// detection model.
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
// Called every time the results of a person detection run are available. The
// `person_score` has the numerical confidence that the captured image contains
// a person, and `no_person_score` has the numerical confidence that the image
// does not contain a person. Typically if person_score > no person score, the
// image is considered to contain a person. This threshold may be adjusted for
// particular applications.
void RespondToDetection(tflite::ErrorReporter* error_reporter,
int8_t person_score, int8_t no_person_score);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_DETECTION_RESPONDER_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "detection_responder.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestCallability) {
tflite::MicroErrorReporter micro_error_reporter;
// This will have external side-effects (like printing to the debug console
// or lighting an LED) that are hard to observe, so the most we can do is
// make sure the call doesn't crash.
RespondToDetection(&micro_error_reporter, -100, 100);
RespondToDetection(&micro_error_reporter, 100, 50);
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "image_provider.h"
#include "model_settings.h"
TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter, int image_width,
int image_height, int channels, int8_t* image_data) {
for (int i = 0; i < image_width * image_height * channels; ++i) {
image_data[i] = 0;
}
return kTfLiteOk;
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
// This is an abstraction around an image source like a camera, and is
// expected to return 8-bit sample data. The assumption is that this will be
// called in a low duty-cycle fashion in a low-power application. In these
// cases, the imaging sensor need not be run in a streaming mode, but rather can
// be idled in a relatively low-power mode between calls to GetImage(). The
// assumption is that the overhead and time of bringing the low-power sensor out
// of this standby mode is commensurate with the expected duty cycle of the
// application. The underlying sensor may actually be put into a streaming
// configuration, but the image buffer provided to GetImage should not be
// overwritten by the driver code until the next call to GetImage();
//
// The reference implementation can have no platform-specific dependencies, so
// it just returns a static image. For real applications, you should
// ensure there's a specialized implementation that accesses hardware APIs.
TfLiteStatus GetImage(tflite::ErrorReporter* error_reporter, int image_width,
int image_height, int channels, int8_t* image_data);
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_IMAGE_PROVIDER_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "image_provider.h"
#include <limits>
#include "tensorflow/lite/c/common.h"
#include "model_settings.h"
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestImageProvider) {
tflite::MicroErrorReporter micro_error_reporter;
int8_t image_data[kMaxImageSize];
TfLiteStatus get_status = GetImage(&micro_error_reporter, kNumCols, kNumRows,
kNumChannels, image_data);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status);
TF_LITE_MICRO_EXPECT(image_data != nullptr);
// Make sure we can read all of the returned memory locations.
uint32_t total = 0;
for (int i = 0; i < kMaxImageSize; ++i) {
total += image_data[i];
}
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
// This is the default main used on systems that have the standard C entry
// point. Other devices (for example FreeRTOS or ESP32) that have different
// requirements for entry code (like an app_main function) should specialize
// this main.cc file in a target-specific subfolder.
int main(int argc, char* argv[]) {
setup();
while (true) {
loop();
}
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "main_functions.h"
#include "detection_responder.h"
#include "image_provider.h"
#include "model_settings.h"
#include "person_detect_model_data.h"
#include "tensorflow/lite/micro/tflite_bridge/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Globals, used for compatibility with Arduino-style sketches.
namespace {
tflite::ErrorReporter* error_reporter = nullptr;
const tflite::Model* model = nullptr;
tflite::MicroInterpreter* interpreter = nullptr;
TfLiteTensor* input = nullptr;
// In order to use optimized tensorflow lite kernels, a signed int8_t quantized
// model is preferred over the legacy unsigned model format. This means that
// throughout this project, input images must be converted from unisgned to
// signed format. The easiest and quickest way to convert from unsigned to
// signed 8-bit integers is to subtract 128 from the unsigned value to get a
// signed value.
// An area of memory to use for input, output, and intermediate arrays.
constexpr int kTensorArenaSize = 136 * 1024;
static uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup() {
// Set up logging. Google style is to avoid globals or statics because of
// lifetime uncertainty, but since this has a trivial destructor it's okay.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroErrorReporter micro_error_reporter;
error_reporter = &micro_error_reporter;
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_person_detect_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
TF_LITE_REPORT_ERROR(error_reporter,
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
//
// tflite::AllOpsResolver resolver;
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddAveragePool2D();
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax();
// Build an interpreter to run the model with.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk) {
TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
return;
}
// Get information about the memory area to use for the model's input.
input = interpreter->input(0);
}
// The name of this function is important for Arduino compatibility.
void loop() {
// Get image from provider.
if (kTfLiteOk != GetImage(error_reporter, kNumCols, kNumRows, kNumChannels,
input->data.int8)) {
TF_LITE_REPORT_ERROR(error_reporter, "Image capture failed.");
}
// Run the model on this input and make sure it succeeds.
if (kTfLiteOk != interpreter->Invoke()) {
TF_LITE_REPORT_ERROR(error_reporter, "Invoke failed.");
}
TfLiteTensor* output = interpreter->output(0);
// Process the inference results.
int8_t person_score = output->data.uint8[kPersonIndex];
int8_t no_person_score = output->data.uint8[kNotAPersonIndex];
RespondToDetection(error_reporter, person_score, no_person_score);
}
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_
// Expose a C friendly interface for main functions.
#ifdef __cplusplus
extern "C" {
#endif
// Initializes all data needed for the example. The name is important, and needs
// to be setup() for Arduino compatibility.
void setup();
// Runs one iteration of data gathering and inference. This should be called
// repeatedly from the application code. The name needs to be loop() for Arduino
// compatibility.
void loop();
#ifdef __cplusplus
}
#endif
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MAIN_FUNCTIONS_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "model_settings.h"
const char* kCategoryLabels[kCategoryCount] = {
"notperson",
"person",
};
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_
// Keeping these as constant expressions allow us to allocate fixed-sized arrays
// on the stack for our working memory.
// All of these values are derived from the values used during model training,
// if you change your model you'll need to update these constants.
constexpr int kNumCols = 96;
constexpr int kNumRows = 96;
constexpr int kNumChannels = 1;
constexpr int kMaxImageSize = kNumCols * kNumRows * kNumChannels;
constexpr int kCategoryCount = 2;
constexpr int kPersonIndex = 1;
constexpr int kNotAPersonIndex = 0;
extern const char* kCategoryLabels[kCategoryCount];
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_MODEL_SETTINGS_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This data was created from a sample image from without a person in it.
// Convert original image to simpler format:
// convert -resize 96x96\! noperson.PNG noperson.bmp3
// Skip the 54 byte bmp3 header and add the reset of the bytes to a C array:
// xxd -s 54 -i /tmp/noperson.bmp3 > /tmp/noperson.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_NO_PERSON_IMAGE_DATA_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_NO_PERSON_IMAGE_DATA_H_
#include <cstdint>
extern const unsigned int g_no_person_image_data_size;
extern const uint8_t g_no_person_image_data[];
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_NO_PERSON_IMAGE_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a standard TensorFlow Lite model file that has been converted into a
// C data array, so it can be easily compiled into a binary for devices that
// don't have a file system. It was created using the command:
// xxd -i person_detect.tflite > person_detect_model_data.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_
extern const unsigned char g_person_detect_model_data[];
extern const int g_person_detect_model_data_len;
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_DETECT_MODEL_DATA_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "model_settings.h"
#include "no_person_image_data.h"
#include "person_image_data.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "person_detect_model_data.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
#include "tensorflow/lite/schema/schema_generated.h"
// Create an area of memory to use for input, output, and intermediate arrays.
#if defined(XTENSA) && defined(VISION_P6)
constexpr int tensor_arena_size = 352 * 1024;
#else
constexpr int tensor_arena_size = 136 * 1024;
#endif // defined(XTENSA) && defined(VISION_P6)
uint8_t tensor_arena[tensor_arena_size];
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(TestInvoke) {
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
const tflite::Model* model = ::tflite::GetModel(g_person_detect_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION) {
MicroPrintf(
"Model provided is schema version %d not equal "
"to supported version %d.\n",
model->version(), TFLITE_SCHEMA_VERSION);
}
// Pull in only the operation implementations we need.
// This relies on a complete list of all the ops needed by this graph.
// An easier approach is to just use the AllOpsResolver, but this will
// incur some penalty in code space for op implementations that are not
// needed by this graph.
tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddAveragePool2D(tflite::Register_AVERAGE_POOL_2D_INT8());
micro_op_resolver.AddConv2D(tflite::Register_CONV_2D_INT8());
micro_op_resolver.AddDepthwiseConv2D(
tflite::Register_DEPTHWISE_CONV_2D_INT8());
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax(tflite::Register_SOFTMAX_INT8());
// Build an interpreter to run the model with.
tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,
tensor_arena_size);
interpreter.AllocateTensors();
// Get information about the memory area to use for the model's input.
TfLiteTensor* input = interpreter.input(0);
// Make sure the input has the properties we expect.
TF_LITE_MICRO_EXPECT(input != nullptr);
TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(kNumRows, input->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kNumCols, input->dims->data[2]);
TF_LITE_MICRO_EXPECT_EQ(kNumChannels, input->dims->data[3]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type);
// Copy an image with a person into the memory area used for the input.
TFLITE_DCHECK_EQ(input->bytes, static_cast<size_t>(g_person_image_data_size));
memcpy(input->data.int8, g_person_image_data, input->bytes);
// Run the model on this input and make sure it succeeds.
TfLiteStatus invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Invoke failed\n");
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
// Get the output from the model, and make sure it's the expected size and
// type.
TfLiteTensor* output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(kCategoryCount, output->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
// Make sure that the expected "Person" score is higher than the other class.
int8_t person_score = output->data.int8[kPersonIndex];
int8_t no_person_score = output->data.int8[kNotAPersonIndex];
MicroPrintf("person data. person score: %d, no person score: %d\n",
person_score, no_person_score);
TF_LITE_MICRO_EXPECT_GT(person_score, no_person_score);
memcpy(input->data.int8, g_no_person_image_data, input->bytes);
// Run the model on this "No Person" input.
invoke_status = interpreter.Invoke();
if (invoke_status != kTfLiteOk) {
MicroPrintf("Invoke failed\n");
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
// Get the output from the model, and make sure it's the expected size and
// type.
output = interpreter.output(0);
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(kCategoryCount, output->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
// Make sure that the expected "No Person" score is higher.
person_score = output->data.int8[kPersonIndex];
no_person_score = output->data.int8[kNotAPersonIndex];
MicroPrintf("no person data. person score: %d, no person score: %d\n",
person_score, no_person_score);
TF_LITE_MICRO_EXPECT_GT(no_person_score, person_score);
MicroPrintf("Ran successfully\n");
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This data was created from a sample image from with a person in it.
// Convert original image to simpler format:
// convert -resize 96x96\! person.PNG person.bmp3
// Skip the 54 byte bmp3 header and add the reset of the bytes to a C array:
// xxd -s 54 -i /tmp/person.bmp3 > /tmp/person.cc
#ifndef TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_IMAGE_DATA_H_
#define TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_IMAGE_DATA_H_
#include <cstdint>
extern const unsigned int g_person_image_data_size;
extern const uint8_t g_person_image_data[];
#endif // TENSORFLOW_LITE_MICRO_EXAMPLES_PERSON_DETECTION_PERSON_IMAGE_DATA_H_
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "model_settings.h"
#include "no_person_image_data.h"
#include "person_detect_model_data.h"
#include "person_image_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/benchmarks/micro_benchmark.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
/*
* Person Detection benchmark. Evaluates runtime performance of the visual
* wakewords person detection model. This is the same model found in
* exmaples/person_detection.
*/
namespace tflite {
using PersonDetectionOpResolver = MicroMutableOpResolver<6>;
using PersonDetectionBenchmarkRunner = MicroBenchmarkRunner<int8_t>;
// Create an area of memory to use for input, output, and intermediate arrays.
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
constexpr int kTensorArenaSize = 135 * 1024;
alignas(16) uint8_t tensor_arena[kTensorArenaSize];
uint8_t op_resolver_buffer[sizeof(PersonDetectionOpResolver)];
uint8_t benchmark_runner_buffer[sizeof(PersonDetectionBenchmarkRunner)];
// Initialize benchmark runner instance explicitly to avoid global init order
// issues on Sparkfun. Use new since static variables within a method
// are automatically surrounded by locking, which breaks bluepill and stm32f4.
PersonDetectionBenchmarkRunner *CreateBenchmarkRunner(MicroProfiler *profiler) {
// We allocate PersonDetectionOpResolver from a global buffer
// because the object's lifetime must exceed that of the
// PersonDetectionBenchmarkRunner object.
PersonDetectionOpResolver *op_resolver =
new (op_resolver_buffer) PersonDetectionOpResolver();
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
op_resolver->AddConv2D(tflite::Register_CONV_2D_INT8REF());
op_resolver->AddDepthwiseConv2D();
op_resolver->AddSoftmax();
op_resolver->AddAveragePool2D();
op_resolver->AddReshape();
return new (benchmark_runner_buffer)
PersonDetectionBenchmarkRunner(g_person_detect_model_data, op_resolver,
tensor_arena, kTensorArenaSize, profiler);
}
void PersonDetectionNIerations(const int8_t *input, int iterations,
const char *tag,
PersonDetectionBenchmarkRunner &benchmark_runner,
MicroProfiler &profiler) {
benchmark_runner.SetInput(input);
uint32_t ticks = 0;
for (int i = 0; i < iterations; ++i) {
profiler.ClearEvents();
benchmark_runner.RunSingleIteration();
ticks += profiler.GetTotalTicks();
}
MicroPrintf("%s took %u ticks (%u ms)", tag, ticks, TicksToMs(ticks));
}
} // namespace tflite
int main(int argc, char **argv) {
tflite::InitializeTarget();
while (true) {
tflite::MicroProfiler profiler;
uint32_t event_handle = profiler.BeginEvent("InitializeBenchmarkRunner");
tflite::PersonDetectionBenchmarkRunner *benchmark_runner =
CreateBenchmarkRunner(&profiler);
profiler.EndEvent(event_handle);
profiler.Log();
MicroPrintf(""); // null MicroPrintf serves as a newline.
tflite::PersonDetectionNIerations(
reinterpret_cast<const int8_t *>(g_person_image_data), 1,
"WithPersonDataIterations(1)", *benchmark_runner, profiler);
profiler.Log();
MicroPrintf(""); // null MicroPrintf serves as a newline.
tflite::PersonDetectionNIerations(
reinterpret_cast<const int8_t *>(g_no_person_image_data), 1,
"NoPersonDataIterations(1)", *benchmark_runner, profiler);
profiler.Log();
MicroPrintf(""); // null MicroPrintf serves as a newline.
tflite::PersonDetectionNIerations(
reinterpret_cast<const int8_t *>(g_person_image_data), 10,
"WithPersonDataIterations(10)", *benchmark_runner, profiler);
MicroPrintf(""); // null MicroPrintf serves as a newline.
tflite::PersonDetectionNIerations(
reinterpret_cast<const int8_t *>(g_no_person_image_data), 10,
"NoPersonDataIterations(10)", *benchmark_runner, profiler);
MicroPrintf(""); // null MicroPrintf serves as a newline.
}
}
# This is a copy of <PICO_SDK_PATH>/external/pico_sdk_import.cmake
# This can be dropped into an external project to help locate this SDK
# It should be include()ed prior to project()
# todo document
if (DEFINED ENV{PICO_SDK_PATH} AND (NOT PICO_SDK_PATH))
set(PICO_SDK_PATH $ENV{PICO_SDK_PATH})
message("Using PICO_SDK_PATH from environment ('${PICO_SDK_PATH}')")
endif ()
if (DEFINED ENV{PICO_SDK_FETCH_FROM_GIT} AND (NOT PICO_SDK_FETCH_FROM_GIT))
set(PICO_SDK_FETCH_FROM_GIT $ENV{PICO_SDK_FETCH_FROM_GIT})
message("Using PICO_SDK_FETCH_FROM_GIT from environment ('${PICO_SDK_FETCH_FROM_GIT}')")
endif ()
if (DEFINED ENV{PICO_SDK_FETCH_FROM_GIT_PATH} AND (NOT PICO_SDK_FETCH_FROM_GIT_PATH))
set(PICO_SDK_FETCH_FROM_GIT_PATH $ENV{PICO_SDK_FETCH_FROM_GIT_PATH})
message("Using PICO_SDK_FETCH_FROM_GIT_PATH from environment ('${PICO_SDK_FETCH_FROM_GIT_PATH}')")
endif ()
set(PICO_SDK_PATH "${PICO_SDK_PATH}" CACHE PATH "Path to the PICO SDK")
set(PICO_SDK_FETCH_FROM_GIT "${PICO_SDK_FETCH_FROM_GIT}" CACHE BOOL "Set to ON to fetch copy of PICO SDK from git if not otherwise locatable")
set(PICO_SDK_FETCH_FROM_GIT_PATH "${PICO_SDK_FETCH_FROM_GIT_PATH}" CACHE FILEPATH "location to download SDK")
if (NOT PICO_SDK_PATH)
if (PICO_SDK_FETCH_FROM_GIT)
include(FetchContent)
set(FETCHCONTENT_BASE_DIR_SAVE ${FETCHCONTENT_BASE_DIR})
if (PICO_SDK_FETCH_FROM_GIT_PATH)
get_filename_component(FETCHCONTENT_BASE_DIR "${PICO_SDK_FETCH_FROM_GIT_PATH}" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
endif ()
FetchContent_Declare(
pico_sdk
GIT_REPOSITORY https://github.com/raspberrypi/pico-sdk
GIT_TAG master
)
if (NOT pico_sdk)
message("Downloading PICO SDK")
FetchContent_Populate(pico_sdk)
set(PICO_SDK_PATH ${pico_sdk_SOURCE_DIR})
endif ()
set(FETCHCONTENT_BASE_DIR ${FETCHCONTENT_BASE_DIR_SAVE})
else ()
message(FATAL_ERROR
"PICO SDK location was not specified. Please set PICO_SDK_PATH or set PICO_SDK_FETCH_FROM_GIT to on to fetch from git."
)
endif ()
endif ()
get_filename_component(PICO_SDK_PATH "${PICO_SDK_PATH}" REALPATH BASE_DIR "${CMAKE_BINARY_DIR}")
if (NOT EXISTS ${PICO_SDK_PATH})
message(FATAL_ERROR "Directory '${PICO_SDK_PATH}' not found")
endif ()
set(PICO_SDK_INIT_CMAKE_FILE ${PICO_SDK_PATH}/pico_sdk_init.cmake)
if (NOT EXISTS ${PICO_SDK_INIT_CMAKE_FILE})
message(FATAL_ERROR "Directory '${PICO_SDK_PATH}' does not appear to contain the PICO SDK")
endif ()
set(PICO_SDK_PATH ${PICO_SDK_PATH} CACHE PATH "Path to the PICO SDK" FORCE)
include(${PICO_SDK_INIT_CMAKE_FILE})
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include "signal/src/circular_buffer.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kDelayLengthIndex = 0; // 'delay_length'
struct TFLMSignalFrontendDelayParams {
int32_t frame_size;
int32_t delay_length;
int32_t outer_dims;
int8_t** state_buffers;
tflm_signal::CircularBuffer** circular_buffers;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* params = static_cast<TFLMSignalFrontendDelayParams*>(
context->AllocatePersistentBuffer(context,
sizeof(TFLMSignalFrontendDelayParams)));
if (params == nullptr) {
return nullptr;
}
FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer), length);
params->delay_length = fbw.ElementAsInt32(kDelayLengthIndex);
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
auto* params =
reinterpret_cast<TFLMSignalFrontendDelayParams*>(node->user_data);
TF_LITE_ENSURE(context, params != nullptr);
RuntimeShape input_shape = GetTensorShape(input);
int innermost_dim = input_shape.Dims(input_shape.DimensionsCount() - 1);
params->outer_dims = input_shape.FlatSize() / innermost_dim;
params->frame_size = innermost_dim;
params->state_buffers =
static_cast<int8_t**>(context->AllocatePersistentBuffer(
context, params->outer_dims * sizeof(int8_t*)));
params->circular_buffers = static_cast<tflm_signal::CircularBuffer**>(
context->AllocatePersistentBuffer(
context, params->outer_dims * sizeof(tflm_signal::CircularBuffer*)));
for (int i = 0; i < params->outer_dims; i++) {
size_t capacity = params->frame_size + params->delay_length;
size_t state_size = tflm_signal::CircularBufferGetNeededMemory(capacity);
params->state_buffers[i] =
static_cast<int8_t*>(context->AllocatePersistentBuffer(
context, state_size * sizeof(int8_t)));
params->circular_buffers[i] = tflm_signal::CircularBufferInit(
capacity, params->state_buffers[i], state_size);
tflm_signal::CircularBufferWriteZeros(params->circular_buffers[i],
params->delay_length);
}
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalFrontendDelayParams*>(node->user_data);
const TfLiteEvalTensor* input =
micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output = micro::GetEvalOutput(context, node, kOutputTensor);
const int16_t* input_data = micro::GetTensorData<int16_t>(input);
int16_t* output_data = micro::GetTensorData<int16_t>(output);
for (int dim_index = 0, sample_index = 0; dim_index < params->outer_dims;
dim_index++, sample_index += params->frame_size) {
tflm_signal::CircularBufferWrite(params->circular_buffers[dim_index],
&input_data[sample_index],
params->frame_size);
tflm_signal::CircularBufferGet(params->circular_buffers[dim_index],
params->frame_size,
&output_data[sample_index]);
tflm_signal::CircularBufferDiscard(params->circular_buffers[dim_index],
params->frame_size);
}
return kTfLiteOk;
}
void Reset(TfLiteContext* context, void* buffer) {
auto* params = static_cast<TFLMSignalFrontendDelayParams*>(buffer);
for (int i = 0; i < params->outer_dims; ++i) {
tflm_signal::CircularBufferReset(params->circular_buffers[i]);
tflm_signal::CircularBufferWriteZeros(params->circular_buffers[i],
params->delay_length);
}
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_DELAY() {
static TFLMRegistration r =
micro::RegisterOp(Init, Prepare, Eval, nullptr, Reset);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_DELAY_FLEXBUFFERS_GENERATED_DATA_H_
#define SIGNAL_MICRO_KERNELS_DELAY_FLEXBUFFERS_GENERATED_DATA_H_
extern const int g_gen_data_size_3_delay;
extern const unsigned char g_gen_data_3_delay[];
extern const int g_gen_data_size_5_delay;
extern const unsigned char g_gen_data_5_delay[];
#endif // SIGNAL_MICRO_KERNELS_DELAY_FLEXBUFFERS_GENERATED_DATA_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/energy.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kEndIndexIndex = 0; // 'end_index'
constexpr int kStartIndexIndex = 1; // 'start_index'
struct TFLMSignalEnergyParams {
int32_t end_index;
int32_t start_index;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* data =
static_cast<TFLMSignalEnergyParams*>(context->AllocatePersistentBuffer(
context, sizeof(TFLMSignalEnergyParams)));
if (data == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
data->end_index = fbw.ElementAsInt32(kEndIndexIndex);
data->start_index = fbw.ElementAsInt32(kStartIndexIndex);
return data;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalEnergyParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const Complex<int16_t>* input_data =
tflite::micro::GetTensorData<Complex<int16_t>>(input);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
tflm_signal::SpectrumToEnergy(input_data, params->start_index,
params->end_index, output_data);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_ENERGY() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_ENERGY_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_ENERGY_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_start_index_2_end_index_4;
extern const unsigned char g_gen_data_start_index_2_end_index_4[];
extern const int g_gen_data_size_start_index_0_end_index_4;
extern const unsigned char g_gen_data_start_index_0_end_index_4[];
extern const int g_gen_data_size_start_index_4_end_index_8;
extern const unsigned char g_gen_data_start_index_4_end_index_8[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_ENERGY_FLEXBUFFERS_DATA_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/micro/kernels/fft_auto_scale_kernel.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kScaleBitTensor = 1;
TfLiteStatus FftAutoScalePrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteTensor* scale_bit =
micro_context->AllocateTempOutputTensor(node, kScaleBitTensor);
TF_LITE_ENSURE(context, scale_bit != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(scale_bit), 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, scale_bit->type, kTfLiteInt32);
micro_context->DeallocateTempTfLiteTensor(scale_bit);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/micro/kernels/fft_auto_scale_kernel.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "signal/src/fft_auto_scale.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kScaleBitTensor = 1;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* scale_bit =
tflite::micro::GetEvalOutput(context, node, kScaleBitTensor);
const int16_t* input_data = tflite::micro::GetTensorData<int16_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
int32_t* scale_bit_data = tflite::micro::GetTensorData<int32_t>(scale_bit);
*scale_bit_data =
tflm_signal::FftAutoScale(input_data, output->dims->data[0], output_data);
return kTfLiteOk;
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_FFT_AUTO_SCALE() {
static TFLMRegistration r =
tflite::micro::RegisterOp(nullptr, FftAutoScalePrepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_FFT_AUTO_SCALE_KERNEL_H_
#define SIGNAL_MICRO_KERNELS_FFT_AUTO_SCALE_KERNEL_H_
#include "tensorflow/lite/c/common.h"
namespace tflite {
TfLiteStatus FftAutoScalePrepare(TfLiteContext* context, TfLiteNode* node);
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_FFT_AUTO_SCALE_KERNEL_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_fft_length_64_float;
extern const unsigned char g_gen_data_fft_length_64_float[];
extern const int g_gen_data_size_fft_length_64_int16;
extern const unsigned char g_gen_data_fft_length_64_int16[];
extern const int g_gen_data_size_fft_length_64_int32;
extern const unsigned char g_gen_data_fft_length_64_int32[];
extern const int g_gen_data_size_fft_length_512_float;
extern const unsigned char g_gen_data_fft_length_512_float[];
extern const int g_gen_data_size_fft_length_512_int16;
extern const unsigned char g_gen_data_fft_length_512_int16[];
extern const int g_gen_data_size_fft_length_512_int32;
extern const unsigned char g_gen_data_fft_length_512_int32[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kWeightTensor = 1;
constexpr int kUnweightTensor = 2;
constexpr int kChFreqStartsTensor = 3;
constexpr int kChWeightStartsTensor = 4;
constexpr int kChannelWidthsTensor = 5;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kNumChannelsIndex = 0; // 'num_channels'
struct TFLMSignalFilterBankParams {
tflm_signal::FilterbankConfig config;
uint64_t* work_area;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalFilterBankParams*>(
context->AllocatePersistentBuffer(context,
sizeof(TFLMSignalFilterBankParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->config.num_channels = fbw.ElementAsInt32(kNumChannelsIndex);
params->work_area = static_cast<uint64_t*>(context->AllocatePersistentBuffer(
context, (params->config.num_channels + 1) * sizeof(uint64_t)));
if (params->work_area == nullptr) {
return nullptr;
}
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 6);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kWeightTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kUnweightTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChFreqStartsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChWeightStartsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChannelWidthsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt64);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalFilterBankParams*>(node->user_data);
const TfLiteEvalTensor* input0 =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kWeightTensor);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kUnweightTensor);
const TfLiteEvalTensor* input3 =
tflite::micro::GetEvalInput(context, node, kChFreqStartsTensor);
const TfLiteEvalTensor* input4 =
tflite::micro::GetEvalInput(context, node, kChWeightStartsTensor);
const TfLiteEvalTensor* input5 =
tflite::micro::GetEvalInput(context, node, kChannelWidthsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
params->config.weights = tflite::micro::GetTensorData<int16_t>(input1);
params->config.unweights = tflite::micro::GetTensorData<int16_t>(input2);
params->config.channel_frequency_starts =
tflite::micro::GetTensorData<int16_t>(input3);
params->config.channel_weight_starts =
tflite::micro::GetTensorData<int16_t>(input4);
params->config.channel_widths = tflite::micro::GetTensorData<int16_t>(input5);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input0);
uint64_t* output_data = tflite::micro::GetTensorData<uint64_t>(output);
tflm_signal::FilterbankAccumulateChannels(&params->config, input_data,
params->work_area);
size_t output_size;
TfLiteTypeSizeOf(output->type, &output_size);
output_size *= ElementCount(*output->dims);
// Discard channel 0, which is just scratch
memcpy(output_data, params->work_area + 1, output_size);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_32_channel;
extern const unsigned char g_gen_data_filter_bank_32_channel[];
extern const int g_gen_data_size_filter_bank_16_channel;
extern const unsigned char g_gen_data_filter_bank_16_channel[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_log.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kInputCorrectionBitsIndex = 0; // 'input_correction_bits'
constexpr int kOutputScaleIndex = 1; // 'output_scale'
struct TFLMSignalLogParams {
int input_correction_bits;
int output_scale;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalLogParams*>(
context->AllocatePersistentBuffer(context, sizeof(TFLMSignalLogParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->input_correction_bits = fbw.ElementAsInt32(kInputCorrectionBitsIndex);
params->output_scale = fbw.ElementAsInt32(kOutputScaleIndex);
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalLogParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
int num_channels = input->dims->data[0];
tflm_signal::FilterbankLog(input_data, num_channels, params->output_scale,
params->input_correction_bits, output_data);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_LOG() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_log_scale_1600_correction_bits_3;
extern const unsigned char
g_gen_data_filter_bank_log_scale_1600_correction_bits_3[];
extern const int g_gen_data_size_filter_bank_log_scale_32768_correction_bits_5;
extern const unsigned char
g_gen_data_filter_bank_log_scale_32768_correction_bits_5[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_spectral_subtraction.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kNoiseEstimateTensor = 1;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'alternate_one_minus_smoothing'
constexpr int kAlternateOneMinusSmoothingIndex = 0;
constexpr int kAlternateSmoothingIndex = 1; // 'alternate_smoothing'
constexpr int kClampingIndex = 2; // 'clamping'
constexpr int kMinSignalRemainingIndex = 3; // 'min_signal_remaining'
constexpr int kNumChannelsIndex = 4; // 'num_channels'
constexpr int kOneMinusSmoothingIndex = 5; // 'one_minus_smoothing'
constexpr int kSmoothingIndex = 6; // 'smoothing'
constexpr int kSmoothingBitsIndex = 7; // 'smoothing_bits'
constexpr int kSpectralSubtractionBitsIndex = 8; // 'spectral_subtraction_bits'
struct TFLMSignalSpectralSubtractionParams {
tflm_signal::SpectralSubtractionConfig config;
uint32_t* noise_estimate;
size_t noise_estimate_size;
};
void ResetState(TFLMSignalSpectralSubtractionParams* params) {
memset(params->noise_estimate, 0,
sizeof(uint32_t) * params->config.num_channels);
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalSpectralSubtractionParams*>(
context->AllocatePersistentBuffer(
context, sizeof(TFLMSignalSpectralSubtractionParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->config.alternate_one_minus_smoothing =
fbw.ElementAsInt32(kAlternateOneMinusSmoothingIndex);
params->config.alternate_smoothing =
fbw.ElementAsInt32(kAlternateSmoothingIndex);
params->config.clamping = fbw.ElementAsBool(kClampingIndex);
params->config.min_signal_remaining =
fbw.ElementAsInt32(kMinSignalRemainingIndex);
params->config.num_channels = fbw.ElementAsInt32(kNumChannelsIndex);
params->config.one_minus_smoothing =
fbw.ElementAsInt32(kOneMinusSmoothingIndex);
params->config.one_minus_smoothing =
fbw.ElementAsInt32(kOneMinusSmoothingIndex);
params->config.smoothing = fbw.ElementAsInt32(kSmoothingIndex);
params->config.smoothing_bits = fbw.ElementAsInt32(kSmoothingBitsIndex);
params->config.spectral_subtraction_bits =
fbw.ElementAsInt32(kSpectralSubtractionBitsIndex);
params->noise_estimate =
static_cast<uint32_t*>(context->AllocatePersistentBuffer(
context, params->config.num_channels * sizeof(uint32_t)));
if (params->noise_estimate == nullptr) {
return nullptr;
}
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TfLiteTensor* noise_estimate =
micro_context->AllocateTempOutputTensor(node, kNoiseEstimateTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, noise_estimate != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(noise_estimate), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, noise_estimate->type, kTfLiteUInt32);
auto* params =
reinterpret_cast<TFLMSignalSpectralSubtractionParams*>(node->user_data);
TfLiteTypeSizeOf(output->type, &params->noise_estimate_size);
params->noise_estimate_size *= ElementCount(*noise_estimate->dims);
ResetState(params);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(noise_estimate);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalSpectralSubtractionParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* noise_estimate =
tflite::micro::GetEvalOutput(context, node, kNoiseEstimateTensor);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
uint32_t* noise_estimate_data =
tflite::micro::GetTensorData<uint32_t>(noise_estimate);
FilterbankSpectralSubtraction(&params->config, input_data, output_data,
params->noise_estimate);
memcpy(noise_estimate_data, params->noise_estimate,
params->noise_estimate_size);
return kTfLiteOk;
}
void Reset(TfLiteContext* context, void* buffer) {
ResetState(static_cast<TFLMSignalSpectralSubtractionParams*>(buffer));
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_SPECTRAL_SUBTRACTION() {
static TFLMRegistration r =
tflite::micro::RegisterOp(Init, Prepare, Eval, /*Free*/ nullptr, Reset);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_spectral_subtraction_32_channel;
extern const unsigned char
g_gen_data_filter_bank_spectral_subtraction_32_channel[];
extern const int g_gen_data_size_filter_bank_spectral_subtraction_16_channel;
extern const unsigned char
g_gen_data_filter_bank_spectral_subtraction_16_channel[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_square_root.h"
#include <stdint.h>
#include "signal/micro/kernels/filter_bank_square_root.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kScaleBitsTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* scale_bits =
tflite::micro::GetEvalInput(context, node, kScaleBitsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const uint64_t* input_data = tflite::micro::GetTensorData<uint64_t>(input);
const int32_t* scale_bits_data =
tflite::micro::GetTensorData<int32_t>(scale_bits);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
int32_t num_channels = input->dims->data[0];
tflm_signal::FilterbankSqrt(input_data, num_channels, *scale_bits_data,
output_data);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_SQUARE_ROOT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(nullptr, FilterBankSquareRootPrepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_FILTER_BANK_SQUARE_ROOT_H_
#define SIGNAL_MICRO_KERNELS_FILTER_BANK_SQUARE_ROOT_H_
#include "tensorflow/lite/c/common.h"
namespace tflite {
TfLiteStatus FilterBankSquareRootPrepare(TfLiteContext* context,
TfLiteNode* node);
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_FILTER_BANK_SQUARE_ROOT_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/micro/kernels/filter_bank_square_root.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
constexpr int kInputTensor = 0;
constexpr int kScaleBitsTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus FilterBankSquareRootPrepare(TfLiteContext* context,
TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* scale_bits =
micro_context->AllocateTempInputTensor(node, kScaleBitsTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, scale_bits != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(scale_bits), 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt64);
TF_LITE_ENSURE_TYPES_EQ(context, scale_bits->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(scale_bits);
return kTfLiteOk;
}
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include "signal/src/circular_buffer.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kOutputValidTensor = 1;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kFrameSizeIndex = 0; // 'frame_size'
constexpr int kFrameStepIndex = 1; // 'frame_step'
constexpr int kPrefillIndex = 2; // 'prefill'
struct TFLMSignalFramerParams {
int32_t frame_size;
int32_t frame_step;
int32_t outer_dims;
int32_t n_frames;
bool prefill;
int8_t** state_buffers;
tflite::tflm_signal::CircularBuffer** circular_buffers;
};
void ResetState(TFLMSignalFramerParams* params) {
for (int i = 0; i < params->outer_dims; ++i) {
tflite::tflm_signal::CircularBufferReset(params->circular_buffers[i]);
if (params->prefill) {
tflite::tflm_signal::CircularBufferWriteZeros(
params->circular_buffers[i], params->frame_size - params->frame_step);
}
}
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
auto* params =
static_cast<TFLMSignalFramerParams*>(context->AllocatePersistentBuffer(
context, sizeof(TFLMSignalFramerParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(buffer_t, length);
params->frame_size = fbw.ElementAsInt32(kFrameSizeIndex);
params->frame_step = fbw.ElementAsInt32(kFrameStepIndex);
params->prefill = fbw.ElementAsBool(kPrefillIndex);
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteTensor* output_valid =
micro_context->AllocateTempOutputTensor(node, kOutputValidTensor);
TF_LITE_ENSURE(context, output_valid != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input) + 1, NumDimensions(output));
TF_LITE_ENSURE_EQ(context, NumDimensions(output_valid), 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output_valid->type, kTfLiteBool);
auto* params = reinterpret_cast<TFLMSignalFramerParams*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
int innermost_dim = input_shape.Dims(input_shape.DimensionsCount() - 1);
TF_LITE_ENSURE(context, innermost_dim >= params->frame_step);
TF_LITE_ENSURE_EQ(context, innermost_dim % params->frame_step, 0);
params->outer_dims = input_shape.FlatSize() / innermost_dim;
params->n_frames = innermost_dim / params->frame_step;
params->state_buffers =
static_cast<int8_t**>(context->AllocatePersistentBuffer(
context, params->outer_dims * sizeof(int8_t*)));
params->circular_buffers = static_cast<tflite::tflm_signal::CircularBuffer**>(
context->AllocatePersistentBuffer(
context,
params->outer_dims * sizeof(tflite::tflm_signal::CircularBuffer*)));
for (int i = 0; i < params->outer_dims; i++) {
// Calculate the capacity of the circular buffer. Round up the frame size to
// a multiple of frame step. Saves memory relative to the simpler frame_size
// + frame_step. For example: step_size = 160, frame_size = 400 capacity =
// 480 vs. step_size + frame_size = 560
size_t capacity = (params->frame_size + params->frame_step - 1) /
params->frame_step * params->frame_step;
size_t state_size =
tflite::tflm_signal::CircularBufferGetNeededMemory(capacity);
params->state_buffers[i] =
static_cast<int8_t*>(context->AllocatePersistentBuffer(
context, state_size * sizeof(int8_t)));
params->circular_buffers[i] = tflite::tflm_signal::CircularBufferInit(
capacity, params->state_buffers[i], state_size);
}
ResetState(params);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(output_valid);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalFramerParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* output_valid =
tflite::micro::GetEvalOutput(context, node, kOutputValidTensor);
const int16_t* input_data = tflite::micro::GetTensorData<int16_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
bool* output_valid_data = tflite::micro::GetTensorData<bool>(output_valid);
*output_valid_data = true;
for (int i = 0; i < params->outer_dims; i++) {
for (int frame = 0; frame < params->n_frames; frame++) {
int input_idx = (i * params->n_frames + frame) * params->frame_step;
int output_idx = (i * params->n_frames + frame) * params->frame_size;
tflite::tflm_signal::CircularBufferWrite(params->circular_buffers[i],
&input_data[input_idx],
params->frame_step);
if (tflite::tflm_signal::CircularBufferAvailable(
params->circular_buffers[i]) >=
static_cast<size_t>(params->frame_size)) {
tflite::tflm_signal::CircularBufferGet(params->circular_buffers[i],
params->frame_size,
&output_data[output_idx]);
tflite::tflm_signal::CircularBufferDiscard(params->circular_buffers[i],
params->frame_step);
} else {
*output_valid_data = false;
}
}
}
return kTfLiteOk;
}
void Reset(TfLiteContext* context, void* buffer) {
ResetState(static_cast<TFLMSignalFramerParams*>(buffer));
}
} // namespace
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
TFLMRegistration* Register_FRAMER() {
static TFLMRegistration r =
tflite::micro::RegisterOp(Init, Prepare, Eval, nullptr, Reset);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FRAMER_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FRAMER_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_3_1_0_framer;
extern const unsigned char g_gen_data_3_1_0_framer[];
extern const int g_gen_data_size_5_2_1_framer;
extern const unsigned char g_gen_data_5_2_1_framer[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FRAMER_FLEXBUFFERS_DATA_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/irfft.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'T' is added implicitly by the TensorFlow framework when the type is resolved
// during graph construction.
// constexpr int kTypeIndex = 0; // 'T' (unused)
constexpr int kFftLengthIndex = 1; // 'fft_length'
struct TfLiteAudioFrontendIrfftParams {
int32_t fft_length;
int32_t input_size;
int32_t input_length;
int32_t output_length;
TfLiteType fft_type;
int8_t* state;
};
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t)>
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TfLiteAudioFrontendIrfftParams*>(
context->AllocatePersistentBuffer(
context, sizeof(TfLiteAudioFrontendIrfftParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->fft_length = fbw.ElementAsInt32(kFftLengthIndex);
params->fft_type = typeToTfLiteType<T>();
size_t state_size = (*get_needed_memory_func)(params->fft_length);
params->state = reinterpret_cast<int8_t*>(
context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t)));
if (params->state == nullptr) {
return nullptr;
}
(*init_func)(params->fft_length, params->state, state_size);
return params;
}
template <TfLiteType TfLiteTypeEnum>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum);
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
// Divide by 2 because input is complex.
params->input_length =
input_shape.Dims(input_shape.DimensionsCount() - 1) / 2;
params->input_size = input_shape.FlatSize() / 2;
params->output_length = output_shape.Dims(output_shape.DimensionsCount() - 1);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T, void (*apply_func)(void*, const Complex<T>* input, T*)>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const Complex<T>* input_data =
tflite::micro::GetTensorData<Complex<T>>(input);
T* output_data = tflite::micro::GetTensorData<T>(output);
for (int input_idx = 0, output_idx = 0; input_idx < params->input_size;
input_idx += params->input_length, output_idx += params->output_length) {
(*apply_func)(params->state, &input_data[input_idx],
&output_data[output_idx]);
}
return kTfLiteOk;
}
void* InitAll(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
auto tensor_type = static_cast<tflite::TensorType>(m["T"].AsInt32());
switch (tensor_type) {
case TensorType_INT16: {
return Init<int16_t, tflm_signal::IrfftInt16GetNeededMemory,
tflm_signal::IrfftInt16Init>(context, buffer, length);
}
case TensorType_INT32: {
return Init<int32_t, tflm_signal::IrfftInt32GetNeededMemory,
tflm_signal::IrfftInt32Init>(context, buffer, length);
}
case TensorType_FLOAT32: {
return Init<float, tflm_signal::IrfftFloatGetNeededMemory,
tflm_signal::IrfftFloatInit>(context, buffer, length);
}
default:
return nullptr;
}
}
TfLiteStatus PrepareAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Prepare<kTfLiteInt16>(context, node);
}
case kTfLiteInt32: {
return Prepare<kTfLiteInt32>(context, node);
}
case kTfLiteFloat32: {
return Prepare<kTfLiteFloat32>(context, node);
}
default:
return kTfLiteError;
}
}
TfLiteStatus EvalAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Eval<int16_t, tflm_signal::IrfftInt16Apply>(context, node);
}
case kTfLiteInt32: {
return Eval<int32_t, tflm_signal::IrfftInt32Apply>(context, node);
}
case kTfLiteFloat32: {
return Eval<float, tflm_signal::IrfftFloatApply>(context, node);
}
default:
return kTfLiteError;
}
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_IRFFT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(InitAll, PrepareAll, EvalAll);
return &r;
}
TFLMRegistration* Register_IRFFT_FLOAT() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<float, IrfftFloatGetNeededMemory, IrfftFloatInit>,
Prepare<kTfLiteFloat32>, Eval<float, IrfftFloatApply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT16() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int16_t, IrfftInt16GetNeededMemory, IrfftInt16Init>,
Prepare<kTfLiteInt16>, Eval<int16_t, IrfftInt16Apply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT32() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int32_t, IrfftInt32GetNeededMemory, IrfftInt32Init>,
Prepare<kTfLiteInt32>, Eval<int32_t, IrfftInt32Apply>);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
\ No newline at end of file
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_IRFFT_H_
#define SIGNAL_MICRO_KERNELS_IRFFT_H_
#include "tensorflow/lite/micro/micro_common.h"
namespace tflite {
namespace tflm_signal {
TFLMRegistration* Register_IRFFT();
TFLMRegistration* Register_IRFFT_FLOAT();
TFLMRegistration* Register_IRFFT_INT16();
TFLMRegistration* Register_IRFFT_INT32();
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_IRFFT_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/overlap_add.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'T' is added implicitly by the TensorFlow framework when the type is resolved
// during graph construction.
// constexpr int kTypeIndex = 0; // 'T' (unused)
constexpr int kFrameStepIndex = 1; // 'frame_step'
template <typename T>
struct TFLMSignalOverlapAddParams {
int32_t frame_size;
int32_t frame_step;
int32_t outer_dims;
int32_t n_frames;
TfLiteType type;
T** state_buffers;
};
template <typename T>
void ResetState(TFLMSignalOverlapAddParams<T>* params) {
for (int i = 0; i < params->outer_dims; i++) {
memset(params->state_buffers[i], 0, sizeof(T) * params->frame_size);
}
}
template <typename T>
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
auto* params = static_cast<TFLMSignalOverlapAddParams<T>*>(
context->AllocatePersistentBuffer(context,
sizeof(TFLMSignalOverlapAddParams<T>)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(buffer_t, length);
params->type = typeToTfLiteType<T>();
params->frame_step = fbw.ElementAsInt32(kFrameStepIndex);
return params;
}
template <typename T, TfLiteType TfLiteTypeEnum>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output) + 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum);
auto* params =
reinterpret_cast<TFLMSignalOverlapAddParams<T>*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
TF_LITE_ENSURE(context, input_shape.DimensionsCount() >= 2);
TF_LITE_ENSURE_EQ(context, input_shape.DimensionsCount(),
output_shape.DimensionsCount() + 1);
params->frame_size = input_shape.Dims(input_shape.DimensionsCount() - 1);
params->n_frames = input_shape.Dims(input_shape.DimensionsCount() - 2);
params->outer_dims =
input_shape.FlatSize() / (params->frame_size * params->n_frames);
params->state_buffers = static_cast<T**>(context->AllocatePersistentBuffer(
context, params->outer_dims * sizeof(T*)));
TF_LITE_ENSURE(context, params != nullptr);
for (int i = 0; i < params->outer_dims; i++) {
params->state_buffers[i] =
static_cast<T*>(context->AllocatePersistentBuffer(
context, params->frame_size * sizeof(T)));
}
ResetState(params);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalOverlapAddParams<T>*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const T* input_data = tflite::micro::GetTensorData<T>(input);
T* output_data = tflite::micro::GetTensorData<T>(output);
for (int i = 0; i < params->outer_dims; i++) {
T* buffer = params->state_buffers[i];
for (int frame = 0; frame < params->n_frames; frame++) {
int input_index = (i * params->n_frames + frame) * params->frame_size;
int output_index = (i * params->n_frames + frame) * params->frame_step;
tflm_signal::OverlapAdd(&input_data[input_index], buffer,
params->frame_size, &output_data[output_index],
params->frame_step);
}
}
return kTfLiteOk;
}
template <typename T>
void Reset(TfLiteContext* context, void* buffer) {
ResetState(static_cast<TFLMSignalOverlapAddParams<T>*>(buffer));
}
void* InitAll(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
auto tensor_type = static_cast<tflite::TensorType>(m["T"].AsInt32());
switch (tensor_type) {
case TensorType_INT16: {
return Init<int16_t>(context, buffer, length);
}
case TensorType_FLOAT32: {
return Init<float>(context, buffer, length);
}
default:
return nullptr;
}
}
TfLiteStatus PrepareAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalOverlapAddParams<void>*>(node->user_data);
switch (params->type) {
case kTfLiteInt16: {
return Prepare<int16_t, kTfLiteInt16>(context, node);
}
case kTfLiteFloat32: {
return Prepare<float, kTfLiteFloat32>(context, node);
}
default:
return kTfLiteError;
}
}
TfLiteStatus EvalAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalOverlapAddParams<void>*>(node->user_data);
switch (params->type) {
case kTfLiteInt16: {
return Eval<int16_t>(context, node);
}
case kTfLiteFloat32: {
return Eval<float>(context, node);
}
default:
return kTfLiteError;
}
}
void ResetAll(TfLiteContext* context, void* buffer) {
auto* params = reinterpret_cast<TFLMSignalOverlapAddParams<void>*>(buffer);
switch (params->type) {
case kTfLiteInt16: {
Reset<int16_t>(context, buffer);
break;
}
case kTfLiteFloat32: {
Reset<float>(context, buffer);
break;
}
default:
break;
}
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_OVERLAP_ADD() {
static TFLMRegistration r = tflite::micro::RegisterOp(
InitAll, PrepareAll, EvalAll, nullptr, ResetAll);
return &r;
}
TFLMRegistration* Register_OVERLAP_ADD_FLOAT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(Init<float>, Prepare<float, kTfLiteFloat32>,
Eval<float>, nullptr, Reset<float>);
return &r;
}
TFLMRegistration* Register_OVERLAP_ADD_INT16() {
static TFLMRegistration r =
tflite::micro::RegisterOp(Init<int16_t>, Prepare<int16_t, kTfLiteInt16>,
Eval<int16_t>, nullptr, Reset<int16_t>);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_OVERLAP_ADD_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_OVERLAP_ADD_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_overlap_add_float;
extern const unsigned char g_gen_data_overlap_add_float[];
extern const int g_gen_data_size_overlap_add_int16;
extern const unsigned char g_gen_data_overlap_add_int16[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_OVERLAP_ADD_FLEXBUFFERS_DATA_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/pcan_argc_fixed.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
constexpr int kInputTensor = 0;
constexpr int kNoiseEstimateTensor = 1;
constexpr int kGainLutTensor = 2;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kSnrShiftIndex = 0; // 'snr_shift'
struct TfLitePcanParams {
int snr_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* params = static_cast<TfLitePcanParams*>(
context->AllocatePersistentBuffer(context, sizeof(TfLitePcanParams)));
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->snr_shift = fbw.ElementAsInt32(kSnrShiftIndex);
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* noise_estimate =
micro_context->AllocateTempInputTensor(node, kNoiseEstimateTensor);
TF_LITE_ENSURE(context, noise_estimate != nullptr);
TfLiteTensor* gain_lut =
micro_context->AllocateTempInputTensor(node, kGainLutTensor);
TF_LITE_ENSURE(context, gain_lut != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(noise_estimate), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(gain_lut), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, noise_estimate->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, gain_lut->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(noise_estimate);
micro_context->DeallocateTempTfLiteTensor(gain_lut);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLitePcanParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteEvalTensor* noise_estimate =
tflite::micro::GetEvalInput(context, node, kNoiseEstimateTensor);
TF_LITE_ENSURE(context, noise_estimate != nullptr);
const TfLiteEvalTensor* gain_lut =
tflite::micro::GetEvalInput(context, node, kGainLutTensor);
TF_LITE_ENSURE(context, gain_lut != nullptr);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input);
const uint32_t* noise_estimate_data =
tflite::micro::GetTensorData<uint32_t>(noise_estimate);
const int16_t* gain_lut_data =
tflite::micro::GetTensorData<int16_t>(gain_lut);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
int num_channels = input->dims->data[0];
size_t output_byte_size;
TF_LITE_ENSURE_OK(
context, tflite::TfLiteEvalTensorByteLength(output, &output_byte_size));
memcpy(output_data, input_data, output_byte_size);
tflite::tflm_signal::ApplyPcanAutoGainControlFixed(
gain_lut_data, params->snr_shift, noise_estimate_data, output_data,
num_channels);
return kTfLiteOk;
}
TFLMRegistration* Register_PCAN() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed. Click to expand it.
This diff is collapsed. Click to expand it.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment