找回密码
 立即注册
首页 业界区 业界 人脸伪造判别分类网络CNN&Transformer

人脸伪造判别分类网络CNN&Transformer

数察啜 6 天前
作者:SkyXZ
CSDN:SkyXZ~-CSDN博客
博客园:SkyXZ - 博客园
一、获取数据集

        FaceForensics++ 是一个取证数据集,由1000段原始视频序列组成,这些视频通过四种自动人脸操纵方法进行处理:Deepfakes、Face2Face、FaceSwap 和 NeuralTextures。数据来自 977 段 YouTube 视频,所有视频中都包含一张可跟踪的、主要为正面且没有遮挡的人脸,使得自动篡改方法能够生成逼真的伪造视频。同时,由于该数据集提供了二值掩码,这些数据可以用于图像和视频分类以及分割。此外,官方还提供了 1000 个 Deepfakes 模型,用于生成和扩充新数据。

  • 原始论文:https://arxiv.org/abs/1901.08971
  • GitHub链接:https://github.com/ondyari/FaceForensics
        FaceForensics++数据集无法直接下载,需要按照要求填写谷歌表单来申请获取https://docs.google.com/forms/d/e/1FAIpQLSdRRR3L5zAv6tQ_CKxmK4W96tAab_pfBu2EKAgQbeDVhmXagg/viewform
1.png

        等待几天之后会收到如下邮件,里面会附上数据集的下载Code,直接使用下载脚本下载即可获取:
2.png
  1. #!/usr/bin/env python
  2. """ Downloads FaceForensics++ and Deep Fake Detection public data release
  3. Example usage:
  4.     see -h or https://github.com/ondyari/FaceForensics
  5. """
  6. # -*- coding: utf-8 -*-
  7. import argparse
  8. import os
  9. import urllib
  10. import urllib.request
  11. import tempfile
  12. import time
  13. import sys
  14. import json
  15. import random
  16. from tqdm import tqdm
  17. from os.path import join
  18. # URLs and filenames
  19. FILELIST_URL = 'misc/filelist.json'
  20. DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'
  21. DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]
  22. # Parameters
  23. DATASETS = {
  24.     'original_youtube_videos': 'misc/downloaded_youtube_videos.zip',
  25.     'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip',
  26.     'original': 'original_sequences/youtube',
  27.     'DeepFakeDetection_original': 'original_sequences/actors',
  28.     'Deepfakes': 'manipulated_sequences/Deepfakes',
  29.     'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection',
  30.     'Face2Face': 'manipulated_sequences/Face2Face',
  31.     'FaceShifter': 'manipulated_sequences/FaceShifter',
  32.     'FaceSwap': 'manipulated_sequences/FaceSwap',
  33.     'NeuralTextures': 'manipulated_sequences/NeuralTextures'
  34.     }
  35. ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes',
  36.                 'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap',
  37.                 'NeuralTextures']
  38. COMPRESSION = ['raw', 'c23', 'c40']
  39. TYPE = ['videos', 'masks', 'models']
  40. SERVERS = ['EU', 'EU2', 'CA']
  41. def parse_args():
  42.     parser = argparse.ArgumentParser(
  43.         description='Downloads FaceForensics v2 public data release.',
  44.         formatter_class=argparse.ArgumentDefaultsHelpFormatter
  45.     )
  46.     parser.add_argument('output_path', type=str, help='Output directory.')
  47.     parser.add_argument('-d', '--dataset', type=str, default='all',
  48.                         help='Which dataset to download, either pristine or '
  49.                              'manipulated data or the downloaded youtube '
  50.                              'videos.',
  51.                         choices=list(DATASETS.keys()) + ['all']
  52.                         )
  53.     parser.add_argument('-c', '--compression', type=str, default='raw',
  54.                         help='Which compression degree. All videos '
  55.                              'have been generated with h264 with a varying '
  56.                              'codec. Raw (c0) videos are lossless compressed.',
  57.                         choices=COMPRESSION
  58.                         )
  59.     parser.add_argument('-t', '--type', type=str, default='videos',
  60.                         help='Which file type, i.e. videos, masks, for our '
  61.                              'manipulation methods, models, for Deepfakes.',
  62.                         choices=TYPE
  63.                         )
  64.     parser.add_argument('-n', '--num_videos', type=int, default=None,
  65.                         help='Select a number of videos number to '
  66.                              "download if you don't want to download the full"
  67.                              ' dataset.')
  68.     parser.add_argument('--server', type=str, default='EU',
  69.                         help='Server to download the data from. If you '
  70.                              'encounter a slow download speed, consider '
  71.                              'changing the server.',
  72.                         choices=SERVERS
  73.                         )
  74.     args = parser.parse_args()
  75.     # URLs
  76.     server = args.server
  77.     if server == 'EU':
  78.         server_url = 'http://canis.vc.in.tum.de:8100/'
  79.     elif server == 'EU2':
  80.         server_url = 'http://kaldir.vc.in.tum.de/faceforensics/'
  81.     elif server == 'CA':
  82.         server_url = 'http://falas.cmpt.sfu.ca:8100/'
  83.     else:
  84.         raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS)))
  85.     args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf'
  86.     args.base_url = server_url + 'v3/'
  87.     args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \
  88.                                'Deepfakes/models/'
  89.     return args
  90. def download_files(filenames, base_url, output_path, report_progress=True):
  91.     os.makedirs(output_path, exist_ok=True)
  92.     if report_progress:
  93.         filenames = tqdm(filenames)
  94.     for filename in filenames:
  95.         download_file(base_url + filename, join(output_path, filename))
  96. def reporthook(count, block_size, total_size):
  97.     global start_time
  98.     if count == 0:
  99.         start_time = time.time()
  100.         return
  101.     duration = time.time() - start_time
  102.     progress_size = int(count * block_size)
  103.     speed = int(progress_size / (1024 * duration))
  104.     percent = int(count * block_size * 100 / total_size)
  105.     sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" %
  106.                      (percent, progress_size / (1024 * 1024), speed, duration))
  107.     sys.stdout.flush()
  108. def download_file(url, out_file, report_progress=False):
  109.     out_dir = os.path.dirname(out_file)
  110.     if not os.path.isfile(out_file):
  111.         fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
  112.         f = os.fdopen(fh, 'w')
  113.         f.close()
  114.         if report_progress:
  115.             urllib.request.urlretrieve(url, out_file_tmp,
  116.                                        reporthook=reporthook)
  117.         else:
  118.             urllib.request.urlretrieve(url, out_file_tmp)
  119.         os.rename(out_file_tmp, out_file)
  120.     else:
  121.         tqdm.write('WARNING: skipping download of existing file ' + out_file)
  122. def main(args):
  123.     # TOS
  124.     print('By pressing any key to continue you confirm that you have agreed '\
  125.           'to the FaceForensics terms of use as described at:')
  126.     print(args.tos_url)
  127.     print('***')
  128.     print('Press any key to continue, or CTRL-C to exit.')
  129.     _ = input('')
  130.     # Extract arguments
  131.     c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS
  132.     c_type = args.type
  133.     c_compression = args.compression
  134.     num_videos = args.num_videos
  135.     output_path = args.output_path
  136.     os.makedirs(output_path, exist_ok=True)
  137.     # Check for special dataset cases
  138.     for dataset in c_datasets:
  139.         dataset_path = DATASETS[dataset]
  140.         # Special cases
  141.         if 'original_youtube_videos' in dataset:
  142.             # Here we download the original youtube videos zip file
  143.             print('Downloading original youtube videos.')
  144.             if not 'info' in dataset_path:
  145.                 print('Please be patient, this may take a while (~40gb)')
  146.                 suffix = ''
  147.             else:
  148.                     suffix = 'info'
  149.             download_file(args.base_url + '/' + dataset_path,
  150.                           out_file=join(output_path,
  151.                                         'downloaded_videos{}.zip'.format(
  152.                                             suffix)),
  153.                           report_progress=True)
  154.             return
  155.         # Else: regular datasets
  156.         print('Downloading {} of dataset "{}"'.format(
  157.             c_type, dataset_path
  158.         ))
  159.         # Get filelists and video lenghts list from server
  160.         if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path:
  161.                 filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' +
  162.                 DEEPFEAKES_DETECTION_URL).read().decode("utf-8"))
  163.                 if 'actors' in dataset_path:
  164.                         filelist = filepaths['actors']
  165.                 else:
  166.                         filelist = filepaths['DeepFakesDetection']
  167.         elif 'original' in dataset_path:
  168.             # Load filelist from server
  169.             file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
  170.                 FILELIST_URL).read().decode("utf-8"))
  171.             filelist = []
  172.             for pair in file_pairs:
  173.                     filelist += pair
  174.         else:
  175.             # Load filelist from server
  176.             file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
  177.                 FILELIST_URL).read().decode("utf-8"))
  178.             # Get filelist
  179.             filelist = []
  180.             for pair in file_pairs:
  181.                 filelist.append('_'.join(pair))
  182.                 if c_type != 'models':
  183.                     filelist.append('_'.join(pair[::-1]))
  184.         # Maybe limit number of videos for download
  185.         if num_videos is not None and num_videos > 0:
  186.                 print('Downloading the first {} videos'.format(num_videos))
  187.                 filelist = filelist[:num_videos]
  188.         # Server and local paths
  189.         dataset_videos_url = args.base_url + '{}/{}/{}/'.format(
  190.             dataset_path, c_compression, c_type)
  191.         dataset_mask_url = args.base_url + '{}/{}/videos/'.format(
  192.             dataset_path, 'masks', c_type)
  193.         if c_type == 'videos':
  194.             dataset_output_path = join(output_path, dataset_path, c_compression,
  195.                                        c_type)
  196.             print('Output path: {}'.format(dataset_output_path))
  197.             filelist = [filename + '.mp4' for filename in filelist]
  198.             download_files(filelist, dataset_videos_url, dataset_output_path)
  199.         elif c_type == 'masks':
  200.             dataset_output_path = join(output_path, dataset_path, c_type,
  201.                                        'videos')
  202.             print('Output path: {}'.format(dataset_output_path))
  203.             if 'original' in dataset:
  204.                 if args.dataset != 'all':
  205.                     print('Only videos available for original data. Aborting.')
  206.                     return
  207.                 else:
  208.                     print('Only videos available for original data. '
  209.                           'Skipping original.\n')
  210.                     continue
  211.             if 'FaceShifter' in dataset:
  212.                 print('Masks not available for FaceShifter. Aborting.')
  213.                 return
  214.             filelist = [filename + '.mp4' for filename in filelist]
  215.             download_files(filelist, dataset_mask_url, dataset_output_path)
  216.         # Else: models for deepfakes
  217.         else:
  218.             if dataset != 'Deepfakes' and c_type == 'models':
  219.                 print('Models only available for Deepfakes. Aborting')
  220.                 return
  221.             dataset_output_path = join(output_path, dataset_path, c_type)
  222.             print('Output path: {}'.format(dataset_output_path))
  223.             # Get Deepfakes models
  224.             for folder in tqdm(filelist):
  225.                 folder_filelist = DEEPFAKES_MODEL_NAMES
  226.                 # Folder paths
  227.                 folder_base_url = args.deepfakes_model_url + folder + '/'
  228.                 folder_dataset_output_path = join(dataset_output_path,
  229.                                                   folder)
  230.                 download_files(folder_filelist, folder_base_url,
  231.                                folder_dataset_output_path,
  232.                                report_progress=False)   # already done
  233. if __name__ == "__main__":
  234.     args = parse_args()
  235.     main(args)
复制代码
        接下来使用如下命令即可下载数据集
  1. python download-FaceForensics.py
  2.     <output path>
  3.     -d <dataset type, e.g., Face2Face, original or all>
  4.     -c <compression quality, e.g., c23 or raw>
  5.     -t <file type, e.g., videos, masks or models>
复制代码
         表示数据集的保存路径,即下载后的 FaceForensics++ 或 DeepFakeDetection 数据将被存放的位置。例如,可以设置为当前项目下的 ./data/,也可以设置为单独的数据盘路径,如 /mnt/data2/qi.xiong/Dataset/FaceForensics/。下载脚本会在该目录下自动构建对应的数据集层级结构.
        d 用于指定下载的数据类型(dataset type)。常见可选项包括 original、Face2Face、Deepfakes、FaceSwap、NeuralTextures、DeepFakeDetection 以及 all 等。其中,original 表示下载原始真实视频序列,通常对应 original_sequences/youtube;Face2Face、Deepfakes、FaceSwap 和 NeuralTextures 表示下载四种主要伪造方法生成的数据;DeepFakeDetection 表示下载 DeepFakeDetection 扩展数据;all 表示一次性下载全部可用数据。若仅用于常规 deepfake 检测实验,通常优先选择 original 与四种主流伪造类型。
        c 用于指定压缩等级(compression quality)。常用选项为 raw、c23 和 c40。其中,raw 表示原始或无损压缩版本,数据体积最大,但保留了最完整的图像细节;c23 表示较高质量压缩版本,是目前较常见、也较平衡的一种设置,既能保留较好的视觉质量,又显著降低存储开销;c40 表示压缩更强、质量更低的数据版本,更适合做强压缩场景下的鲁棒性测试。实际使用中,如果只是复现主流实验或进行预处理,通常推荐优先下载 c23 视频版本。
        t 用于指定文件类型(file type)。常见选项包括 videos、masks 和 models。其中,videos 表示下载视频文件,这是最常用的选项;masks 表示下载伪造区域的二值掩码,适用于伪造区域定位、分割或可解释性分析任务;models 主要与部分伪造方法相关,用于获取对应的生成模型文件。对于大多数 deepfake 分类或人脸抽帧任务,仅下载 videos 即可。
        下载完成的数据集格式如下:
  1. (xq) qi.xiong@instance-ujccspas:/mnt/data2/qi.xiong/Dataset/FaceForensics$ tree -L 3
  2. .
  3. ├── manipulated_sequences
  4. │   ├── DeepFakeDetection
  5. │   │   ├── c23
  6. │   │   └── masks
  7. │   ├── Deepfakes
  8. │   │   ├── c23
  9. │   │   └── masks
  10. │   ├── Face2Face
  11. │   │   ├── c23
  12. │   │   └── masks
  13. │   ├── FaceShifter
  14. │   │   └── c23
  15. │   ├── FaceSwap
  16. │   │   └── c23
  17. │   └── NeuralTextures
  18. │       └── c23
  19. └── original_sequences
  20.     ├── actors
  21.     │   └── c23
  22.     └── youtube
  23.         └── c23
  24. 22 directories, 0 files
复制代码
二、数据集预处理

        我们前面下载得到的数据集仍然是视频格式,因此在正式用于 deepfake 检测之前,还需要先进行预处理。通常来说,这类任务不会直接将整段视频输入模型,而是先从视频中抽取若干具有代表性的帧,再从每一帧中提取对应的人脸区域。这样做一方面可以明显降低后续数据处理和模型训练的开销,另一方面也能让模型更聚焦于真正有用的面部伪造信息。FaceForensics++ 官方文档中也提到,通常更推荐先下载压缩后的视频,再自行完成帧提取。本文这里采用一种比较简化且实用的处理方式:从每个视频中均匀抽取固定数量的帧,然后使用 RetinaFace 对这些帧进行人脸检测,并将检测到的人脸区域裁剪保存。相比一些传统方法,RetinaFace 在检测精度和鲁棒性方面通常更有优势,尤其是在侧脸、光照变化较大或者人脸尺度变化明显的情况下,检测结果往往更加稳定。需要说明的是,本文这里的预处理目标比较明确,即只做人脸抽帧和人脸裁剪,不额外涉及关键点对齐、伪造区域掩码生成等更复杂的步骤,因此整个流程会更加清晰,也更适合作为 FaceForensics++ 数据预处理的基础版本。
  1. git clone https://github.com/ternaus/retinaface.git
  2. cd retinaface
  3. pip install -v -e .
复制代码
        我们配置好了retinaface之后,即可使用如下脚本继续转换:
  1. from glob import glob
  2. import os
  3. import cv2
  4. from tqdm import tqdm
  5. import numpy as np
  6. import argparse
  7. from retinaface.pre_trained_models import get_model
  8. import torch
  9. def facecrop(model, org_path, save_path, num_frames=10):
  10.     cap_org = cv2.VideoCapture(org_path)
  11.     frame_count_org = int(cap_org.get(cv2.CAP_PROP_FRAME_COUNT))
  12.     if frame_count_org <= 0:
  13.         print(f"Invalid video: {org_path}")
  14.         cap_org.release()
  15.         return
  16.     frame_idxs = np.linspace(0, frame_count_org - 1, num_frames, endpoint=True, dtype=int)
  17.     frame_idxs = set(frame_idxs.tolist())
  18.     for cnt_frame in range(frame_count_org):
  19.         ret_org, frame_org = cap_org.read()
  20.         if not ret_org or frame_org is None:
  21.             continue
  22.         if cnt_frame not in frame_idxs:
  23.             continue
  24.         frame = cv2.cvtColor(frame_org, cv2.COLOR_BGR2RGB)
  25.         faces = model.predict_jsons(frame)
  26.         if len(faces) == 0:
  27.             continue
  28.         save_path_frames = os.path.join(
  29.             save_path, 'frames_retina', os.path.basename(org_path).replace('.mp4', '')
  30.         )
  31.         os.makedirs(save_path_frames, exist_ok=True)
  32.         for face_idx, face in enumerate(faces):
  33.             bbox = face.get('bbox', None)
  34.             if bbox is None or len(bbox) < 4:
  35.                 continue
  36.             x0, y0, x1, y1 = map(int, bbox[:4])
  37.             x0 = max(0, x0)
  38.             y0 = max(0, y0)
  39.             x1 = min(frame_org.shape[1], x1)
  40.             y1 = min(frame_org.shape[0], y1)
  41.             if x1 <= x0 or y1 <= y0:
  42.                 continue
  43.             cropped_face = frame_org[y0:y1, x0:x1]
  44.             face_image_path = os.path.join(
  45.                 save_path_frames, f'frame_{cnt_frame}_face_{face_idx}.png'
  46.             )
  47.             cv2.imwrite(face_image_path, cropped_face)
  48.     cap_org.release()
  49. if __name__ == '__main__':
  50.     parser = argparse.ArgumentParser()
  51.     parser.add_argument(
  52.         '-d',
  53.         dest='dataset',
  54.         choices=[
  55.             'Original',
  56.             'DeepFakeDetection_original',
  57.             'DeepFakeDetection',
  58.             'Deepfakes',
  59.             'Face2Face',
  60.             'FaceShifter',
  61.             'FaceSwap',
  62.             'NeuralTextures'
  63.         ]
  64.     )
  65.     parser.add_argument('-c', dest='comp', choices=['raw', 'c23', 'c40'], default='raw')
  66.     parser.add_argument('-n', dest='num_frames', type=int, default=20)
  67.     args = parser.parse_args()
  68.     if args.dataset == 'Original':
  69.         dataset_path = 'data/FaceForensics++/original_sequences/youtube/{}/'.format(args.comp)
  70.     elif args.dataset == 'DeepFakeDetection_original':
  71.         dataset_path = 'data/FaceForensics++/original_sequences/actors/{}/'.format(args.comp)
  72.     elif args.dataset in ['DeepFakeDetection', 'FaceShifter', 'Face2Face', 'Deepfakes', 'FaceSwap', 'NeuralTextures']:
  73.         dataset_path = 'data/FaceForensics++/manipulated_sequences/{}/{}/'.format(args.dataset, args.comp)
  74.     else:
  75.         raise NotImplementedError
  76.     device = torch.device('cpu')
  77.     model = get_model("resnet50_2020-07-20", max_size=2048, device=device)
  78.     model.eval()
  79.     movies_path = dataset_path + 'videos/'
  80.     movies_path_list = sorted(glob(movies_path + '*.mp4'))
  81.     print("{} : videos are exist in {}".format(len(movies_path_list), args.dataset))
  82.     for i in tqdm(range(len(movies_path_list))):
  83.         facecrop(model, movies_path_list[i], save_path=dataset_path, num_frames=args.num_frames)
复制代码
三、人脸分类网络

        我们接下来直接使用Timm库来验证CNN和Transformer作为Backbone对人脸伪造分类的识别性能,我们将支持两种分类方式,分别是二分类和五分类,二分类即单纯的True/False,五分类则在正确区分的基础上额外实现分类人脸伪造的方式
        所有代码已上传至GitHub:https://github.com/xiongqi123123/fakefaceclsnet
        数据集加载及数据增强代码如下:
  1. data/FaceForensics++/original_sequences/youtube/c23/videos/
复制代码
  1. data/FaceForensics++/original_sequences/youtube/c23/frames_retina/
复制代码
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 将 frames_retina 组织为 fakefacecls 所需结构
  5. 用法:
  6.   python setup_ffpp_dataset.py
  7.   python setup_ffpp_dataset.py --data_root /path/to/data/FaceForensics++
  8. 输出:
  9.   data/FaceForensics++/ffpp/
  10.   ├── train.json, val.json, test.json
  11.   ├── Origin/c23/larger_images/     -> symlinks to frames_retina
  12.   ├── Deepfakes/c23/larger_images/
  13.   ├── Face2Face/c23/larger_images/
  14.   ├── FaceSwap/c23/larger_images/
  15.   └── NeuralTextures/c23/larger_images/
  16. """
  17. import argparse
  18. import json
  19. import os
  20. from pathlib import Path
  21. # FF++ 官方划分 (来自 https://github.com/ondyari/FaceForensics)
  22. TRAIN_JSON = [
  23.     ["071", "054"], ["087", "081"], ["881", "856"], ["187", "234"], ["645", "688"],
  24.     ["754", "758"], ["811", "920"], ["710", "788"], ["628", "568"], ["312", "021"],
  25.     ["950", "836"], ["059", "050"], ["524", "580"], ["751", "752"], ["918", "934"],
  26.     ["604", "703"], ["296", "293"], ["518", "131"], ["536", "540"], ["969", "897"],
  27.     ["372", "413"], ["357", "432"], ["809", "799"], ["092", "098"], ["302", "323"],
  28.     ["981", "985"], ["512", "495"], ["088", "060"], ["795", "907"], ["535", "587"],
  29.     ["297", "270"], ["838", "810"], ["850", "764"], ["476", "400"], ["268", "269"],
  30.     ["033", "097"], ["226", "491"], ["784", "769"], ["195", "442"], ["678", "460"],
  31.     ["320", "328"], ["451", "449"], ["409", "382"], ["556", "588"], ["027", "009"],
  32.     ["196", "310"], ["241", "210"], ["295", "099"], ["043", "110"], ["753", "789"],
  33.     ["716", "712"], ["508", "831"], ["005", "010"], ["276", "185"], ["498", "433"],
  34.     ["294", "292"], ["105", "180"], ["984", "967"], ["318", "334"], ["356", "324"],
  35.     ["344", "020"], ["289", "228"], ["022", "489"], ["137", "165"], ["095", "053"],
  36.     ["999", "960"], ["481", "469"], ["534", "490"], ["543", "559"], ["150", "153"],
  37.     ["598", "178"], ["475", "265"], ["671", "677"], ["204", "230"], ["863", "853"],
  38.     ["561", "998"], ["163", "031"], ["655", "444"], ["038", "125"], ["735", "774"],
  39.     ["184", "205"], ["499", "539"], ["717", "684"], ["878", "866"], ["127", "129"],
  40.     ["286", "267"], ["032", "944"], ["681", "711"], ["236", "237"], ["989", "993"],
  41.     ["537", "563"], ["814", "871"], ["509", "525"], ["221", "206"], ["808", "829"],
  42.     ["696", "686"], ["431", "447"], ["737", "719"], ["609", "596"], ["408", "424"],
  43.     ["976", "954"], ["156", "243"], ["434", "438"], ["627", "658"], ["025", "067"],
  44.     ["635", "642"], ["523", "541"], ["572", "554"], ["215", "208"], ["651", "835"],
  45.     ["975", "978"], ["792", "903"], ["931", "936"], ["846", "845"], ["899", "914"],
  46.     ["209", "016"], ["398", "457"], ["797", "844"], ["360", "437"], ["738", "804"],
  47.     ["694", "767"], ["790", "014"], ["657", "644"], ["374", "407"], ["728", "673"],
  48.     ["193", "030"], ["876", "891"], ["553", "545"], ["331", "260"], ["873", "872"],
  49.     ["109", "107"], ["121", "093"], ["143", "140"], ["778", "798"], ["983", "113"],
  50.     ["504", "502"], ["709", "390"], ["940", "941"], ["894", "848"], ["311", "387"],
  51.     ["562", "626"], ["330", "162"], ["112", "892"], ["765", "867"], ["124", "085"],
  52.     ["665", "679"], ["414", "385"], ["555", "516"], ["072", "037"], ["086", "090"],
  53.     ["202", "348"], ["341", "340"], ["333", "377"], ["082", "103"], ["569", "921"],
  54.     ["750", "743"], ["211", "177"], ["770", "791"], ["329", "327"], ["613", "685"],
  55.     ["007", "132"], ["304", "300"], ["860", "905"], ["986", "994"], ["378", "368"],
  56.     ["761", "766"], ["232", "248"], ["136", "285"], ["601", "653"], ["693", "698"],
  57.     ["359", "317"], ["246", "258"], ["500", "592"], ["776", "676"], ["262", "301"],
  58.     ["307", "365"], ["600", "505"], ["833", "826"], ["361", "448"], ["473", "366"],
  59.     ["885", "802"], ["277", "335"], ["667", "446"], ["522", "337"], ["018", "019"],
  60.     ["430", "459"], ["886", "877"], ["456", "435"], ["239", "218"], ["771", "849"],
  61.     ["065", "089"], ["654", "648"], ["151", "225"], ["152", "149"], ["229", "247"],
  62.     ["624", "570"], ["290", "240"], ["011", "805"], ["461", "250"], ["251", "375"],
  63.     ["639", "841"], ["602", "397"], ["028", "068"], ["338", "336"], ["964", "174"],
  64.     ["782", "787"], ["478", "506"], ["313", "283"], ["659", "749"], ["690", "689"],
  65.     ["893", "913"], ["197", "224"], ["253", "183"], ["373", "394"], ["803", "017"],
  66.     ["305", "513"], ["051", "332"], ["238", "282"], ["621", "546"], ["401", "395"],
  67.     ["510", "528"], ["410", "411"], ["049", "946"], ["663", "231"], ["477", "487"],
  68.     ["252", "266"], ["952", "882"], ["315", "322"], ["216", "164"], ["061", "080"],
  69.     ["603", "575"], ["828", "830"], ["723", "704"], ["870", "001"], ["201", "203"],
  70.     ["652", "773"], ["108", "052"], ["272", "396"], ["040", "997"], ["988", "966"],
  71.     ["281", "474"], ["077", "100"], ["146", "256"], ["972", "718"], ["303", "309"],
  72.     ["582", "172"], ["222", "168"], ["884", "968"], ["217", "117"], ["118", "120"],
  73.     ["242", "182"], ["858", "861"], ["101", "096"], ["697", "581"], ["763", "930"],
  74.     ["839", "864"], ["542", "520"], ["122", "144"], ["687", "615"], ["544", "532"],
  75.     ["721", "715"], ["179", "212"], ["591", "605"], ["275", "887"], ["996", "056"],
  76.     ["825", "074"], ["530", "594"], ["757", "573"], ["611", "760"], ["189", "200"],
  77.     ["392", "339"], ["734", "699"], ["977", "075"], ["879", "963"], ["910", "911"],
  78.     ["889", "045"], ["962", "929"], ["515", "519"], ["062", "066"], ["937", "888"],
  79.     ["199", "181"], ["785", "736"], ["079", "076"], ["155", "576"], ["748", "355"],
  80.     ["819", "786"], ["577", "593"], ["464", "463"], ["439", "441"], ["574", "547"],
  81.     ["747", "854"], ["403", "497"], ["965", "948"], ["726", "713"], ["943", "942"],
  82.     ["160", "928"], ["496", "417"], ["700", "813"], ["756", "503"], ["213", "083"],
  83.     ["039", "058"], ["781", "806"], ["620", "619"], ["351", "346"], ["959", "957"],
  84.     ["264", "271"], ["006", "002"], ["391", "406"], ["631", "551"], ["501", "326"],
  85.     ["412", "274"], ["641", "662"], ["111", "094"], ["166", "167"], ["130", "139"],
  86.     ["938", "987"], ["055", "147"], ["990", "008"], ["013", "883"], ["614", "616"],
  87.     ["772", "708"], ["840", "800"], ["415", "484"], ["287", "426"], ["680", "486"],
  88.     ["057", "070"], ["590", "034"], ["194", "235"], ["291", "874"], ["902", "901"],
  89.     ["343", "363"], ["279", "298"], ["393", "405"], ["674", "744"], ["244", "822"],
  90.     ["133", "148"], ["636", "578"], ["637", "427"], ["041", "063"], ["869", "780"],
  91.     ["733", "935"], ["259", "345"], ["069", "961"], ["783", "916"], ["191", "188"],
  92.     ["526", "436"], ["123", "119"], ["207", "908"], ["796", "740"], ["815", "730"],
  93.     ["173", "171"], ["383", "353"], ["458", "722"], ["533", "450"], ["618", "629"],
  94.     ["646", "643"], ["531", "549"], ["428", "466"], ["859", "843"], ["692", "610"],
  95. ]
  96. VAL_JSON = [
  97.     ["720", "672"], ["939", "115"], ["284", "263"], ["402", "453"], ["820", "818"],
  98.     ["762", "832"], ["834", "852"], ["922", "898"], ["104", "126"], ["106", "198"],
  99.     ["159", "175"], ["416", "342"], ["857", "909"], ["599", "585"], ["443", "514"],
  100.     ["566", "617"], ["472", "511"], ["325", "492"], ["816", "649"], ["583", "558"],
  101.     ["933", "925"], ["419", "824"], ["465", "482"], ["565", "589"], ["261", "254"],
  102.     ["992", "980"], ["157", "245"], ["571", "746"], ["947", "951"], ["926", "900"],
  103.     ["493", "538"], ["468", "470"], ["915", "895"], ["362", "354"], ["440", "364"],
  104.     ["640", "638"], ["827", "817"], ["793", "768"], ["837", "890"], ["004", "982"],
  105.     ["192", "134"], ["745", "777"], ["299", "145"], ["742", "775"], ["586", "223"],
  106.     ["483", "370"], ["779", "794"], ["971", "564"], ["273", "807"], ["991", "064"],
  107.     ["664", "668"], ["823", "584"], ["656", "666"], ["557", "560"], ["471", "455"],
  108.     ["042", "084"], ["979", "875"], ["316", "369"], ["091", "116"], ["023", "923"],
  109.     ["702", "612"], ["904", "046"], ["647", "622"], ["958", "956"], ["606", "567"],
  110.     ["632", "548"], ["927", "912"], ["350", "349"], ["595", "597"], ["727", "729"],
  111. ]
  112. TEST_JSON = [
  113.     ["953", "974"], ["012", "026"], ["078", "955"], ["623", "630"], ["919", "015"],
  114.     ["367", "371"], ["847", "906"], ["529", "633"], ["418", "507"], ["227", "169"],
  115.     ["389", "480"], ["821", "812"], ["670", "661"], ["158", "379"], ["423", "421"],
  116.     ["352", "319"], ["579", "701"], ["488", "399"], ["695", "422"], ["288", "321"],
  117.     ["705", "707"], ["306", "278"], ["865", "739"], ["995", "233"], ["755", "759"],
  118.     ["467", "462"], ["314", "347"], ["741", "731"], ["970", "973"], ["634", "660"],
  119.     ["494", "445"], ["706", "479"], ["186", "170"], ["176", "190"], ["380", "358"],
  120.     ["214", "255"], ["454", "527"], ["425", "485"], ["388", "308"], ["384", "932"],
  121.     ["035", "036"], ["257", "420"], ["924", "917"], ["114", "102"], ["732", "691"],
  122.     ["550", "452"], ["280", "249"], ["842", "714"], ["625", "650"], ["024", "073"],
  123.     ["044", "945"], ["896", "128"], ["862", "047"], ["607", "683"], ["517", "521"],
  124.     ["682", "669"], ["138", "142"], ["552", "851"], ["376", "381"], ["000", "003"],
  125.     ["048", "029"], ["724", "725"], ["608", "675"], ["386", "154"], ["220", "219"],
  126.     ["801", "855"], ["161", "141"], ["949", "868"], ["880", "135"], ["429", "404"],
  127. ]
  128. # 路径映射: (method, codec) -> (frames_retina 相对路径)
  129. ORIGIN_FRAMES = "original_sequences/youtube/{codec}/frames_retina"
  130. MANIPULATED_FRAMES = "manipulated_sequences/{method}/{codec}/frames_retina"
  131. METHODS = ["Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures", "FaceShifter"]  # 可选 DeepFakeDetection
  132. def main():
  133.     root = Path(__file__).resolve().parent / "FaceForensics++"
  134.     parser = argparse.ArgumentParser()
  135.     parser.add_argument("--data_root", default=str(root), help="FaceForensics++ 根目录")
  136.     parser.add_argument("--codec", default="c23")
  137.     parser.add_argument("--methods", nargs="+", default=METHODS)
  138.     args = parser.parse_args()
  139.     data_root = Path(args.data_root)
  140.     codec = args.codec
  141.     ffpp = data_root / "ffpp"
  142.     ffpp.mkdir(parents=True, exist_ok=True)
  143.     # 1. 保存 JSON
  144.     for name, pairs in [("train", TRAIN_JSON), ("val", VAL_JSON), ("test", TEST_JSON)]:
  145.         f = ffpp / f"{name}.json"
  146.         with open(f, "w") as fp:
  147.             json.dump(pairs, fp, indent=2)
  148.         print(f"  {f}")
  149.     # 2. Origin: larger_images/{id} -> symlink to frames_retina/xxx
  150.     origin_frames = data_root / ORIGIN_FRAMES.format(codec=codec)
  151.     origin_larger = ffpp / "Origin" / codec / "larger_images"
  152.     origin_larger.mkdir(parents=True, exist_ok=True)
  153.     if origin_frames.exists():
  154.         for vid in sorted(origin_frames.iterdir()):
  155.             if vid.is_dir():
  156.                 dst = origin_larger / vid.name
  157.                 if not dst.exists():
  158.                     dst.symlink_to(vid.resolve())
  159.         print(f"  Origin: {origin_larger} ({len(list(origin_larger.iterdir()))} videos)")
  160.     else:
  161.         print(f"  [skip] Origin {origin_frames} not found")
  162.     # 3. Manipulated: larger_images/{id1_id2} -> symlink to frames_retina/xxx
  163.     for method in args.methods:
  164.         man_frames = data_root / MANIPULATED_FRAMES.format(method=method, codec=codec)
  165.         man_larger = ffpp / method / codec / "larger_images"
  166.         man_larger.mkdir(parents=True, exist_ok=True)
  167.         if man_frames.exists():
  168.             n = 0
  169.             for vid in sorted(man_frames.iterdir()):
  170.                 if vid.is_dir():
  171.                     dst = man_larger / vid.name
  172.                     if not dst.exists():
  173.                         dst.symlink_to(vid.resolve())
  174.                     n += 1
  175.             print(f"  {method}: {man_larger} ({n} videos)")
  176.         else:
  177.             print(f"  [skip] {method} {man_frames} not found")
  178.     print(f"\n完成: ffpp 目录 -> {ffpp}")
  179.     print("\n使用方式:")
  180.     print("  1. fakefacecls: export FFPP_ROOT=" + str(ffpp.resolve()))
  181.     print("  2. multiple-attention: 在 datasets/data.py 中设置 ffpproot = '" + str(ffpp.resolve()) + "/'")
  182. if __name__ == "__main__":
  183.     main()
复制代码
        网络直接使用Timm的预置模型:
  1. import os
  2. import random
  3. import torch
  4. import cv2
  5. from torch.utils.data import Dataset
  6. import albumentations as A
  7. from albumentations import Compose
  8. from .augmentations import augmentations
  9. from . import data
  10. class DeepfakeDataset(Dataset):
  11.     def __init__(
  12.         self,
  13.         phase='train',
  14.         datalabel='',
  15.         resize=(224, 224),
  16.         imgs_per_video=30,
  17.         min_frames=0,
  18.         normalize=None,
  19.         frame_interval=10,
  20.         max_frames=300,
  21.         augment='augment0',
  22.     ):
  23.         assert phase in ['train', 'val', 'test']
  24.         normalize = normalize or dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  25.         self.datalabel = datalabel
  26.         self.phase = phase
  27.         self.imgs_per_video = imgs_per_video
  28.         self.frame_interval = frame_interval
  29.         self.epoch = 0
  30.         self.max_frames = max_frames
  31.         self.min_frames = min_frames if min_frames else max_frames * 0.3
  32.         self.aug = augmentations.get(augment, augmentations['augment0'])
  33.         self.resize = resize
  34.         self.trans = Compose([
  35.             A.Resize(resize[0], resize[1]),  # 小图(如19x14)需先 resize,CenterCrop 会报错
  36.             A.Normalize(mean=normalize['mean'], std=normalize['std']),
  37.             A.ToTensorV2(),
  38.         ])
  39.         self.dataset = self._build_dataset()
  40.         self._frame_cache = {}  # 缓存 os.listdir,避免每帧重复读目录
  41.     def _build_dataset(self):
  42.         if isinstance(self.datalabel, (list, tuple)):
  43.             return self.datalabel
  44.         if 'ff-5' in self.datalabel:
  45.             codec = self.datalabel.split('-')[2]
  46.             out = []
  47.             for idx, tag in enumerate(['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']):
  48.                 for item in data.FF_dataset(tag, codec, self.phase):
  49.                     out.append([item[0], idx])
  50.             return out
  51.         if 'ff-all' in self.datalabel:
  52.             codec = self.datalabel.split('-')[2]
  53.             out = []
  54.             for tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']:
  55.                 out.extend(data.FF_dataset(tag, codec, self.phase))
  56.             if self.phase != 'test':
  57.                 out = data.make_balance(out)
  58.             return out
  59.         if 'ff' in self.datalabel:
  60.             parts = self.datalabel.split('-')
  61.             codec = parts[2]
  62.             tag = parts[1]
  63.             return data.FF_dataset(tag, codec, self.phase) + data.FF_dataset('Origin', codec, self.phase)
  64.         if 'celeb' in self.datalabel:
  65.             return data.Celeb_test
  66.         if 'deeper' in self.datalabel:
  67.             codec = self.datalabel.split('-')[1]
  68.             return data.deeperforensics_dataset(self.phase) + data.FF_dataset('Origin', codec, self.phase)
  69.         if 'dfdc' in self.datalabel:
  70.             return data.dfdc_dataset(self.phase)
  71.         raise ValueError(f'Unknown datalabel: {self.datalabel}')
  72.     def next_epoch(self):
  73.         self.epoch += 1
  74.     def __getitem__(self, item):
  75.         for _ in range(len(self.dataset)):  # 避免无限递归
  76.             try:
  77.                 vid = self.dataset[item // self.imgs_per_video]
  78.                 vid_path = vid[0]
  79.                 if vid_path not in self._frame_cache:
  80.                     self._frame_cache[vid_path] = sorted(os.listdir(vid_path))
  81.                 vd = self._frame_cache[vid_path]
  82.                 if len(vd) < self.min_frames:
  83.                     raise ValueError(f"frames {len(vd)} < min_frames {self.min_frames}")
  84.                 idx = (item % self.imgs_per_video * self.frame_interval + self.epoch) % min(len(vd), self.max_frames)
  85.                 fname = vd[idx]
  86.                 img = cv2.imread(os.path.join(vid[0], fname))
  87.                 if img is None:
  88.                     raise ValueError(f"cv2.imread failed: {fname}")
  89.                 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  90.                 if self.phase == 'train':
  91.                     img = self.aug(image=img)['image']
  92.                 return self.trans(image=img)['image'], vid[1]
  93.             except Exception as e:
  94.                 if os.environ.get('DEBUG_DATASET') == '1' and not getattr(self, '_debug_printed', False):
  95.                     import traceback
  96.                     vp = self.dataset[item // self.imgs_per_video][0] if item < len(self) else '?'
  97.                     print(f'[DEBUG] item={item} path={vp} err={e}')
  98.                     traceback.print_exc()
  99.                     self._debug_printed = True  # 只打印第一次
  100.                 if self.phase == 'test':
  101.                     return torch.zeros(3, self.resize[0], self.resize[1]), -1
  102.                 item = (item + self.imgs_per_video) % len(self)
  103.         return torch.zeros(3, self.resize[0], self.resize[1]), -1  # 全部失败时返回占位
  104.     def __len__(self):
  105.         return len(self.dataset) * self.imgs_per_video
复制代码
        然后就是训练的代码
  1. import os
  2. import json
  3. import random
  4. # 数据根目录:FFPP_ROOT 或默认 FFDeepFake/data/FaceForensics++/ffpp
  5. _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
  6. _FFDEEPFAKE_ROOT = os.path.dirname(os.path.dirname(_SCRIPT_DIR))  # fakefacecls/ -> FFDeepFake
  7. _FFDEEPFAKE_ROOT = os.path.dirname(_FFDEEPFAKE_ROOT)  # FFDeepFake
  8. _data_root = os.path.join(_FFDEEPFAKE_ROOT, 'data')
  9. _DEFAULT_FFPP = os.path.join(_data_root, 'FaceForensics++', 'ffpp')
  10. ffpproot = os.environ.get('FFPP_ROOT', _DEFAULT_FFPP)
  11. if ffpproot and not ffpproot.endswith(os.sep):
  12.     ffpproot += os.sep
  13. dfdcroot = os.path.join(_data_root, 'dfdc')
  14. celebroot = os.path.join(_data_root, 'celebDF')
  15. deeperforensics_root = os.path.join(_data_root, 'deeper')
  16. def load_json(name):
  17.     with open(name) as f:
  18.         return json.load(f)
  19. def FF_dataset(tag='Origin', codec='c0', part='train'):
  20.     assert tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face', 'FaceShifter']
  21.     assert codec in ['c0', 'c23', 'c40', 'all']
  22.     assert part in ['train', 'val', 'test', 'all']
  23.     if part == 'all':
  24.         return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test')
  25.     if codec == 'all':
  26.         return FF_dataset(tag, 'c0', part) + FF_dataset(tag, 'c23', part) + FF_dataset(tag, 'c40', part)
  27.     path = os.path.join(ffpproot, tag, codec, 'larger_images')
  28.     metafile = load_json(os.path.join(ffpproot, part + '.json'))
  29.     files = []
  30.     if tag == 'Origin':
  31.         for i in metafile:
  32.             files.append([os.path.join(path, i[0]), 0])
  33.             files.append([os.path.join(path, i[1]), 0])
  34.     else:
  35.         for i in metafile:
  36.             files.append([os.path.join(path, i[0] + '_' + i[1]), 1])
  37.             files.append([os.path.join(path, i[1] + '_' + i[0]), 1])
  38.     return files
  39. def make_balance(data):
  40.     tr = [x for x in data if x[1] == 0]
  41.     tf = [x for x in data if x[1] == 1]
  42.     if len(tr) > len(tf):
  43.         tr, tf = tf, tr
  44.     rate = len(tf) // len(tr)
  45.     res = len(tf) - rate * len(tr)
  46.     tr = tr * rate + random.sample(tr, res)
  47.     return tr + tf
  48. def dfdc_dataset(part='train'):
  49.     assert part in ['train', 'val', 'test']
  50.     lf = load_json(os.path.join(dfdcroot, 'DFDC.json'))
  51.     if part == 'train':
  52.         path = os.path.join(dfdcroot, 'dfdc')
  53.         files = make_balance(lf['train'])
  54.     elif part == 'test':
  55.         path = os.path.join(dfdcroot, 'dfdc-test')
  56.         files = lf['test']
  57.     else:
  58.         path = os.path.join(dfdcroot, 'dfdc-val')
  59.         files = lf['val']
  60.     return [[os.path.join(path, i[0]), i[1]] for i in files]
  61. def deeperforensics_dataset(part='train'):
  62.     a = os.listdir(deeperforensics_root)
  63.     d = {i.split('_')[0]: i for i in a}
  64.     metafile = load_json(os.path.join(ffpproot, part + '.json'))
  65.     files = []
  66.     for i in metafile:
  67.         p = os.path.join(deeperforensics_root, d[i[0]])
  68.         files.append([p, 1])
  69.         p = os.path.join(deeperforensics_root, d[i[1]])
  70.         files.append([p, 1])
  71.     return files
  72. try:
  73.     Celeb_test = list(map(lambda x: [os.path.join(celebroot, x[0]), 1 - x[1]], load_json(os.path.join(celebroot, 'celeb.json'))))
  74. except Exception:
  75.     Celeb_test = []
复制代码
        可以看到训练的效果非常的好,基本一个Epoch就可以在Test验证集上达到0.8以上的正确率,且可以观察发现Transformer作为Backbone的效果远比CNN的效果好
3.png


来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册