mirror of
https://github.com/s0md3v/roop.git
synced 2024-08-24 08:31:17 +08:00
Implement predict video
This commit is contained in:
parent
4e2be506ce
commit
96a61397ef
56
run.py
56
run.py
@ -12,7 +12,7 @@ import torch
|
||||
from pathlib import Path
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
from opennsfw2 import predict_image as face_check
|
||||
from opennsfw2 import predict_video_frames, Preprocessing
|
||||
from tkinter.filedialog import asksaveasfilename
|
||||
import webbrowser
|
||||
import psutil
|
||||
@ -37,7 +37,7 @@ parser.add_argument('-o', '--output', help='save output to this file', dest='out
|
||||
parser.add_argument('--gpu', help='use gpu', dest='gpu', action='store_true', default=False)
|
||||
parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False)
|
||||
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
|
||||
parser.add_argument('--max-memory', help='set max memory', type=int)
|
||||
parser.add_argument('--max-memory', help='set max memory', default=16, type=int)
|
||||
parser.add_argument('--max-cores', help='number of cores to use', dest='cores_count', type=int, default=max(psutil.cpu_count() - 2, 2))
|
||||
|
||||
for name, value in vars(parser.parse_args()).items():
|
||||
@ -88,6 +88,13 @@ def pre_check():
|
||||
|
||||
def start_processing():
|
||||
start_time = time.time()
|
||||
try:
|
||||
seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=30)
|
||||
print(seconds, probabilities)
|
||||
if probabilities > 0.7:
|
||||
quit('0')
|
||||
except:
|
||||
quit('1')
|
||||
if args['gpu']:
|
||||
process_video(args['source_img'], args["frame_paths"])
|
||||
end_time = time.time()
|
||||
@ -140,6 +147,49 @@ def preview_video(video_path):
|
||||
cap.release()
|
||||
|
||||
|
||||
def validate_video(video_path):
|
||||
cap = cv2.VideoCapture('target.mp4')
|
||||
frame_interval = 10
|
||||
batch_size = 10
|
||||
detector = cv2.HOGDescriptor()
|
||||
|
||||
# Loop through the video frames
|
||||
while True:
|
||||
# Read the next batch of frames
|
||||
frames = []
|
||||
for i in range(batch_size):
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frames.append(frame)
|
||||
|
||||
# Stop the loop if there are no more frames to process
|
||||
if not frames:
|
||||
break
|
||||
|
||||
# Process the frames for nudity detection
|
||||
for frame in frames:
|
||||
# Convert the frame to grayscale
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Apply the nudity detection algorithm
|
||||
rects, weights = detector.detectMultiScale(gray, winStride=(8, 8))
|
||||
|
||||
# Draw rectangles around detected nudity regions
|
||||
for (x, y, w, h) in rects:
|
||||
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
|
||||
|
||||
# Display the frame with detections marked
|
||||
# cv2.imshow('Nudity Detection', frame)
|
||||
# cv2.waitKey(1)
|
||||
print('detected nude')
|
||||
|
||||
# Skip ahead to the next batch of frames
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, (cap.get(cv2.CAP_PROP_POS_MSEC) + frame_interval * 1000))
|
||||
|
||||
cap.release()
|
||||
|
||||
|
||||
def select_face():
|
||||
args['source_img'] = filedialog.askopenfilename(title="Select a face")
|
||||
preview_image(args['source_img'])
|
||||
@ -192,8 +242,6 @@ def start():
|
||||
print("\n[WARNING] No face detected in source image. Please try with another one.\n")
|
||||
return
|
||||
if is_img(target_path):
|
||||
if face_check(target_path) > 0.7:
|
||||
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
|
||||
process_img(args['source_img'], target_path, args['output_file'])
|
||||
status("swap successful!")
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user