【Stable diffusion inpaiting】训练自己数据集

2023-12-28 23:47:40

https://github.com/advimman/lama/tree/7dee0e4a3cf5f73f86a820674bf471454f52b74f

prepare your data:
1) Create masks named as `[images_name]_maskXXX[image_suffix]`, put images and masks in the same folder.
You can use the script for random masks generation.
Check the format of the files:
image1_mask001.png
image1.png
image2_mask001.png
image2.png
Specify image_suffix, e.g. .png or .jpg or _input.jpg in configs/prediction/default.yaml.

https://github.com/advimman/lama/blob/main/bin/gen_mask_dataset.py

如果图像不是正方形,使用crop或者transform变换

import os
import argparse
from PIL import Image

def crop_images(input_folder, output_folder, size):
    # 遍历输入文件夹中的所有文件
    for filename in os.listdir(input_folder):
        if filename.endswith(".png"):
            # 构建输入和输出文件的路径
            input_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename)
            
            # 打开图像并裁剪为指定大小
            image = Image.open(input_path)
            cropped_image = image.crop((0, 0, size, size))
            
            # 保存裁剪后的图像
            cropped_image.save(output_path)

if __name__ == "__main__":
    # 创建解析器并添加参数
    parser = argparse.ArgumentParser(description="Crop PNG images in a folder to 512x512 size.")
    parser.add_argument("input_folder", default='./infrared_only', help="Path to the input folder.")
    parser.add_argument("output_folder", default='./square_infrared_only', help="Path to the output folder.")
    parser.add_argument("--size", type=int, default=512, help="Size of the cropped images. Default is 512.")
    
    # 解析命令行参数
    args = parser.parse_args()
    
    # 调用函数进行裁剪
    crop_images(args.input_folder, args.output_folder, args.size)

配置文件修改

generator_kind: random

mask_generator_kwargs:
  irregular_proba: 1
  irregular_kwargs:
    min_times: 4
    max_times: 5
    max_width: 50
    max_angle: 4
    max_len: 100

  box_proba: 0.3
  box_kwargs:
    margin: 0
    bbox_min_size: 10
    bbox_max_size: 50
    max_times: 5
    min_times: 1

  segm_proba: 0
  squares_proba: 0

  variants_n: 5

max_masks_per_image: 1

cropping:
  out_min_size: 256
  handle_small_mode: upscale
  out_square_crop: True
  crop_min_overlap: 1

max_tamper_area: 0.5

解释

generator_kind: 这个参数指定了生成器的类型,这里设置为"random",表示使用随机生成器[1]。

mask_generator_kwargs: 这个参数是一个字典,包含了生成掩码的相关参数设置。

irregular_proba: 这个参数指定了生成不规则掩码的概率,设置为1表示始终生成不规则掩码。

irregular_kwargs: 这个参数是一个字典,包含了生成不规则掩码时的具体参数设置。

min_times: 每个不规则掩码最小生成次数。
max_times: 每个不规则掩码最大生成次数。
max_width: 不规则掩码的最大宽度。
max_angle: 不规则掩码的最大角度。
max_len: 不规则掩码的最大长度。
box_proba: 这个参数指定了生成方框掩码的概率,设置为0.3表示以30%的概率生成方框掩码。

box_kwargs: 这个参数是一个字典,包含了生成方框掩码时的具体参数设置。

margin: 方框掩码的边距。
bbox_min_size: 方框掩码的最小尺寸。
bbox_max_size: 方框掩码的最大尺寸。
max_times: 每个方框掩码最大生成次数。
min_times: 每个方框掩码最小生成次数。
segm_proba: 这个参数指定了生成分割掩码的概率,设置为0表示不生成分割掩码。

squares_proba: 这个参数指定了生成方形掩码的概率,设置为0表示不生成方形掩码。

variants_n: 这个参数指定了生成器生成的变体数量,这里设置为5。

max_masks_per_image: 这个参数指定了每张图像生成的最大掩码数量,这里设置为1。

cropping: 这个参数是一个字典,包含了裁剪图像的相关参数设置。

out_min_size: 裁剪后的图像最小尺寸。
handle_small_mode: 处理小尺寸图像的模式,这里设置为"upscale"表示放大处理。
out_square_crop: 是否进行方形裁剪,这里设置为True表示进行方形裁剪。
crop_min_overlap: 裁剪时的最小重叠区域。
max_tamper_area: 这个参数指定了最大篡改区域的面积比例,这里设置为0.5。

运行配置文件

https://github.com/lorenzo-stacchio/Stable-Diffusion-Inpaint/blob/main/scripts/generate_llama_mask/README.md

python scripts/generate_llama_mask/gen_mask_dataset.py --config ./scripts/generate_llama_mask/data_gen_configs/random_medium_256.yaml --indir data/infrared/mask --outdir data/infrared/mask_gen --ext png

检查输入图像是不是8bit,不是的话需要转换。

image = Image.open(infile).convert('RGB')

变成

if not bit16:
    image = Image.open(infile).convert('RGB')
else:
    #读取16位深度图(像素范围0~65535),并将其转化为8位(像素范围0~255)
    uint16_img = cv2.imread(infile, -1)    #在cv2.imread参数中加入-1,表示不改变读取图像的类型直接读取,否则默认的读取类型为8位。
    uint16_img -= uint16_img.min()
    uint16_img = uint16_img / (uint16_img.max() - uint16_img.min())
    uint16_img *= 255
    #使得越近的地方深度值越大,越远的地方深度值越小,以达到伪彩色图近蓝远红的目的。
    # uint16_img = 255 - uint16_img

    # cv2 中的色度图有十几种,其中最常用的是 cv2.COLORMAP_JET,蓝色表示较高的深度值,红色表示较低的深度值。
    # cv.convertScaleAbs() 函数中的 alpha 的大小与深度图中的有效距离有关,如果像我一样默认深度图中的所有深度值都在有效距离内,并已经手动将16位深度转化为了8位深度,则 alpha 可以设为1。
    # im_color=cv2.applyColorMap(cv2.convertScaleAbs(uint16_img,alpha=1),cv2.COLORMAP_JET)
    uint8_img = cv2.convertScaleAbs(uint16_img,alpha=1)
    if len(uint8_img.shape) == 2:
        uint8_img_rgb = np.repeat(uint8_img[:, :, np.newaxis], 3, axis=2)
    elif len(uint8_img.shape) == 3:
        if uint8_img.shape[2] == 1:
            uint8_img_rgb = np.repeat(uint8_img, 3, axis=2)
    else:
        raise TypeError
    #convert to mat png
    image=Image.fromarray(uint8_img_rgb).convert('RGB')

csv

image_path,mask_path,partition
desk_pc_mouse2.png,desk_pc_mouse2_mask.png,train
desk_pc_mouse2.png,desk_pc_mouse2_mask.png,validation

python scripts\generate_llama_mask\generate_csv.py --llama_masked_outdir output_generated_folder/ --csv_out_path out_path.csv

data/infrared/mask_gen

model:

type: 检测器的名称,这里是MaskRCNN [1]
backbone: 主干网络的配置
type: 主干网络的类别,这里是ResNet [1]
depth: 主干网络的深度,这里是50 [1]
num_stages: 主干网络状态的数目,这些状态产生的特征图作为后续的head的输入 [1]
out_indices: 每个状态产生的特征图输出的索引 [1]
frozen_stages: 第一个状态的权重是否被冻结 [1]
norm_cfg: 归一化层的配置项 [1]
norm_eval: 是否冻结BN里的统计项 [1]
style: 主干网络的风格 [1]
init_cfg: 加载通过ImageNet预训练的模型 [1]
neck: 检测器的neck配置
type: neck的类型,这里是FPN [1]
in_channels: 输入通道数 [1]
out_channels: 金字塔特征图每一层的输出通道 [1]
num_outs: 输出的范围 [1]
rpn_head: RPN head的配置
type: RPN head的类型,这里是RPNHead [1]
in_channels: 每个输入特征图的输入通道 [1]
feat_channels: head卷积层的特征通道 [1]
anchor_generator: 锚点生成器的配置 [1]
bbox_coder: 在训练和测试期间对框进行编码和解码 [1]
loss_cls: 分类分支的损失函数配置 [1]
loss_bbox: 回归分支的损失函数配置 [1]
roi_head: RoIHead的配置
type: RoI head的类型 [1]
bbox_roi_extractor: 用于bbox回归的RoI特征提取器 [1]
python scripts/generate_llama_mask/generate_csv.py --llama_masked_outdir data/infrared/mask_gen/ --csv_out_path data/infrared/out_path.csv

训练

python3 main_inpainting.py --train --name  custom_training --base  configs/latent-diffusion/inpainting_catsv2.yaml  --gpus 0,1   --seed  42

文章来源:https://blog.csdn.net/prinTao/article/details/135276652
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。