diff --git a/run.py b/run.py index c530977..3a8b6b0 100755 --- a/run.py +++ b/run.py @@ -12,7 +12,7 @@ import torch from pathlib import Path import tkinter as tk from tkinter import filedialog -from opennsfw2 import predict_video_frames +from opennsfw2 import predict_video_frames, predict_image from tkinter.filedialog import asksaveasfilename import webbrowser import psutil @@ -88,9 +88,6 @@ def pre_check(): def start_processing(): start_time = time.time() - seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=50) - if any(probability > 0.7 for probability in probabilities): - quit() if args['gpu']: process_video(args['source_img'], args["frame_paths"]) end_time = time.time() @@ -195,9 +192,14 @@ def start(): print("\n[WARNING] No face detected in source image. Please try with another one.\n") return if is_img(target_path): + if predict_image(args[target_path]) > 0.7: + quit() process_img(args['source_img'], target_path, args['output_file']) status("swap successful!") return + seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=50) + if any(probability > 0.7 for probability in probabilities): + quit() video_name_full = target_path.split("/")[-1] video_name = os.path.splitext(video_name_full)[0] output_dir = os.path.dirname(target_path) + "/" + video_name