⚠ Switch to EXCALIDRAW VIEW in the MORE OPTIONS menu of this document. ⚠ You can decompress Drawing data with the command palette: ‘Decompress current Excalidraw file’. For more info check in plugin settings under ‘Saving’
Code Block
# Only use the blending dataset class in training
# ###
# def collate_fn(batch):
# batch = [item for item in batch if item is not None] # 过滤无效样本
# if len(batch) == 0:
# return [], [], []
# images, labels, img_names = zip(*batch)
# return torch.stack(images), torch.tensor(labels), list(img_names)
# ###
if 'dataset_type' in config and config['dataset_type'] == 'blend':
if config['model_name'] == 'facexray':
train_set = FFBlendDataset(config)
elif config['model_name'] == 'fwa':
train_set = FWABlendDataset(config)
elif config['model_name'] == 'sbi':
train_set = SBIDataset(config, mode='train')
elif config['model_name'] == 'lsda':
train_set = LSDADataset(config, mode='train')
else:
raise NotImplementedError(
'Only facexray, fwa, sbi, and lsda are currently supported for blending dataset'
)
elif 'dataset_type' in config and config['dataset_type'] == 'pair':
train_set = pairDataset(config, mode='train') # Only use the pair dataset class in training
elif 'dataset_type' in config and config['dataset_type'] == 'iid':
train_set = IIDDataset(config, mode='train')
elif 'dataset_type' in config and config['dataset_type'] == 'I2G':
train_set = I2GDataset(config, mode='train')
elif 'dataset_type' in config and config['dataset_type'] == 'lrl':
train_set = LRLDataset(config, mode='train')
else:
train_set = DeepfakeAbstractBaseDataset(
config=config,
mode='train',
)
if config['model_name'] == 'lsda':
from dataset.lsda_dataset import CustomSampler
custom_sampler = CustomSampler(num_groups=2*360, n_frame_per_vid=config['frame_num']['train'], batch_size=config['train_batchSize'], videos_per_group=5)
train_data_loader = \
torch.utils.data.DataLoader(
dataset=train_set,
batch_size=config['train_batchSize'],
num_workers=int(config['workers']),
sampler=custom_sampler,
collate_fn=train_set.collate_fn,
)
elif config['ddp']:
sampler = DistributedSampler(train_set)
train_data_loader = \
torch.utils.data.DataLoader(
dataset=train_set,
batch_size=config['train_batchSize'],
num_workers=int(config['workers']),
collate_fn=train_set.collate_fn,
sampler=sampler
)
else:
train_data_loader = \
torch.utils.data.DataLoader(
dataset=train_set,
batch_size=config['train_batchSize'],
shuffle=True,
num_workers=int(config['workers']),
collate_fn=train_set.collate_fn,
)
return train_data_loaderExcalidraw Data
Text Elements
def prepare_training_data(config)
会进入LSDA_dataset
train_data_loader
如果探测器不为需要特殊处理数据的探测器,则会进入abstract_dataset,里的DeepfakeAbstractBaseDataset
def prepare_testing_data(config)
test_data_loader
prepare the model (detector)
lsda_detector.py
DS
Element Links
MSYrD2Fm: abstract_dataset