feat: AMD DML optimization - GPU face detection, detection throttle, pre-load fix

This commit is contained in:
ozp3
2026-03-28 13:09:20 +03:00
parent 9e6f30c0a4
commit eac2ad2307
7 changed files with 52 additions and 17 deletions
BIN
View File
Binary file not shown.
+5 -2
View File
@@ -2,7 +2,7 @@ import os
import sys
# single thread doubles cuda performance - needs to be set before torch import
if any(arg.startswith('--execution-provider') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '6'
# reduce tensorflow log level
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings
@@ -291,9 +291,12 @@ def run() -> None:
for frame_processor in get_frame_processors_modules(modules.globals.frame_processors):
if not frame_processor.pre_check():
return
# Pre-load face analyser in main thread before GUI starts
#from modules.face_analyser import get_face_analyser
#get_face_analyser()
limit_resources()
if modules.globals.headless:
start()
else:
window = ui.init(start, destroy, modules.globals.lang)
window.mainloop()
window.mainloop()
+7 -3
View File
@@ -35,7 +35,9 @@ def get_face_analyser() -> Any:
def get_one_face(frame: Frame) -> Any:
face = get_face_analyser().get(frame)
import modules.globals as g
with g.dml_lock:
face = get_face_analyser().get(frame)
try:
return min(face, key=lambda x: x.bbox[0])
except ValueError:
@@ -43,8 +45,10 @@ def get_one_face(frame: Frame) -> Any:
def get_many_faces(frame: Frame) -> Any:
import modules.globals as g
try:
return get_face_analyser().get(frame)
with g.dml_lock:
return get_face_analyser().get(frame)
except IndexError:
return None
@@ -196,4 +200,4 @@ def dump_faces(centroids: Any, frame_face_embeddings: list):
if temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)].size > 0:
cv2.imwrite(temp_directory_path + f"/{i}/{frame['frame']}_{j}.png", temp_frame[int(y_min):int(y_max), int(x_min):int(x_max)])
j += 1
j += 1
+3
View File
@@ -71,3 +71,6 @@ interpolation_weight: float = 0 # Blend weight for current frame (0.0-1.0). Low
# --- END: Added for Frame Interpolation ---
# --- END OF FILE globals.py ---
import threading
dml_lock = threading.Lock()
+5 -5
View File
@@ -110,7 +110,6 @@ def get_face_swapper() -> Any:
))
else:
providers_config.append(p)
FACE_SWAPPER = insightface.model_zoo.get_model(
model_path,
providers=providers_config,
@@ -153,9 +152,10 @@ def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
if not temp_frame.flags['C_CONTIGUOUS']:
temp_frame = np.ascontiguousarray(temp_frame)
swapped_frame_raw = face_swapper.get(
temp_frame, target_face, source_face, paste_back=True
)
with modules.globals.dml_lock:
swapped_frame_raw = face_swapper.get(
temp_frame, target_face, source_face, paste_back=True
)
# --- START: CRITICAL FIX FOR ORT 1.17 ---
# Check the output type and range from the model
@@ -1183,4 +1183,4 @@ def apply_color_transfer(source, target):
# traceback.print_exc()
return source
return result_bgr
return result_bgr
+27 -7
View File
@@ -72,8 +72,8 @@ ROOT_WIDTH = 600
PREVIEW = None
PREVIEW_MAX_HEIGHT = 700
PREVIEW_MAX_WIDTH = 1200
PREVIEW_DEFAULT_WIDTH = 960
PREVIEW_DEFAULT_HEIGHT = 540
PREVIEW_DEFAULT_WIDTH = 640
PREVIEW_DEFAULT_HEIGHT = 360
POPUP_WIDTH = 750
POPUP_HEIGHT = 810
@@ -1000,6 +1000,10 @@ def webcam_preview(root: ctk.CTk, camera_index: int):
if modules.globals.source_path is None:
update_status("Please select a source image first")
return
from modules.processors.frame.face_swapper import get_face_swapper
from modules.face_analyser import get_face_analyser
get_face_analyser()
get_face_swapper()
create_webcam_preview(camera_index)
else:
modules.globals.source_target_map = []
@@ -1105,7 +1109,7 @@ def _detection_thread_func(latest_frame_holder, detection_result, detection_lock
frame = latest_frame_holder[0]
if frame is None:
time.sleep(0.005)
time.sleep(0.2)
continue
if modules.globals.many_faces:
@@ -1157,7 +1161,22 @@ def _processing_thread_func(capture_queue, processed_queue, stop_event,
source_image = get_one_face(cv2.imread(modules.globals.source_path))
# Read latest detection results (brief lock to avoid blocking detection thread)
with detection_lock:
# Run detection inline since detection thread is disabled
# Run detection every 3 frames, reuse cached result otherwise
if not hasattr(_processing_thread_func, '_det_count'):
_processing_thread_func._det_count = 0
_processing_thread_func._det_count += 1
if _processing_thread_func._det_count % 3 == 0:
if modules.globals.many_faces:
cached_target_face = None
cached_many_faces = get_many_faces(temp_frame)
detection_result['many_faces'] = cached_many_faces
else:
cached_target_face = get_one_face(temp_frame)
cached_many_faces = None
detection_result['target_face'] = cached_target_face
else:
cached_target_face = detection_result.get('target_face')
cached_many_faces = detection_result.get('many_faces')
@@ -1275,7 +1294,7 @@ def create_webcam_preview(camera_index: int):
args=(latest_frame_holder, detection_result, detection_lock, stop_event),
daemon=True,
)
det_thread.start()
# det_thread.start()
# Start processing thread
proc_thread = threading.Thread(
@@ -1316,7 +1335,7 @@ def create_webcam_preview(camera_index: int):
temp_frame = fit_image_to_size(
temp_frame, PREVIEW.winfo_width(), PREVIEW.winfo_height()
)
temp_frame = temp_frame.copy()
image = gpu_cvt_color(temp_frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
image = ImageOps.contain(
@@ -1574,4 +1593,5 @@ def update_webcam_target(
target_label_dict_live[button_num] = target_image
else:
update_pop_live_status("Face could not be detected in last upload!")
return map
return map
+5
View File
@@ -0,0 +1,5 @@
@echo off
cd /d "%~dp0"
call venv\Scripts\activate
python run.py --execution-provider dml
pause