diffusers-训练自己的模型

2023-12-21 08:16:25

一、搭建dataset

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
基于datasets这个库创建的dataloader,底层代码还待探索

二、修改模型结构(非必要)

尽量可以利用已有的预训练权重去训练模型,但是权重并不一定能够完全是适配,所以还需要自己来视情况做修改,未能加载预训练权重的那一部分参数必须要重新开始训练,不存在finetune一说
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

三、无条件样本生成

先搭建环境
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
train_unconditional.py这个代码待细看
在这里插入图片描述
train_unconditional.py中创建的unet

model = UNet2DModel(
    sample_size=args.resolution,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
# Initialize the scheduler
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=args.ddpm_num_steps,
        beta_schedule=args.ddpm_beta_schedule,
        prediction_type=args.prediction_type,
    )
else:
    noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)

# Initialize the optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

默认的优化器和采样器

dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")

augmentations = transforms.Compose(
    [
        transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
        transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

读取数据模块

在这里插入图片描述

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --output_dir="ddpm-ema-flowers-64" \
  --mixed_precision="fp16" \
  --push_to_hub
# 单卡训

accelerate launch --multi_gpu train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --output_dir="ddpm-ema-flowers-64" \
  --mixed_precision="fp16" \
  --push_to_hub
多卡训

在这里插入图片描述

预测代码(小疑问,这个路径咋确定的呢?)

四、

文章来源:https://blog.csdn.net/qq_45692660/article/details/135119717
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。