ReHiFace-S/app.py
2024-08-20 17:06:10 +08:00

175 lines
6.8 KiB
Python

import gradio as gr
import cv2
import os
import numpy as np
import numexpr as ne
from concurrent.futures import ThreadPoolExecutor
from face_feature.hifi_image_api import HifiImage
from HifiFaceAPI_parallel_trt_roi_realtime_sr_api import HifiFaceRealTime
from face_lib.face_swap import HifiFace
from face_restore.gfpgan_onnx_api import GFPGAN
from face_restore.xseg_onnx_api import XSEG
from face_detect.face_align_68 import face_alignment_landmark
from face_detect.face_detect import FaceDetect
from options.hifi_test_options import HifiTestOptions
from color_transfer import color_transfer
opt = HifiTestOptions().parse()
processor = None
def initialize_processor():
global processor
if processor is None:
processor = FaceSwapProcessor(crop_size=opt.input_size)
class FaceSwapProcessor:
def __init__(self, crop_size=256):
self.hi = HifiImage(crop_size=crop_size)
self.xseg = XSEG(model_type='xseg_0611', provider='gpu')
self.hf = HifiFace(model_name='er8_bs1', provider='gpu')
self.scrfd_detector = FaceDetect(mode='scrfd_500m', tracking_thres=0.15)
self.face_alignment = face_alignment_landmark(lm_type=68)
self.gfp = GFPGAN(model_type='GFPGANv1.4', provider='gpu')
self.crop_size = crop_size
def reverse2wholeimage_hifi_trt_roi(self, swaped_img, mat_rev, img_mask, frame, roi_img, roi_box):
target_image = cv2.warpAffine(swaped_img, mat_rev, roi_img.shape[:2][::-1], borderMode=cv2.BORDER_REPLICATE)[
...,
::-1]
local_dict = {
'img_mask': img_mask,
'target_image': target_image,
'roi_img': roi_img,
}
img = ne.evaluate('img_mask * (target_image * 255)+(1 - img_mask) * roi_img', local_dict=local_dict,
global_dict=None)
img = img.astype(np.uint8)
frame[roi_box[1]:roi_box[3], roi_box[0]:roi_box[2]] = img
return frame
def process_frame(self, frame, image_latent, use_gfpgan, sr_weight, use_color_trans, color_trans_mode):
_, bboxes, kpss = self.scrfd_detector.get_bboxes(frame, max_num=0)
rois, faces, Ms, masks = self.face_alignment.forward(
frame, bboxes, kpss, limit=5, min_face_size=30,
crop_size=(self.crop_size, self.crop_size), apply_roi=True
)
if len(faces) == 0:
return frame
elif len(faces) == 1:
face = np.array(faces[0])
mat = Ms[0]
roi_box = rois[0]
else:
max_index = np.argmax([roi[2] * roi[3] for roi in rois]) # Get the largest face
face = np.array(faces[max_index])
mat = Ms[max_index]
roi_box = rois[max_index]
roi_img = frame[roi_box[1]:roi_box[3], roi_box[0]:roi_box[2]]
face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
mask_out, swap_face_out = self.hf.forward(face, image_latent[0].reshape(1, -1))
mask_out = self.xseg.forward(swap_face_out)[None, None]
mask = cv2.warpAffine(mask_out[0][0].astype(np.float32), mat, roi_img.shape[:2][::-1])
mask[mask > 0.2] = 1
mask = mask[:, :, np.newaxis].astype(np.uint8)
swap_face = swap_face_out[0].transpose((1, 2, 0)).astype(np.float32)
target_face = (face.copy() / 255).astype(np.float32)
if use_gfpgan:
sr_face = self.gfp.forward(swap_face)
if sr_weight != 1.0:
sr_face = cv2.addWeighted(sr_face, sr_weight, swap_face, 1.0 - sr_weight, 0)
if use_color_trans:
transed_face = color_transfer(color_trans_mode, (sr_face + 1) / 2, target_face)
swap_face = (transed_face * 2) - 1
else:
swap_face = sr_face
elif use_color_trans:
transed_face = color_transfer(color_trans_mode, (swap_face + 1) / 2, target_face)
swap_face = (transed_face * 2) - 1
swap_face = ((swap_face + 1) / 2)
frame_out = self.reverse2wholeimage_hifi_trt_roi(
swap_face, mat, mask,
frame, roi_img, roi_box
)
return frame_out
def process_image_video(image, video_path, use_gfpgan, sr_weight, use_color_trans, color_trans_mode):
global processor
initialize_processor()
src_latent, _ = processor.hi.get_face_feature(image)
image_latent = [src_latent]
video = cv2.VideoCapture(video_path)
video_fps = video.get(cv2.CAP_PROP_FPS)
video_size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
output_dir = 'data/output/'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
swap_video_path = output_dir + 'temp.mp4'
videoWriter = cv2.VideoWriter(swap_video_path, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, video_size)
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
futures = []
while True:
ret, frame = video.read()
if not ret:
break
future = executor.submit(processor.process_frame, frame, image_latent, use_gfpgan, sr_weight,
use_color_trans, color_trans_mode)
futures.append(future)
for future in futures:
processed_frame = future.result()
if processed_frame is not None:
videoWriter.write(processed_frame)
video.release()
videoWriter.release()
add_audio_to_video(video_path, swap_video_path)
return swap_video_path
def add_audio_to_video(original_video_path, swapped_video_path):
audio_file_path = original_video_path.split('.')[0] + '.wav'
if not os.path.exists(audio_file_path):
os.system(f'ffmpeg -y -hide_banner -loglevel error -i "{original_video_path}" -f wav -vn "{audio_file_path}"')
temp_output_path = swapped_video_path.replace('.mp4', '_with_audio.mp4')
os.system(
f'ffmpeg -y -hide_banner -loglevel error -i "{swapped_video_path}" -i "{audio_file_path}" -c:v copy -c:a aac "{temp_output_path}"')
os.remove(swapped_video_path)
os.rename(temp_output_path, swapped_video_path)
# Gradio interface setup
iface = gr.Interface(
fn=process_image_video,
inputs=[
gr.Image(type="pil", label="Source Image"),
gr.Video(label="Input Video"),
gr.Checkbox(label="Use GFPGAN [Super-Resolution]"),
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="SR Weight [only support GFPGAN enabled]", value=1.0),
gr.Checkbox(label="Use Color Transfer"),
gr.Dropdown(choices=["rct", "lct", "mkl", "idt", "sot"],
label="Color Transfer Mode [only support Color-Transfer enabled]", value="rct")
],
outputs=gr.Video(label="Output Video"),
title="Video Generation",
description="Upload an image and a video, and the system will generate a new video based on the input."
)
if __name__ == "__main__":
iface.launch()