mirror of
https://github.com/blakeblackshear/frigate.git
synced 2026-04-22 15:07:41 +08:00
Add deferred real-time processor for enrichments (#22880)
* implement deferred real-time processor with background task handling * add tests * fix typing
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
"""Local only processors for handling real time object processing."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -74,3 +78,123 @@ class RealTimeProcessorApi(ABC):
|
||||
payload: The updated configuration object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def drain_results(self) -> list[dict[str, Any]]:
|
||||
"""Return pending results that need IPC side-effects.
|
||||
|
||||
Deferred processors accumulate results on a worker thread.
|
||||
The maintainer calls this each loop iteration to collect them
|
||||
and perform publishes on the main thread.
|
||||
|
||||
Synchronous processors return an empty list (default).
|
||||
"""
|
||||
return []
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop any background work and release resources.
|
||||
|
||||
Called when the processor is being removed or the maintainer
|
||||
is shutting down. Default is a no-op for synchronous processors.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DeferredRealtimeProcessorApi(RealTimeProcessorApi):
|
||||
"""Base class for processors that offload heavy work to a background thread.
|
||||
|
||||
Subclasses implement:
|
||||
- process_frame(): do cheap gating + crop + copy, then call _enqueue_task()
|
||||
- _process_task(task): heavy work (inference, consensus) on the worker thread
|
||||
- handle_request(): optionally use _enqueue_request() for sync request/response
|
||||
- expire_object(): call _enqueue_task() with a control message
|
||||
|
||||
The worker thread owns all processor state. No locks are needed because
|
||||
only the worker mutates state. Results that need IPC are placed in
|
||||
_pending_results via _emit_result(), and the maintainer drains them
|
||||
each loop iteration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FrigateConfig,
|
||||
metrics: DataProcessorMetrics,
|
||||
max_queue: int = 8,
|
||||
) -> None:
|
||||
super().__init__(config, metrics)
|
||||
self._task_queue: Queue = Queue(maxsize=max_queue)
|
||||
self._pending_results: deque[dict[str, Any]] = deque()
|
||||
self._results_lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self._worker = threading.Thread(
|
||||
target=self._drain_loop,
|
||||
daemon=True,
|
||||
name=f"{type(self).__name__}_worker",
|
||||
)
|
||||
self._worker.start()
|
||||
|
||||
def _drain_loop(self) -> None:
|
||||
"""Worker thread main loop — drains the task queue until stopped."""
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
task = self._task_queue.get(timeout=0.5)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(task, tuple)
|
||||
and len(task) == 2
|
||||
and isinstance(task[1], Future)
|
||||
):
|
||||
# Request/response: (callable_and_args, future)
|
||||
(func, args), future = task
|
||||
try:
|
||||
result = func(args)
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
else:
|
||||
try:
|
||||
self._process_task(task)
|
||||
except Exception:
|
||||
logger.exception("Error processing deferred task")
|
||||
|
||||
def _enqueue_task(self, task: Any) -> bool:
|
||||
"""Enqueue a task for the worker. Returns False if queue is full (dropped)."""
|
||||
try:
|
||||
self._task_queue.put_nowait(task)
|
||||
return True
|
||||
except Full:
|
||||
logger.debug("Deferred processor queue full, dropping task")
|
||||
return False
|
||||
|
||||
def _enqueue_request(self, func: Callable, args: Any, timeout: float = 10.0) -> Any:
|
||||
"""Enqueue a request and block until the worker returns a result."""
|
||||
future: Future = Future()
|
||||
self._task_queue.put(((func, args), future), timeout=timeout)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def _emit_result(self, result: dict[str, Any]) -> None:
|
||||
"""Called by the worker thread to stage a result for the maintainer."""
|
||||
with self._results_lock:
|
||||
self._pending_results.append(result)
|
||||
|
||||
def drain_results(self) -> list[dict[str, Any]]:
|
||||
"""Called by the maintainer on the main thread to collect pending results."""
|
||||
with self._results_lock:
|
||||
results = list(self._pending_results)
|
||||
self._pending_results.clear()
|
||||
return results
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Signal the worker to stop and wait for it to finish."""
|
||||
self._stop_event.set()
|
||||
self._worker.join(timeout=5.0)
|
||||
|
||||
@abstractmethod
|
||||
def _process_task(self, task: Any) -> None:
|
||||
"""Process a single task on the worker thread.
|
||||
|
||||
Subclasses implement inference, consensus, training image saves here.
|
||||
Call _emit_result() to stage results for the maintainer to publish.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Real time processor that works with classification tflite models."""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -10,25 +9,18 @@ import cv2
|
||||
import numpy as np
|
||||
|
||||
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum
|
||||
from frigate.comms.event_metadata_updater import (
|
||||
EventMetadataPublisher,
|
||||
EventMetadataTypeEnum,
|
||||
)
|
||||
from frigate.comms.event_metadata_updater import EventMetadataPublisher
|
||||
from frigate.comms.inter_process import InterProcessRequestor
|
||||
from frigate.config import FrigateConfig
|
||||
from frigate.config.classification import (
|
||||
CustomClassificationConfig,
|
||||
ObjectClassificationType,
|
||||
)
|
||||
from frigate.config.classification import CustomClassificationConfig
|
||||
from frigate.const import CLIPS_DIR, MODEL_CACHE_DIR
|
||||
from frigate.log import suppress_stderr_during
|
||||
from frigate.types import TrackedObjectUpdateTypesEnum
|
||||
from frigate.util.builtin import EventsPerSecond, InferenceSpeed, load_labels
|
||||
from frigate.util.image import calculate_region
|
||||
from frigate.util.object import box_overlaps
|
||||
|
||||
from ..types import DataProcessorMetrics
|
||||
from .api import RealTimeProcessorApi
|
||||
from .api import DeferredRealtimeProcessorApi
|
||||
|
||||
try:
|
||||
from tflite_runtime.interpreter import Interpreter
|
||||
@@ -40,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
MAX_OBJECT_CLASSIFICATIONS = 16
|
||||
|
||||
|
||||
class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
class CustomStateClassificationProcessor(DeferredRealtimeProcessorApi):
|
||||
def __init__(
|
||||
self,
|
||||
config: FrigateConfig,
|
||||
@@ -48,7 +40,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
requestor: InterProcessRequestor,
|
||||
metrics: DataProcessorMetrics,
|
||||
):
|
||||
super().__init__(config, metrics)
|
||||
super().__init__(config, metrics, max_queue=4)
|
||||
self.model_config = model_config
|
||||
|
||||
if not self.model_config.name:
|
||||
@@ -259,14 +251,34 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
return
|
||||
|
||||
frame = rgb[y1:y2, x1:x2]
|
||||
cropped_frame = rgb[y1:y2, x1:x2]
|
||||
|
||||
try:
|
||||
resized_frame = cv2.resize(frame, (224, 224))
|
||||
resized_frame = cv2.resize(cropped_frame, (224, 224))
|
||||
except Exception:
|
||||
logger.warning("Failed to resize image for state classification")
|
||||
return
|
||||
|
||||
# Copy for training image saves on worker thread
|
||||
crop_bgr = cv2.cvtColor(cropped_frame, cv2.COLOR_RGB2BGR)
|
||||
|
||||
self._enqueue_task(("classify", camera, now, resized_frame, crop_bgr))
|
||||
|
||||
def _process_task(self, task: Any) -> None:
|
||||
kind = task[0]
|
||||
if kind == "classify":
|
||||
_, camera, timestamp, resized_frame, crop_bgr = task
|
||||
self._classify_state(camera, timestamp, resized_frame, crop_bgr)
|
||||
elif kind == "reload":
|
||||
self.__build_detector()
|
||||
|
||||
def _classify_state(
|
||||
self,
|
||||
camera: str,
|
||||
timestamp: float,
|
||||
resized_frame: np.ndarray,
|
||||
crop_bgr: np.ndarray,
|
||||
) -> None:
|
||||
if self.interpreter is None:
|
||||
# When interpreter is None, always save (score is 0.0, which is < 1.0)
|
||||
if self._should_save_image(camera, "unknown", 0.0):
|
||||
@@ -277,15 +289,18 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
||||
crop_bgr,
|
||||
"none-none",
|
||||
now,
|
||||
timestamp,
|
||||
"unknown",
|
||||
0.0,
|
||||
max_files=save_attempts,
|
||||
)
|
||||
return
|
||||
|
||||
if not self.tensor_input_details or not self.tensor_output_details:
|
||||
return
|
||||
|
||||
input = np.expand_dims(resized_frame, axis=0)
|
||||
self.interpreter.set_tensor(self.tensor_input_details[0]["index"], input)
|
||||
self.interpreter.invoke()
|
||||
@@ -298,7 +313,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
|
||||
|
||||
detected_state = self.labelmap[best_id]
|
||||
|
||||
@@ -310,9 +325,9 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(frame, cv2.COLOR_RGB2BGR),
|
||||
crop_bgr,
|
||||
"none-none",
|
||||
now,
|
||||
timestamp,
|
||||
detected_state,
|
||||
score,
|
||||
max_files=save_attempts,
|
||||
@@ -327,9 +342,14 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
verified_state = self.verify_state_change(camera, detected_state)
|
||||
|
||||
if verified_state is not None:
|
||||
self.requestor.send_data(
|
||||
f"{camera}/classification/{self.model_config.name}",
|
||||
verified_state,
|
||||
self._emit_result(
|
||||
{
|
||||
"type": "classification",
|
||||
"processor": "state",
|
||||
"model_name": self.model_config.name,
|
||||
"camera": camera,
|
||||
"state": verified_state,
|
||||
}
|
||||
)
|
||||
|
||||
def handle_request(
|
||||
@@ -337,14 +357,19 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
) -> dict[str, Any] | None:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
|
||||
def _do_reload(data: dict[str, Any]) -> dict[str, Any]:
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
|
||||
result: dict[str, Any] = self._enqueue_request(_do_reload, request_data)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
@@ -354,7 +379,7 @@ class CustomStateClassificationProcessor(RealTimeProcessorApi):
|
||||
pass
|
||||
|
||||
|
||||
class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
class CustomObjectClassificationProcessor(DeferredRealtimeProcessorApi):
|
||||
def __init__(
|
||||
self,
|
||||
config: FrigateConfig,
|
||||
@@ -363,7 +388,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
requestor: InterProcessRequestor,
|
||||
metrics: DataProcessorMetrics,
|
||||
):
|
||||
super().__init__(config, metrics)
|
||||
super().__init__(config, metrics, max_queue=8)
|
||||
self.model_config = model_config
|
||||
|
||||
if not self.model_config.name:
|
||||
@@ -536,18 +561,41 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_YUV2RGB_I420)
|
||||
crop = rgb[
|
||||
y:y2,
|
||||
x:x2,
|
||||
]
|
||||
crop = rgb[y:y2, x:x2]
|
||||
|
||||
if crop.shape != (224, 224):
|
||||
try:
|
||||
resized_crop = cv2.resize(crop, (224, 224))
|
||||
except Exception:
|
||||
logger.warning("Failed to resize image for state classification")
|
||||
return
|
||||
try:
|
||||
resized_crop = cv2.resize(crop, (224, 224))
|
||||
except Exception:
|
||||
logger.warning("Failed to resize image for object classification")
|
||||
return
|
||||
|
||||
# Copy crop for training images (will be used on worker thread)
|
||||
crop_bgr = cv2.cvtColor(crop, cv2.COLOR_RGB2BGR)
|
||||
|
||||
self._enqueue_task(
|
||||
("classify", object_id, obj_data["camera"], now, resized_crop, crop_bgr)
|
||||
)
|
||||
|
||||
def _process_task(self, task: Any) -> None:
|
||||
kind = task[0]
|
||||
if kind == "classify":
|
||||
_, object_id, camera, timestamp, resized_crop, crop_bgr = task
|
||||
self._classify_object(object_id, camera, timestamp, resized_crop, crop_bgr)
|
||||
elif kind == "expire":
|
||||
_, object_id = task
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
elif kind == "reload":
|
||||
self.__build_detector()
|
||||
|
||||
def _classify_object(
|
||||
self,
|
||||
object_id: str,
|
||||
camera: str,
|
||||
timestamp: float,
|
||||
resized_crop: np.ndarray,
|
||||
crop_bgr: np.ndarray,
|
||||
) -> None:
|
||||
if self.interpreter is None:
|
||||
save_attempts = (
|
||||
self.model_config.save_attempts
|
||||
@@ -556,9 +604,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
crop_bgr,
|
||||
object_id,
|
||||
now,
|
||||
timestamp,
|
||||
"unknown",
|
||||
0.0,
|
||||
max_files=save_attempts,
|
||||
@@ -569,7 +617,10 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
if object_id not in self.classification_history:
|
||||
self.classification_history[object_id] = []
|
||||
|
||||
self.classification_history[object_id].append(("unknown", 0.0, now))
|
||||
self.classification_history[object_id].append(("unknown", 0.0, timestamp))
|
||||
return
|
||||
|
||||
if not self.tensor_input_details or not self.tensor_output_details:
|
||||
return
|
||||
|
||||
input = np.expand_dims(resized_crop, axis=0)
|
||||
@@ -584,7 +635,7 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
best_id = int(np.argmax(probs))
|
||||
score = round(probs[best_id], 2)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - now)
|
||||
self.__update_metrics(datetime.datetime.now().timestamp() - timestamp)
|
||||
|
||||
save_attempts = (
|
||||
self.model_config.save_attempts
|
||||
@@ -593,9 +644,9 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
)
|
||||
write_classification_attempt(
|
||||
self.train_dir,
|
||||
cv2.cvtColor(crop, cv2.COLOR_RGB2BGR),
|
||||
crop_bgr,
|
||||
object_id,
|
||||
now,
|
||||
timestamp,
|
||||
self.labelmap[best_id],
|
||||
score,
|
||||
max_files=save_attempts,
|
||||
@@ -610,92 +661,57 @@ class CustomObjectClassificationProcessor(RealTimeProcessorApi):
|
||||
sub_label = self.labelmap[best_id]
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Object {object_id} (label={obj_data['label']}) passed threshold with sub_label={sub_label}, score={score}"
|
||||
f"{self.model_config.name}: Object {object_id} passed threshold with sub_label={sub_label}, score={score}"
|
||||
)
|
||||
|
||||
consensus_label, consensus_score = self.get_weighted_score(
|
||||
object_id, sub_label, score, now
|
||||
object_id, sub_label, score, timestamp
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: get_weighted_score returned consensus_label={consensus_label}, consensus_score={consensus_score} for {object_id}"
|
||||
)
|
||||
|
||||
if consensus_label is not None:
|
||||
camera = obj_data["camera"]
|
||||
logger.debug(
|
||||
f"{self.model_config.name}: Publishing sub_label={consensus_label} for {obj_data['label']} object {object_id} on {camera}"
|
||||
if consensus_label is not None and self.model_config.object_config is not None:
|
||||
self._emit_result(
|
||||
{
|
||||
"type": "classification",
|
||||
"processor": "object",
|
||||
"model_name": self.model_config.name,
|
||||
"classification_type": self.model_config.object_config.classification_type,
|
||||
"object_id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": timestamp,
|
||||
"label": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.sub_label
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(object_id, consensus_label, consensus_score),
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": now,
|
||||
"model": self.model_config.name,
|
||||
"sub_label": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
elif (
|
||||
self.model_config.object_config.classification_type
|
||||
== ObjectClassificationType.attribute
|
||||
):
|
||||
self.sub_label_publisher.publish(
|
||||
(
|
||||
object_id,
|
||||
self.model_config.name,
|
||||
consensus_label,
|
||||
consensus_score,
|
||||
),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": now,
|
||||
"model": self.model_config.name,
|
||||
"attribute": consensus_label,
|
||||
"score": consensus_score,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def handle_request(self, topic: str, request_data: dict) -> dict | None:
|
||||
def handle_request(
|
||||
self, topic: str, request_data: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
if topic == EmbeddingsRequestEnum.reload_classification_model.value:
|
||||
if request_data.get("model_name") == self.model_config.name:
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
|
||||
def _do_reload(data: dict[str, Any]) -> dict[str, Any]:
|
||||
self.__build_detector()
|
||||
logger.info(
|
||||
f"Successfully loaded updated model for {self.model_config.name}"
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Loaded {self.model_config.name} model.",
|
||||
}
|
||||
|
||||
result: dict[str, Any] = self._enqueue_request(_do_reload, request_data)
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id: str, camera: str) -> None:
|
||||
if object_id in self.classification_history:
|
||||
self.classification_history.pop(object_id)
|
||||
self._enqueue_task(("expire", object_id))
|
||||
|
||||
|
||||
def write_classification_attempt(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from multiprocessing.synchronize import Event as MpEvent
|
||||
@@ -33,6 +34,7 @@ from frigate.config.camera.updater import (
|
||||
CameraConfigUpdateEnum,
|
||||
CameraConfigUpdateSubscriber,
|
||||
)
|
||||
from frigate.config.classification import ObjectClassificationType
|
||||
from frigate.data_processing.common.license_plate.model import (
|
||||
LicensePlateModelRunner,
|
||||
)
|
||||
@@ -61,6 +63,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.events.types import EventTypeEnum, RegenerateDescriptionEnum
|
||||
from frigate.genai import GenAIClientManager
|
||||
from frigate.models import Event, Recordings, ReviewSegment, Trigger
|
||||
from frigate.types import TrackedObjectUpdateTypesEnum
|
||||
from frigate.util.builtin import serialize
|
||||
from frigate.util.file import get_event_thumbnail_bytes
|
||||
from frigate.util.image import SharedMemoryFrameManager
|
||||
@@ -274,10 +277,15 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
self._process_recordings_updates()
|
||||
self._process_review_updates()
|
||||
self._process_frame_updates()
|
||||
self._process_deferred_results()
|
||||
self._expire_dedicated_lpr()
|
||||
self._process_finalized()
|
||||
self._process_event_metadata()
|
||||
|
||||
# Shutdown deferred processors
|
||||
for processor in self.realtime_processors:
|
||||
processor.shutdown()
|
||||
|
||||
self.config_updater.stop()
|
||||
self.enrichment_config_subscriber.stop()
|
||||
self.event_subscriber.stop()
|
||||
@@ -316,10 +324,9 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
model_name = topic.split("/")[-1]
|
||||
|
||||
if model_config is None:
|
||||
self.realtime_processors = [
|
||||
processor
|
||||
for processor in self.realtime_processors
|
||||
if not (
|
||||
remaining = []
|
||||
for processor in self.realtime_processors:
|
||||
if (
|
||||
isinstance(
|
||||
processor,
|
||||
(
|
||||
@@ -328,8 +335,11 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
),
|
||||
)
|
||||
and processor.model_config.name == model_name
|
||||
)
|
||||
]
|
||||
):
|
||||
processor.shutdown()
|
||||
else:
|
||||
remaining.append(processor)
|
||||
self.realtime_processors = remaining
|
||||
|
||||
logger.info(
|
||||
f"Successfully removed classification processor for model: {model_name}"
|
||||
@@ -697,6 +707,68 @@ class EmbeddingMaintainer(threading.Thread):
|
||||
|
||||
self.frame_manager.close(frame_name)
|
||||
|
||||
def _process_deferred_results(self) -> None:
|
||||
"""Drain results from deferred processors and perform IPC side-effects."""
|
||||
for processor in self.realtime_processors:
|
||||
results = processor.drain_results()
|
||||
|
||||
for result in results:
|
||||
if result.get("type") != "classification":
|
||||
continue
|
||||
|
||||
if result["processor"] == "state":
|
||||
self.requestor.send_data(
|
||||
f"{result['camera']}/classification/{result['model_name']}",
|
||||
result["state"],
|
||||
)
|
||||
elif result["processor"] == "object":
|
||||
object_id = result["object_id"]
|
||||
camera = result["camera"]
|
||||
timestamp = result["timestamp"]
|
||||
model_name = result["model_name"]
|
||||
label = result["label"]
|
||||
score = result["score"]
|
||||
classification_type = result["classification_type"]
|
||||
|
||||
if classification_type == ObjectClassificationType.sub_label:
|
||||
self.event_metadata_publisher.publish(
|
||||
(object_id, label, score),
|
||||
EventMetadataTypeEnum.sub_label,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": timestamp,
|
||||
"model": model_name,
|
||||
"sub_label": label,
|
||||
"score": score,
|
||||
}
|
||||
),
|
||||
)
|
||||
elif classification_type == ObjectClassificationType.attribute:
|
||||
self.event_metadata_publisher.publish(
|
||||
(object_id, model_name, label, score),
|
||||
EventMetadataTypeEnum.attribute.value,
|
||||
)
|
||||
self.requestor.send_data(
|
||||
"tracked_object_update",
|
||||
json.dumps(
|
||||
{
|
||||
"type": TrackedObjectUpdateTypesEnum.classification,
|
||||
"id": object_id,
|
||||
"camera": camera,
|
||||
"timestamp": timestamp,
|
||||
"model": model_name,
|
||||
"attribute": label,
|
||||
"score": score,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _embed_thumbnail(self, event_id: str, thumbnail: bytes) -> None:
|
||||
"""Embed the thumbnail for an event."""
|
||||
if not self.config.semantic_search.enabled:
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
"""Tests for DeferredRealtimeProcessorApi."""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from frigate.data_processing.real_time.api import DeferredRealtimeProcessorApi
|
||||
|
||||
# Mock TFLite before importing classification module
|
||||
_MOCK_MODULES = [
|
||||
"tflite_runtime",
|
||||
"tflite_runtime.interpreter",
|
||||
"ai_edge_litert",
|
||||
"ai_edge_litert.interpreter",
|
||||
]
|
||||
for mod in _MOCK_MODULES:
|
||||
if mod not in sys.modules:
|
||||
sys.modules[mod] = MagicMock()
|
||||
|
||||
from frigate.data_processing.real_time.custom_classification import ( # noqa: E402
|
||||
CustomObjectClassificationProcessor,
|
||||
)
|
||||
|
||||
|
||||
class StubDeferredProcessor(DeferredRealtimeProcessorApi):
|
||||
"""Minimal concrete subclass for testing the deferred base."""
|
||||
|
||||
def __init__(self, max_queue: int = 8):
|
||||
config = MagicMock()
|
||||
metrics = MagicMock()
|
||||
super().__init__(config, metrics, max_queue=max_queue)
|
||||
self.processed_items: list[tuple] = []
|
||||
|
||||
def process_frame(self, obj_data: dict[str, Any], frame: np.ndarray) -> None:
|
||||
"""Enqueue every call — no gating logic in the stub."""
|
||||
self._enqueue_task(("frame", obj_data, frame.copy()))
|
||||
|
||||
def _process_task(self, task: tuple) -> None:
|
||||
kind = task[0]
|
||||
if kind == "frame":
|
||||
_, obj_data, frame = task
|
||||
self.processed_items.append((obj_data["id"], frame.shape))
|
||||
self._emit_result(
|
||||
{
|
||||
"type": "test_result",
|
||||
"id": obj_data["id"],
|
||||
"label": "cat",
|
||||
"score": 0.95,
|
||||
}
|
||||
)
|
||||
elif kind == "expire":
|
||||
_, object_id = task
|
||||
self.processed_items.append(("expired", object_id))
|
||||
|
||||
def handle_request(
|
||||
self, topic: str, request_data: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
if topic == "reload":
|
||||
|
||||
def _do_reload(data):
|
||||
return {"success": True, "model": data.get("name")}
|
||||
|
||||
return self._enqueue_request(_do_reload, request_data)
|
||||
return None
|
||||
|
||||
def expire_object(self, object_id: str, camera: str) -> None:
|
||||
self._enqueue_task(("expire", object_id))
|
||||
|
||||
|
||||
class TestDeferredProcessorBase(unittest.TestCase):
|
||||
def test_enqueue_and_drain(self):
|
||||
"""Tasks enqueued on main thread are processed by worker, results are drainable."""
|
||||
proc = StubDeferredProcessor()
|
||||
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
proc.process_frame({"id": "obj1"}, frame)
|
||||
proc.process_frame({"id": "obj2"}, frame)
|
||||
|
||||
# Give the worker time to process
|
||||
time.sleep(0.1)
|
||||
|
||||
results = proc.drain_results()
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]["id"], "obj1")
|
||||
self.assertEqual(results[1]["id"], "obj2")
|
||||
|
||||
# Second drain should be empty
|
||||
self.assertEqual(len(proc.drain_results()), 0)
|
||||
|
||||
def test_backpressure_drops_tasks(self):
|
||||
"""When queue is full, new tasks are silently dropped."""
|
||||
proc = StubDeferredProcessor(max_queue=2)
|
||||
|
||||
frame = np.zeros((10, 10, 3), dtype=np.uint8)
|
||||
for i in range(10):
|
||||
proc.process_frame({"id": f"obj{i}"}, frame)
|
||||
|
||||
time.sleep(0.2)
|
||||
results = proc.drain_results()
|
||||
# The key property: no crash, no unbounded growth
|
||||
self.assertLessEqual(len(results), 10)
|
||||
self.assertGreater(len(results), 0)
|
||||
|
||||
def test_handle_request_through_worker(self):
|
||||
"""handle_request blocks until the worker processes it and returns a response."""
|
||||
proc = StubDeferredProcessor()
|
||||
result = proc.handle_request("reload", {"name": "my_model"})
|
||||
self.assertEqual(result, {"success": True, "model": "my_model"})
|
||||
|
||||
def test_expire_object_serialized_with_work(self):
|
||||
"""expire_object goes through the queue, serialized with inference work."""
|
||||
proc = StubDeferredProcessor()
|
||||
frame = np.zeros((10, 10, 3), dtype=np.uint8)
|
||||
proc.process_frame({"id": "obj1"}, frame)
|
||||
proc.expire_object("obj1", "front_door")
|
||||
|
||||
time.sleep(0.1)
|
||||
# Both should have been processed in order
|
||||
self.assertEqual(len(proc.processed_items), 2)
|
||||
self.assertEqual(proc.processed_items[0][0], "obj1")
|
||||
self.assertEqual(proc.processed_items[1], ("expired", "obj1"))
|
||||
|
||||
def test_shutdown_joins_worker(self):
|
||||
"""shutdown() signals the worker to stop and joins the thread."""
|
||||
proc = StubDeferredProcessor()
|
||||
proc.shutdown()
|
||||
self.assertFalse(proc._worker.is_alive())
|
||||
|
||||
def test_drain_results_returns_list(self):
|
||||
"""drain_results returns a plain list, not a deque."""
|
||||
proc = StubDeferredProcessor()
|
||||
results = proc.drain_results()
|
||||
self.assertIsInstance(results, list)
|
||||
|
||||
|
||||
class TestCustomObjectClassificationDeferred(unittest.TestCase):
|
||||
"""Test that CustomObjectClassificationProcessor uses the deferred pattern correctly."""
|
||||
|
||||
def _make_processor(self):
|
||||
config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
model_config.name = "test_breed"
|
||||
model_config.object_config = MagicMock()
|
||||
model_config.object_config.objects = ["dog"]
|
||||
model_config.threshold = 0.5
|
||||
model_config.save_attempts = 10
|
||||
model_config.object_config.classification_type = "sub_label"
|
||||
publisher = MagicMock()
|
||||
requestor = MagicMock()
|
||||
metrics = MagicMock()
|
||||
metrics.classification_speeds = {}
|
||||
metrics.classification_cps = {}
|
||||
|
||||
with patch.object(
|
||||
CustomObjectClassificationProcessor,
|
||||
"_CustomObjectClassificationProcessor__build_detector",
|
||||
):
|
||||
proc = CustomObjectClassificationProcessor(
|
||||
config, model_config, publisher, requestor, metrics
|
||||
)
|
||||
proc.interpreter = None
|
||||
proc.tensor_input_details = [{"index": 0}]
|
||||
proc.tensor_output_details = [{"index": 0}]
|
||||
proc.labelmap = {0: "labrador", 1: "poodle", 2: "none"}
|
||||
return proc
|
||||
|
||||
def test_is_deferred_processor(self):
|
||||
"""CustomObjectClassificationProcessor should be a DeferredRealtimeProcessorApi."""
|
||||
proc = self._make_processor()
|
||||
self.assertIsInstance(proc, DeferredRealtimeProcessorApi)
|
||||
|
||||
def test_expire_clears_history(self):
|
||||
"""expire_object should clear classification history for the object."""
|
||||
proc = self._make_processor()
|
||||
proc.classification_history["obj1"] = [("labrador", 0.9, 1.0)]
|
||||
|
||||
proc.expire_object("obj1", "front")
|
||||
time.sleep(0.1)
|
||||
|
||||
self.assertNotIn("obj1", proc.classification_history)
|
||||
|
||||
def test_drain_results_empty_when_no_model(self):
|
||||
"""With no interpreter, process_frame saves training images but emits no results."""
|
||||
proc = self._make_processor()
|
||||
proc.interpreter = None
|
||||
|
||||
frame = np.zeros((150, 100), dtype=np.uint8)
|
||||
obj_data = {
|
||||
"id": "obj1",
|
||||
"label": "dog",
|
||||
"false_positive": False,
|
||||
"end_time": None,
|
||||
"box": [10, 10, 50, 50],
|
||||
"camera": "front",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"frigate.data_processing.real_time.custom_classification.write_classification_attempt"
|
||||
):
|
||||
proc.process_frame(obj_data, frame)
|
||||
|
||||
time.sleep(0.1)
|
||||
results = proc.drain_results()
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user