Implement predict video

This commit is contained in:
henryruhs 2023-05-31 17:52:21 +02:00
parent 4e2be506ce
commit 96a61397ef

56
run.py
View File

@ -12,7 +12,7 @@ import torch
from pathlib import Path
import tkinter as tk
from tkinter import filedialog
from opennsfw2 import predict_image as face_check
from opennsfw2 import predict_video_frames, Preprocessing
from tkinter.filedialog import asksaveasfilename
import webbrowser
import psutil
@ -37,7 +37,7 @@ parser.add_argument('-o', '--output', help='save output to this file', dest='out
parser.add_argument('--gpu', help='use gpu', dest='gpu', action='store_true', default=False)
parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False)
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
parser.add_argument('--max-memory', help='set max memory', type=int)
parser.add_argument('--max-memory', help='set max memory', default=16, type=int)
parser.add_argument('--max-cores', help='number of cores to use', dest='cores_count', type=int, default=max(psutil.cpu_count() - 2, 2))
for name, value in vars(parser.parse_args()).items():
@ -88,6 +88,13 @@ def pre_check():
def start_processing():
start_time = time.time()
try:
seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=30)
print(seconds, probabilities)
if probabilities > 0.7:
quit('0')
except:
quit('1')
if args['gpu']:
process_video(args['source_img'], args["frame_paths"])
end_time = time.time()
@ -140,6 +147,49 @@ def preview_video(video_path):
cap.release()
def validate_video(video_path):
cap = cv2.VideoCapture('target.mp4')
frame_interval = 10
batch_size = 10
detector = cv2.HOGDescriptor()
# Loop through the video frames
while True:
# Read the next batch of frames
frames = []
for i in range(batch_size):
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
# Stop the loop if there are no more frames to process
if not frames:
break
# Process the frames for nudity detection
for frame in frames:
# Convert the frame to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Apply the nudity detection algorithm
rects, weights = detector.detectMultiScale(gray, winStride=(8, 8))
# Draw rectangles around detected nudity regions
for (x, y, w, h) in rects:
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
# Display the frame with detections marked
# cv2.imshow('Nudity Detection', frame)
# cv2.waitKey(1)
print('detected nude')
# Skip ahead to the next batch of frames
cap.set(cv2.CAP_PROP_POS_MSEC, (cap.get(cv2.CAP_PROP_POS_MSEC) + frame_interval * 1000))
cap.release()
def select_face():
args['source_img'] = filedialog.askopenfilename(title="Select a face")
preview_image(args['source_img'])
@ -192,8 +242,6 @@ def start():
print("\n[WARNING] No face detected in source image. Please try with another one.\n")
return
if is_img(target_path):
if face_check(target_path) > 0.7:
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
process_img(args['source_img'], target_path, args['output_file'])
status("swap successful!")
return