Fix GPU for NVIDIA

This commit is contained in:
henryruhs 2023-05-31 16:30:24 +02:00
parent 38fb60efca
commit 4e2be506ce

21
run.py Normal file → Executable file
View File

@ -8,24 +8,24 @@ import glob
import argparse
import multiprocessing as mp
import os
import random
import torch
from pathlib import Path
import tkinter as tk
from tkinter import filedialog
from opennsfw2 import predict_image as face_check
from tkinter.filedialog import asksaveasfilename
import core.globals
from core.processor import process_video, process_img
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
from core.config import get_face
import webbrowser
import psutil
import cv2
import threading
from PIL import Image, ImageTk
import core.globals
from core.processor import process_video, process_img
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
from core.config import get_face
if 'ROCMExecutionProvider' not in core.globals.providers:
import torch
if 'ROCMExecutionProvider' in core.globals.providers:
del torch
pool = None
args = {}
@ -69,8 +69,7 @@ def pre_check():
if not os.path.isfile(model_path):
quit('File "inswapper_128.onnx" does not exist!')
if '--gpu' in sys.argv:
NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider']
if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1:
if 'ROCMExecutionProvider' not in core.globals.providers:
CUDA_VERSION = torch.version.cuda
CUDNN_VERSION = torch.backends.cudnn.version()
if not torch.cuda.is_available() or not CUDA_VERSION:
@ -89,10 +88,6 @@ def pre_check():
def start_processing():
start_time = time.time()
threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10
for i in range(threshold):
if face_check(random.choice(args['frame_paths'])) > 0.8:
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
if args['gpu']:
process_video(args['source_img'], args["frame_paths"])
end_time = time.time()