Back to Repositories

Testing TensorFlow Model Loading Implementation in OpenPilot

This test suite validates TensorFlow model loading and graph import functionality in C++, ensuring proper handling of model files and graph definitions in the OpenPilot system.

Test Coverage Overview

The test coverage focuses on critical TensorFlow C API operations, specifically model loading and graph import functionality.

  • Tests file reading and buffer management
  • Validates TensorFlow graph import process
  • Handles memory allocation and deallocation
  • Verifies error conditions and status codes

Implementation Analysis

The testing approach implements low-level C++ functionality to interface with TensorFlow’s C API, ensuring robust model loading capabilities.

The implementation utilizes direct file I/O operations and TensorFlow’s buffer management system, with careful attention to memory handling and error checking patterns.

Technical Details

  • Uses TensorFlow C API (tensorflow/c/c_api.h)
  • Implements custom file reading functionality
  • Employs TF_Buffer for model data management
  • Utilizes TF_Graph and TF_Status for graph operations
  • Includes memory deallocation callback implementation

Best Practices Demonstrated

The test demonstrates several key testing best practices for systems integration.

  • Proper resource cleanup and memory management
  • Explicit error handling and status checking
  • Modular function design for file operations
  • Clear separation of model loading and graph import steps

commaai/openpilot

selfdrive/modeld/tests/tf_test/main.cc

            
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include "tensorflow/c/c_api.h"

void* read_file(const char* path, size_t* out_len) {
  FILE* f = fopen(path, "r");
  if (!f) {
    return NULL;
  }
  fseek(f, 0, SEEK_END);
  long f_len = ftell(f);
  rewind(f);

  char* buf = (char*)calloc(f_len, 1);
  assert(buf);

  size_t num_read = fread(buf, f_len, 1, f);
  fclose(f);

  if (num_read != 1) {
    free(buf);
    return NULL;
  }

  if (out_len) {
    *out_len = f_len;
  }

  return buf;
}

static void DeallocateBuffer(void* data, size_t) {
  free(data);
}

int main(int argc, char* argv[]) {
  TF_Buffer* buf;
  TF_Graph* graph;
  TF_Status* status;
  char *path = argv[1];

  // load model
  {
    size_t model_size;
    char tmp[1024];
    snprintf(tmp, sizeof(tmp), "%s.pb", path);
    printf("loading model %s
", tmp);
    uint8_t *model_data = (uint8_t *)read_file(tmp, &model_size);
    buf = TF_NewBuffer();
    buf->data = model_data;
    buf->length = model_size;
    buf->data_deallocator = DeallocateBuffer;
    printf("loaded model of size %d
", model_size);
  }

  // import graph
  status = TF_NewStatus();
  graph = TF_NewGraph();
  TF_ImportGraphDefOptions *opts = TF_NewImportGraphDefOptions();
  TF_GraphImportGraphDef(graph, buf, opts, status);
  TF_DeleteImportGraphDefOptions(opts);
  TF_DeleteBuffer(buf);
  if (TF_GetCode(status) != TF_OK) {
    printf("FAIL: %s
", TF_Message(status));
  } else {
    printf("SUCCESS
");
  }
}