Add deferred real-time processor for enrichments (#22880)

* implement deferred real-time processor with background task handling

* add tests

* fix typing
This commit is contained in:
Josh Hawkins
2026-04-14 22:39:44 -05:00
committed by GitHub
parent 4232792248
commit a47be12ac5
4 changed files with 546 additions and 123 deletions
+125 -1
View File
@@ -1,8 +1,12 @@
"""Local only processors for handling real time object processing."""
import logging
import threading
from abc import ABC, abstractmethod
from typing import Any
from collections import deque
from concurrent.futures import Future
from queue import Empty, Full, Queue
from typing import Any, Callable
import numpy as np
@@ -74,3 +78,123 @@ class RealTimeProcessorApi(ABC):
payload: The updated configuration object.
"""
pass
def drain_results(self) -> list[dict[str, Any]]:
"""Return pending results that need IPC side-effects.
Deferred processors accumulate results on a worker thread.
The maintainer calls this each loop iteration to collect them
and perform publishes on the main thread.
Synchronous processors return an empty list (default).
"""
return []
def shutdown(self) -> None:
"""Stop any background work and release resources.
Called when the processor is being removed or the maintainer
is shutting down. Default is a no-op for synchronous processors.
"""
pass
class DeferredRealtimeProcessorApi(RealTimeProcessorApi):
"""Base class for processors that offload heavy work to a background thread.
Subclasses implement:
- process_frame(): do cheap gating + crop + copy, then call _enqueue_task()
- _process_task(task): heavy work (inference, consensus) on the worker thread
- handle_request(): optionally use _enqueue_request() for sync request/response
- expire_object(): call _enqueue_task() with a control message
The worker thread owns all processor state. No locks are needed because
only the worker mutates state. Results that need IPC are placed in
_pending_results via _emit_result(), and the maintainer drains them
each loop iteration.
"""
def __init__(
self,
config: FrigateConfig,
metrics: DataProcessorMetrics,
max_queue: int = 8,
) -> None:
super().__init__(config, metrics)
self._task_queue: Queue = Queue(maxsize=max_queue)
self._pending_results: deque[dict[str, Any]] = deque()
self._results_lock = threading.Lock()
self._stop_event = threading.Event()
self._worker = threading.Thread(
target=self._drain_loop,
daemon=True,
name=f"{type(self).__name__}_worker",
)
self._worker.start()
def _drain_loop(self) -> None:
"""Worker thread main loop — drains the task queue until stopped."""
while not self._stop_event.is_set():
try:
task = self._task_queue.get(timeout=0.5)
except Empty:
continue
if (
isinstance(task, tuple)
and len(task) == 2
and isinstance(task[1], Future)
):
# Request/response: (callable_and_args, future)
(func, args), future = task
try:
result = func(args)
future.set_result(result)
except Exception as e:
future.set_exception(e)
else:
try:
self._process_task(task)
except Exception:
logger.exception("Error processing deferred task")
def _enqueue_task(self, task: Any) -> bool:
"""Enqueue a task for the worker. Returns False if queue is full (dropped)."""
try:
self._task_queue.put_nowait(task)
return True
except Full:
logger.debug("Deferred processor queue full, dropping task")
return False
def _enqueue_request(self, func: Callable, args: Any, timeout: float = 10.0) -> Any:
"""Enqueue a request and block until the worker returns a result."""
future: Future = Future()
self._task_queue.put(((func, args), future), timeout=timeout)
return future.result(timeout=timeout)
def _emit_result(self, result: dict[str, Any]) -> None:
"""Called by the worker thread to stage a result for the maintainer."""
with self._results_lock:
self._pending_results.append(result)
def drain_results(self) -> list[dict[str, Any]]:
"""Called by the maintainer on the main thread to collect pending results."""
with self._results_lock:
results = list(self._pending_results)
self._pending_results.clear()
return results
def shutdown(self) -> None:
"""Signal the worker to stop and wait for it to finish."""
self._stop_event.set()
self._worker.join(timeout=5.0)
@abstractmethod
def _process_task(self, task: Any) -> None:
"""Process a single task on the worker thread.
Subclasses implement inference, consensus, training image saves here.
Call _emit_result() to stage results for the maintainer to publish.
"""
pass
@@ -1,7 +1,6 @@
"""Real time processor that works with classification tflite models."""
import datetime
import json
import logging
import os
from typing import Any
@@ -10,25 +9,18 @@ import cv2
import numpy as np
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
EventMetadataTypeEnum,
)
from frigate.comms.event_metadata_updater import EventMetadataPublisher
from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig
from frigate.config.classification import (
CustomClassificationConfig,
ObjectClassificationType,
)
from frigate.config.classification import CustomClassificationConfig
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
from frigate.log import suppress_stderr_during
from frigate.types import TrackedObjectUpdateTypesEnum
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
from frigate.util.image import calculate_region
from frigate.util.object import box_overlaps
from ..types import DataProcessorMetrics
from .api import RealTimeProcessorApi
from .api import DeferredRealtimeProcessorApi
try:
from tflite_runtime.interpreter import Interpreter
@@ -40,7 +32,7 @@ logger = logging.getLogger(__name__)
MAX_OBJECT_CLASSIFICATIONS = 16
class CustomStateClassificationProcessor(RealTimeProcessorApi):
class CustomStateClassificationProcessor(DeferredRealtimeProcessorApi):
def __init__(
self,
config: FrigateConfig,
@@ -48,7 +40,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
super().__init__(config, metrics, max_queue=4)
self.model_config = model_config
if not self.model_config.name:
@@ -259,14 +251,34 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
return
frame = rgb[y1:y2, x1:x2]
cropped_frame = rgb[y1:y2, x1:x2]
try:
resized_frame = cv2.resize(frame, (224, 224))
resized_frame = cv2.resize(cropped_frame, (224, 224))
except Exception:
logger.warning("Failed to resize image for state classification")
return
# Copy for training image saves on worker thread
crop_bgr = cv2.cvtColor(cropped_frame, cv2.COLOR_RGB2BGR)
self._enqueue_task(("classify", camera, now, resized_frame, crop_bgr))
def _process_task(self, task: Any) -> None:
kind = task[0]
if kind == "classify":
_, camera, timestamp, resized_frame, crop_bgr = task
self._classify_state(camera, timestamp, resized_frame, crop_bgr)
elif kind == "reload":
self.__build_detector()
def _classify_state(
self,
camera: str,
timestamp: float,
resized_frame: np.ndarray,
crop_bgr: np.ndarray,
) -> None:
if self.interpreter is None:
# When interpreter is None, always save (score is 0.0, which is < 1.0)
if self._should_save_image(camera, "unknown", 0.0):
@@ -277,15 +289,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
crop_bgr,
"none-none",
now,
timestamp,
"unknown",
0.0,
max_files=save_attempts,
)
return
if not self.tensor_input_details or not self.tensor_output_details:
return
input = np.expand_dims(resized_frame, axis=0)
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
self.interpreter.invoke()
@@ -298,7 +313,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now)
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
detected_state = self.labelmap[best_id]
@@ -310,9 +325,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
crop_bgr,
"none-none",
now,
timestamp,
detected_state,
score,
max_files=save_attempts,
@@ -327,9 +342,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
verified_state = self.verify_state_change(camera, detected_state)
if verified_state is not None:
self.requestor.send_data(
f"{camera}/classification/{self.model_config.name}",
verified_state,
self._emit_result(
{
"type": "classification",
"processor": "state",
"model_name": self.model_config.name,
"camera": camera,
"state": verified_state,
}
)
def handle_request(
@@ -337,14 +357,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
) -> dict[str, Any] | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name:
self.__build_detector()
logger.info(
f"Successfully loaded updated model for {self.model_config.name}"
)
return {
"success": True,
"message": f"Loaded {self.model_config.name} model.",
}
def _do_reload(data: dict[str, Any]) -> dict[str, Any]:
self.__build_detector()
logger.info(
f"Successfully loaded updated model for {self.model_config.name}"
)
return {
"success": True,
"message": f"Loaded {self.model_config.name} model.",
}
result: dict[str, Any] = self._enqueue_request(_do_reload, request_data)
return result
else:
return None
else:
@@ -354,7 +379,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
pass
class CustomObjectClassificationProcessor(RealTimeProcessorApi):
class CustomObjectClassificationProcessor(DeferredRealtimeProcessorApi):
def __init__(
self,
config: FrigateConfig,
@@ -363,7 +388,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
requestor: InterProcessRequestor,
metrics: DataProcessorMetrics,
):
super().__init__(config, metrics)
super().__init__(config, metrics, max_queue=8)
self.model_config = model_config
if not self.model_config.name:
@@ -536,18 +561,41 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
crop = rgb[
y:y2,
x:x2,
]
crop = rgb[y:y2, x:x2]
if crop.shape != (224, 224):
try:
resized_crop = cv2.resize(crop, (224, 224))
except Exception:
logger.warning("Failed to resize image for state classification")
return
try:
resized_crop = cv2.resize(crop, (224, 224))
except Exception:
logger.warning("Failed to resize image for object classification")
return
# Copy crop for training images (will be used on worker thread)
crop_bgr = cv2.cvtColor(crop, cv2.COLOR_RGB2BGR)
self._enqueue_task(
("classify", object_id, obj_data["camera"], now, resized_crop, crop_bgr)
)
def _process_task(self, task: Any) -> None:
kind = task[0]
if kind == "classify":
_, object_id, camera, timestamp, resized_crop, crop_bgr = task
self._classify_object(object_id, camera, timestamp, resized_crop, crop_bgr)
elif kind == "expire":
_, object_id = task
if object_id in self.classification_history:
self.classification_history.pop(object_id)
elif kind == "reload":
self.__build_detector()
def _classify_object(
self,
object_id: str,
camera: str,
timestamp: float,
resized_crop: np.ndarray,
crop_bgr: np.ndarray,
) -> None:
if self.interpreter is None:
save_attempts = (
self.model_config.save_attempts
@@ -556,9 +604,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
crop_bgr,
object_id,
now,
timestamp,
"unknown",
0.0,
max_files=save_attempts,
@@ -569,7 +617,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
if object_id not in self.classification_history:
self.classification_history[object_id] = []
self.classification_history[object_id].append(("unknown", 0.0, now))
self.classification_history[object_id].append(("unknown", 0.0, timestamp))
return
if not self.tensor_input_details or not self.tensor_output_details:
return
input = np.expand_dims(resized_crop, axis=0)
@@ -584,7 +635,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
best_id = int(np.argmax(probs))
score = round(probs[best_id], 2)
self.__update_metrics(datetime.datetime.now().timestamp() - now)
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
save_attempts = (
self.model_config.save_attempts
@@ -593,9 +644,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
)
write_classification_attempt(
self.train_dir,
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
crop_bgr,
object_id,
now,
timestamp,
self.labelmap[best_id],
score,
max_files=save_attempts,
@@ -610,92 +661,57 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
sub_label = self.labelmap[best_id]
logger.debug(
f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}"
f"{self.model_config.name}: Object {object_id} passed threshold with sub_label={sub_label}, score={score}"
)
consensus_label, consensus_score = self.get_weighted_score(
object_id, sub_label, score, now
object_id, sub_label, score, timestamp
)
logger.debug(
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
)
if consensus_label is not None:
camera = obj_data["camera"]
logger.debug(
f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}"
if consensus_label is not None and self.model_config.object_config is not None:
self._emit_result(
{
"type": "classification",
"processor": "object",
"model_name": self.model_config.name,
"classification_type": self.model_config.object_config.classification_type,
"object_id": object_id,
"camera": camera,
"timestamp": timestamp,
"label": consensus_label,
"score": consensus_score,
}
)
if (
self.model_config.object_config.classification_type
== ObjectClassificationType.sub_label
):
self.sub_label_publisher.publish(
(object_id, consensus_label, consensus_score),
EventMetadataTypeEnum.sub_label,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": now,
"model": self.model_config.name,
"sub_label": consensus_label,
"score": consensus_score,
}
),
)
elif (
self.model_config.object_config.classification_type
== ObjectClassificationType.attribute
):
self.sub_label_publisher.publish(
(
object_id,
self.model_config.name,
consensus_label,
consensus_score,
),
EventMetadataTypeEnum.attribute.value,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": now,
"model": self.model_config.name,
"attribute": consensus_label,
"score": consensus_score,
}
),
)
def handle_request(self, topic: str, request_data: dict) -> dict | None:
def handle_request(
self, topic: str, request_data: dict[str, Any]
) -> dict[str, Any] | None:
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
if request_data.get("model_name") == self.model_config.name:
self.__build_detector()
logger.info(
f"Successfully loaded updated model for {self.model_config.name}"
)
return {
"success": True,
"message": f"Loaded {self.model_config.name} model.",
}
def _do_reload(data: dict[str, Any]) -> dict[str, Any]:
self.__build_detector()
logger.info(
f"Successfully loaded updated model for {self.model_config.name}"
)
return {
"success": True,
"message": f"Loaded {self.model_config.name} model.",
}
result: dict[str, Any] = self._enqueue_request(_do_reload, request_data)
return result
else:
return None
else:
return None
def expire_object(self, object_id: str, camera: str) -> None:
if object_id in self.classification_history:
self.classification_history.pop(object_id)
self._enqueue_task(("expire", object_id))
def write_classification_attempt(
+78 -6
View File
@@ -2,6 +2,7 @@
import base64
import datetime
import json
import logging
import threading
from multiprocessing.synchronize import Event as MpEvent
@@ -33,6 +34,7 @@ from frigate.config.camera.updater import (
CameraConfigUpdateEnum,
CameraConfigUpdateSubscriber,
)
from frigate.config.classification import ObjectClassificationType
from frigate.data_processing.common.license_plate.model import (
LicensePlateModelRunner,
)
@@ -61,6 +63,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum
from frigate.genai import GenAIClientManager
from frigate.models import Event, Recordings, ReviewSegment, Trigger
from frigate.types import TrackedObjectUpdateTypesEnum
from frigate.util.builtin import serialize
from frigate.util.file import get_event_thumbnail_bytes
from frigate.util.image import SharedMemoryFrameManager
@@ -274,10 +277,15 @@ class EmbeddingMaintainer(threading.Thread):
self._process_recordings_updates()
self._process_review_updates()
self._process_frame_updates()
self._process_deferred_results()
self._expire_dedicated_lpr()
self._process_finalized()
self._process_event_metadata()
# Shutdown deferred processors
for processor in self.realtime_processors:
processor.shutdown()
self.config_updater.stop()
self.enrichment_config_subscriber.stop()
self.event_subscriber.stop()
@@ -316,10 +324,9 @@ class EmbeddingMaintainer(threading.Thread):
model_name = topic.split("/")[-1]
if model_config is None:
self.realtime_processors = [
processor
for processor in self.realtime_processors
if not (
remaining = []
for processor in self.realtime_processors:
if (
isinstance(
processor,
(
@@ -328,8 +335,11 @@ class EmbeddingMaintainer(threading.Thread):
),
)
and processor.model_config.name == model_name
)
]
):
processor.shutdown()
else:
remaining.append(processor)
self.realtime_processors = remaining
logger.info(
f"Successfully removed classification processor for model: {model_name}"
@@ -697,6 +707,68 @@ class EmbeddingMaintainer(threading.Thread):
self.frame_manager.close(frame_name)
def _process_deferred_results(self) -> None:
"""Drain results from deferred processors and perform IPC side-effects."""
for processor in self.realtime_processors:
results = processor.drain_results()
for result in results:
if result.get("type") != "classification":
continue
if result["processor"] == "state":
self.requestor.send_data(
f"{result['camera']}/classification/{result['model_name']}",
result["state"],
)
elif result["processor"] == "object":
object_id = result["object_id"]
camera = result["camera"]
timestamp = result["timestamp"]
model_name = result["model_name"]
label = result["label"]
score = result["score"]
classification_type = result["classification_type"]
if classification_type == ObjectClassificationType.sub_label:
self.event_metadata_publisher.publish(
(object_id, label, score),
EventMetadataTypeEnum.sub_label,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": timestamp,
"model": model_name,
"sub_label": label,
"score": score,
}
),
)
elif classification_type == ObjectClassificationType.attribute:
self.event_metadata_publisher.publish(
(object_id, model_name, label, score),
EventMetadataTypeEnum.attribute.value,
)
self.requestor.send_data(
"tracked_object_update",
json.dumps(
{
"type": TrackedObjectUpdateTypesEnum.classification,
"id": object_id,
"camera": camera,
"timestamp": timestamp,
"model": model_name,
"attribute": label,
"score": score,
}
),
)
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
"""Embed the thumbnail for an event."""
if not self.config.semantic_search.enabled:
+211
View File
@@ -0,0 +1,211 @@
"""Tests for DeferredRealtimeProcessorApi."""
import sys
import time
import unittest
from typing import Any
from unittest.mock import MagicMock, patch
import numpy as np
from frigate.data_processing.real_time.api import DeferredRealtimeProcessorApi
# Mock TFLite before importing classification module
_MOCK_MODULES = [
"tflite_runtime",
"tflite_runtime.interpreter",
"ai_edge_litert",
"ai_edge_litert.interpreter",
]
for mod in _MOCK_MODULES:
if mod not in sys.modules:
sys.modules[mod] = MagicMock()
from frigate.data_processing.real_time.custom_classification import ( # noqa: E402
CustomObjectClassificationProcessor,
)
class StubDeferredProcessor(DeferredRealtimeProcessorApi):
"""Minimal concrete subclass for testing the deferred base."""
def __init__(self, max_queue: int = 8):
config = MagicMock()
metrics = MagicMock()
super().__init__(config, metrics, max_queue=max_queue)
self.processed_items: list[tuple] = []
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
"""Enqueue every call — no gating logic in the stub."""
self._enqueue_task(("frame", obj_data, frame.copy()))
def _process_task(self, task: tuple) -> None:
kind = task[0]
if kind == "frame":
_, obj_data, frame = task
self.processed_items.append((obj_data["id"], frame.shape))
self._emit_result(
{
"type": "test_result",
"id": obj_data["id"],
"label": "cat",
"score": 0.95,
}
)
elif kind == "expire":
_, object_id = task
self.processed_items.append(("expired", object_id))
def handle_request(
self, topic: str, request_data: dict[str, Any]
) -> dict[str, Any] | None:
if topic == "reload":
def _do_reload(data):
return {"success": True, "model": data.get("name")}
return self._enqueue_request(_do_reload, request_data)
return None
def expire_object(self, object_id: str, camera: str) -> None:
self._enqueue_task(("expire", object_id))
class TestDeferredProcessorBase(unittest.TestCase):
def test_enqueue_and_drain(self):
"""Tasks enqueued on main thread are processed by worker, results are drainable."""
proc = StubDeferredProcessor()
frame = np.zeros((100, 100, 3), dtype=np.uint8)
proc.process_frame({"id": "obj1"}, frame)
proc.process_frame({"id": "obj2"}, frame)
# Give the worker time to process
time.sleep(0.1)
results = proc.drain_results()
self.assertEqual(len(results), 2)
self.assertEqual(results[0]["id"], "obj1")
self.assertEqual(results[1]["id"], "obj2")
# Second drain should be empty
self.assertEqual(len(proc.drain_results()), 0)
def test_backpressure_drops_tasks(self):
"""When queue is full, new tasks are silently dropped."""
proc = StubDeferredProcessor(max_queue=2)
frame = np.zeros((10, 10, 3), dtype=np.uint8)
for i in range(10):
proc.process_frame({"id": f"obj{i}"}, frame)
time.sleep(0.2)
results = proc.drain_results()
# The key property: no crash, no unbounded growth
self.assertLessEqual(len(results), 10)
self.assertGreater(len(results), 0)
def test_handle_request_through_worker(self):
"""handle_request blocks until the worker processes it and returns a response."""
proc = StubDeferredProcessor()
result = proc.handle_request("reload", {"name": "my_model"})
self.assertEqual(result, {"success": True, "model": "my_model"})
def test_expire_object_serialized_with_work(self):
"""expire_object goes through the queue, serialized with inference work."""
proc = StubDeferredProcessor()
frame = np.zeros((10, 10, 3), dtype=np.uint8)
proc.process_frame({"id": "obj1"}, frame)
proc.expire_object("obj1", "front_door")
time.sleep(0.1)
# Both should have been processed in order
self.assertEqual(len(proc.processed_items), 2)
self.assertEqual(proc.processed_items[0][0], "obj1")
self.assertEqual(proc.processed_items[1], ("expired", "obj1"))
def test_shutdown_joins_worker(self):
"""shutdown() signals the worker to stop and joins the thread."""
proc = StubDeferredProcessor()
proc.shutdown()
self.assertFalse(proc._worker.is_alive())
def test_drain_results_returns_list(self):
"""drain_results returns a plain list, not a deque."""
proc = StubDeferredProcessor()
results = proc.drain_results()
self.assertIsInstance(results, list)
class TestCustomObjectClassificationDeferred(unittest.TestCase):
"""Test that CustomObjectClassificationProcessor uses the deferred pattern correctly."""
def _make_processor(self):
config = MagicMock()
model_config = MagicMock()
model_config.name = "test_breed"
model_config.object_config = MagicMock()
model_config.object_config.objects = ["dog"]
model_config.threshold = 0.5
model_config.save_attempts = 10
model_config.object_config.classification_type = "sub_label"
publisher = MagicMock()
requestor = MagicMock()
metrics = MagicMock()
metrics.classification_speeds = {}
metrics.classification_cps = {}
with patch.object(
CustomObjectClassificationProcessor,
"_CustomObjectClassificationProcessor__build_detector",
):
proc = CustomObjectClassificationProcessor(
config, model_config, publisher, requestor, metrics
)
proc.interpreter = None
proc.tensor_input_details = [{"index": 0}]
proc.tensor_output_details = [{"index": 0}]
proc.labelmap = {0: "labrador", 1: "poodle", 2: "none"}
return proc
def test_is_deferred_processor(self):
"""CustomObjectClassificationProcessor should be a DeferredRealtimeProcessorApi."""
proc = self._make_processor()
self.assertIsInstance(proc, DeferredRealtimeProcessorApi)
def test_expire_clears_history(self):
"""expire_object should clear classification history for the object."""
proc = self._make_processor()
proc.classification_history["obj1"] = [("labrador", 0.9, 1.0)]
proc.expire_object("obj1", "front")
time.sleep(0.1)
self.assertNotIn("obj1", proc.classification_history)
def test_drain_results_empty_when_no_model(self):
"""With no interpreter, process_frame saves training images but emits no results."""
proc = self._make_processor()
proc.interpreter = None
frame = np.zeros((150, 100), dtype=np.uint8)
obj_data = {
"id": "obj1",
"label": "dog",
"false_positive": False,
"end_time": None,
"box": [10, 10, 50, 50],
"camera": "front",
}
with patch(
"frigate.data_processing.real_time.custom_classification.write_classification_attempt"
):
proc.process_frame(obj_data, frame)
time.sleep(0.1)
results = proc.drain_results()
self.assertEqual(len(results), 0)
if __name__ == "__main__":
unittest.main()