当前位置:AIGC资讯 > AIGC > 正文

Datawhale X 魔搭 AI夏令营第四期 AIGC方向 学习笔记(一)

本期主要任务是了解AI文生图的原理并进行相关实践

下面是对baseline部分代码的功能介绍:

安装Data-juicere和DiffSynth-Studio

!pip install simple-aesthetics-predictor
!pip install -v -e data-juicer
!pip uninstall pytorch-lightning -y
!pip install peft lightning pandas torchvision
!pip install -e DiffSynth-Studio

基本的通过pip安装,"!"控制语句在终端进行操作。simple-aesthetics-predictor 这个包,参考pypi上项目描述,是一个基于CLIP的美学预测器,用于预测图片的美学质量。"-v"、"-e"命令用于设定安装模式. data-juicer ,参考github上的原项目Readme文件,是一个“用于大语言模型的一站式数据处理系统”。peft 与参数高效微调相关,lightning 是用于简化训练过程的库,pandas和torchvision就不多说了。DiffSynth-Studio 则是一种用于实现图片和视频风格转换的引擎。

下载数据集

从modelscope上下载某个数据集,指定了目标数据集的路径,子集名称,拆分部分(训练集)和下载完成后的缓存目录。

保存数据集中的图片和元数据

os.makedirs("./data/lora_dataset/train", exist_ok=True)
os.makedirs("./data/data-juicer/input", exist_ok=True)
with open("./data/data-juicer/input/metadata.jsonl", "w") as f:
    for data_id, data in enumerate(tqdm(ds)):
        image = data["image"].convert("RGB")
        image.save(f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg")
        metadata = {"text": "二次元", "image": [f"/mnt/workspace/kolors/data/lora_dataset/train/{data_id}.jpg"]}
        f.write(json.dumps(metadata))
        f.write("\n")

这部分主要进行对下载得到的数据集的遍历,将其中的图片转化成RGB格式后保存到指定路径(../data/lora_dataset/train)。另外创建由文本和对应图片构成的字典作为元数据写入json文件保存

数据处理

在变量 data_juicer_config 中定义了数据处理的各项配置信息,并将其写入yaml文件中。之后调用dj-process命令开启数据处理,并通过该配置文件传入相关参数。

保存处理好的数据

主要是从 result.jsonl 文件中进行文本和图像的保存,并将文件名和文本信息存至csv文件中

训练模型 

from diffsynth import download_models
download_models(["Kolors", "SDXL-vae-fp16-fix"])

!python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py -h

下载模型;终端查看训练脚本输入参数

cmd = """
python DiffSynth-Studio/examples/train/kolors/train_kolors_lora.py \
  --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
  --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
  --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
  --lora_rank 16 \
  --lora_alpha 4.0 \
  --dataset_path data/lora_dataset_processed \
  --output_path ./models \
  --max_epochs 1 \
  --center_crop \
  --use_gradient_checkpointing \
  --precision "16-mixed"
""".strip()

os.system(cmd)

 这一段定义了训练过程需要在终端执行的命令,主要包含以下内容:指定了预训练需要的Unet模型路径、文本编码器模型路径和fp16VAE模型路径;指定lora的等级和alpha值相关参数;指定数据集路径、输出路径;指定最大训练轮数,使用中心裁剪、梯度检查点,和精度参数。

加载模型 

def load_lora(model, lora_rank, lora_alpha, lora_path):
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        init_lora_weights="gaussian",
        target_modules=["to_q", "to_k", "to_v", "to_out"],
    )
    model = inject_adapter_in_model(lora_config, model)
    state_dict = torch.load(lora_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=False)
    return model

# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                             file_path_list=[
                                 "models/kolors/Kolors/text_encoder",
                                 "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
                                 "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
                             ])
pipe = SDXLImagePipeline.from_model_manager(model_manager)

# Load LoRA
pipe.unet = load_lora(
    pipe.unet,
    lora_rank=16, # This parameter should be consistent with that in your training script.
    lora_alpha=2.0, # lora_alpha can control the weight of LoRA.
    lora_path="models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt"
)

 load_lora 函数加载loRA模型并进行相关参数配置(

model:要注入 LoRA 适配器的原始模型。

lora_rank:LoRA 适配器的秩,用于控制适配器的复杂度。

lora_alpha:LoRA 适配器的缩放因子,用于控制其权重。

lora_path:包含预训练 LoRA 权重的文件路径。) 

使用 inject_adapt_in_model 将loRA注入原始模型,加载loRA预训练的权重字典并应用至模型中。

后续部分并不熟悉各实例的作用,暂且一放。

生成图像

torch.manual_seed(0)
image = pipe(
    prompt="二次元,一个紫色短发小女孩,在家中沙发上坐着,双手托着腮,很无聊,全身,粉色连衣裙",
    negative_prompt="丑陋、变形、嘈杂、模糊、低对比度",
    cfg_scale=4,
    num_inference_steps=50, height=1024, width=1024,
)
image.save("1.jpg")

设置随机种子值使随机操作具有可重复性。使用pipe对象进行生成,给出正负面提示词,配置尺度,推理步数和图像尺寸参数。

在尝试了自己的一系列提示词后得到如下八张图,内容类似baseline原本给的,主线换成了足球:

主要存在的问题:部分画风不统一;部分图细节不佳;对“足球”一词的表现错误(应该是中文输入翻译问题);部分提示词的信息未有效表现出来

总结

### 文章总结
本期文章主要围绕AI图像生成的原理与实践进行展开,涵盖了从环境搭建、数据准备、模型训练到图像生成的全流程。以下是文章关键内容的总结:
#### 1. 环境搭建
- **安装必要的库**:使用pip安装了一系列相关的Python库,包括`data-juicer`(用于数据处理)、`simple-aesthetics-predictor`(用于美学预测)、`peft`和`lightning`(与模型训练和参数微调相关)、`torchvision`和`pandas`(用于数据处理),以及`DiffSynth-Studio`(图像和视频风格转换引擎)。
#### 2. 数据准备
- **下载数据集**:从modelscope下载指定数据集,并选择训练集部分和缓存目录。
- **保存图片和元数据**:遍历数据集,将图片转换为RGB格式并存储,同时生成包含图片路径和描述性文本的元数据文件。
#### 3. 数据处理
- **配置并运行数据处理**:在配置文件中定义数据处理参数,通过`dj-process`命令执行数据处理流程,完成后将处理结果保存至指定文件。
#### 4. 模型训练
- **下载预训练模型**:从DiffSynth仓库下载所需的预训练模型,包括Unet模型、文本编码器模型以及fp16VAE模型。
- **执行训练脚本**:编写并执行训练命令,其中包含指定模型路径、数据集路径、训练参数等信息。训练过程使用了LoRA技术来调整预训练模型以适应特定任务。
#### 5. 模型加载与图像生成
- **加载LoRA模型**:编写加载LoRA模型的函数,将其注入到Unet模型中,并加载预训练的LoRA权重。
- **图像生成**:通过指定的pipe对象输入文本提示词生成图像,调整配置包括正负面提示词、图像尺寸等参数,最终保存生成的图像文件。
#### 6. 实践结果与反思
- **生成结果**:成功生成了一系列图像,但存在部分画风不一致、细节不佳、特定词汇表现错误(可能是中文输入翻译问题)以及提示词信息未完全呈现等问题。
通过这次实践,文章展示了从数据准备到模型训练再到最终图像生成的完整流程,同时也指出了在生成图像过程中遇到的一些问题和需要### 改进建议
针对上述实践中遇到的问题,以下是一些具体的改进建议,旨在提高AI图像生成的质量和一致性:
#### 1. 数据处理与优化
- **增强数据多样性**:增加数据集中的图片数量和种类,尤其是包含目标主题(如本例中的“足球”)的相关图片,以提高模型对特定主题的理解和泛化能力。
- **标签标准化**:确保元数据中的文本描述准确且标准化,避免出现因翻译或理解错误导致的信息损失。可以考虑采用专业的翻译服务或使用标准化的标签库。
#### 2. 模型训练调整
- **微调训练参数**:尝试调整训练过程中的各项参数,如学习率、训练轮数、梯度检查点的使用情况等,以找到最佳的模型训练配置。
- **增强LaRA训练**:针对LoRA(或LaRA)适配器的训练,可以探索不同的秩(`lora_rank`)和缩放因子(`lora_alpha`)设置,以找到最适合特定任务的参数组合。同时,可以尝试使用更大的训练数据集或更复杂的模型架构来提高LoRA的性能。
#### 3. 提示词优化
- **精确描述**:在生成图像时,使用更具体、更精确的提示词来描述想要的场景和细节,这有助于引导模型生成更符合预期的输出。
- **避免歧义**:确保提示词之间没有歧义或冲突,有些词语可能在不同语境下有不同的含义,需要仔细选择和组合。
- **使用权重指导**:为不同的提示词设置权重,以控制它们在图像生成过程中的重要性。例如,可以给关键特征分配更高的权重,以确保它们在生成图像中得到更好的表现。
#### 4. 反馈与迭代
- **用户反馈**:收集和分析用户对于生成图像的反馈意见,了解他们的需求和偏好。根据用户反馈调整模型训练或生成参数,以提升用户满意度和生成图像的质量。
- **持续迭代**:将上述改进措施逐一实施并观察效果,不断迭代优化模型和训练过程,以提高整体生成能力的稳定性和准确性。
#### 5. 技术探索与集成
- **引入新技术**:关注AI领域的新技术和研究成果,及时将其应用到图像生成任务中。例如,可以探索引入对抗性训练(GANs)、多模态学习等先进技术来提高图像生成的质量和多样性。
- **多模型融合**:尝试将不同的预训练模型进行融合,利用它们的互补优势来提高整体性能。例如,

更新时间 2024-08-19