当前位置:AIGC资讯 > AIGC > 正文

stable diffusion API 调用,超级详细代码示例和说明

本文主要介绍 stable diffusion API 调用,准确来说是对 stable diffusion webui 的 API 调用。接口文档可以查看:

http://sd-webui.test.cn/docs

这里的 sd-webui.test.cn 是自己的 sd webui Endpoint。

文生图是:/sdapi/v1/txt2img 这个 POST 接口。

图生图是:/sdapi/v1/img2img 这个 POST 接口。

本文主要介绍文生图 txt2img 接口。

文生图 txt2img 接口

以下是添加了两个 ControlNet,4 批次,每批次生成 1 张图,并指定了基础模型、VAE 等的入参 JSON:

{
  "alwayson_scripts": {
    "controlnet": {
      "args": [
        {
          "control_mode": 0,
          "enabled": true,
          "guidance_end": 0.5,
          "guidance_start": 0.0,
          "input_image": "base64SrcImg",
          "lowvram": false,
          "model": "control_v11p_sd15_softedge [a8575a2a]",
          "module": "softedge_pidinet",
          "pixel_perfect": true,
          "processor_res": 0,
          "resize_mode": 1,
          "threshold_a": 0,
          "threshold_b": 0,
          "weight": 0.3
        },
        {
          "control_mode": 0,
          "enabled": true,
          "guidance_end": 0.5,
          "guidance_start": 0.0,
          "input_image": "base64SrcImg",
          "lowvram": false,
          "model": "control_v11f1p_sd15_depth [cfd03158]",
          "module": "depth_midas",
          "pixel_perfect": true,
          "processor_res": 0,
          "resize_mode": 1,
          "threshold_a": 0,
          "threshold_b": 0,
          "weight": 0.75
        }
      ]
    }
  },
  "batch_size": 4,
  "cfg_scale": 7,
  "height": 512,
  "negative_prompt": "EasyNegative, paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans,extra fingers,fewer fingers,strange fingers,bad hand,backlight, (worst quality, low quality:1.4), watermark, logo, bad anatomy,lace,rabbit,back,",
  "override_settings": {
    "sd_model_checkpoint": "chosenMix_chosenMix.ckpt [dd0aacadb6]",
    "sd_vae": "pastel-waifu-diffusion.vae.pt"
  },
  "clip_skip": 2,
  "prompt": ",(best quality:1.25),( masterpiece:1.25), (ultra high res:1.25), (no human:1.3),<lora:tachi-e:1>,(white background:2)",
  "restore_faces": false,
  "sampler_index": "DPM++ SDE Karras",
  "sampler_name": "",
  "script_args": [
  ],
  "seed": -1,
  "steps": 28,
  "tiling": false,
  "width": 512
}

其中 ControlNet 参数解释可以参考:
sd-webui-controlnet 接口调用 API 文档

input_image : 用于此单元的图像。默认为 null mask : 用于过滤图像的掩码 pixel_perfect。默认为 null module : 在将图像传递给此单元之前在其上使用的预处理器。接受/controlnet/module_list 路由返回的值。默认为 none model : 用于此单元中的调节的模型的名称。接受/controlnet/model_list 路由返回的值。默认为 None weight : 此单元的权重。默认为 1 resize_mode : 如何调整输入图像以适应生成的输出分辨率。默认为 Scale to Fit (Inner Fit)。接受的值为: 0 或 Just Resize:只需将图像调整为目标宽度/高度 1 或 Scale to Fit (Inner Fit):按比例缩放和裁剪以适应最小尺寸。保持比例。
2 或 Envelope (Outer Fit):按比例缩放以适应最大尺寸。保持比例。
lowvram : 是否通过处理时间来补偿低 GPU 内存。默认为 false
processor_res : 预处理器的分辨率。默认为 64
threshold_a : 预处理器的第一个参数。仅在预处理器接受参数时生效。默认为 64 threshold_b : 预处理器的第二个参数,用法与上述相同。默认为 64 guidance_start : 此单元开始发挥作用的生成比例。默认为 0.0 guidance_end : 此单元停止发挥作用的生成比例。默认为 1.0 control_mode : 有关用法,请参见相关问题。默认为 0。接受的值为: 0 或 Balanced:平衡,对提示和控制模型没有偏好 1 或 My prompt is more important:提示比模型更有影响力 2 或 ControlNet is more important:控制网络模型比提示更有影响力 pixel_perfect : 启用像素完美的预处理器。默认为 false

Java 封装入参类

StableDiffusionTextToImg 类:

import com.fasterxml.jackson.annotation.JsonInclude;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.Serializable;
import java.util.List;

@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class StableDiffusionTextToImg implements Serializable {

    /**
     * 去噪强度
     */
    private Integer denoising_strength;
    private Integer firstphase_width;
    private Integer firstphase_height;

    /**
     * 高清修复
     * 缩写hr代表的就是webui中的"高分辨率修复 (Hires. fix)",相关的参数对应的是webui中的这些选项:
     */
    private Boolean enable_hr;
    /**
     * default 2
     */
    private Integer hr_scale;
    private String hr_upscaler;
    private Integer hr_second_pass_steps;
    private Integer hr_resize_x;
    private Integer hr_resize_y;
    private String hr_sampler_name;
    private String hr_prompt;
    private String hr_negative_prompt;

    /**
     * 正向提示词, 默认 ""
     * lora 需要放在 prompt 里
     */
    private String prompt;

    /**
     * 反向提示词, 默认 ""
     */
    private String negative_prompt;

    private List<String> styles;

    /**
     * 随机数种子 (Seed)
     */
    private Integer seed;

    private Integer clip_skip;


    /**
     *
     */
    private Integer subseed;

    /**
     *
     */
    private Integer subseed_strength;

    /**
     * 高度
     */
    private Integer seed_resize_from_h;

    /**
     * 宽度
     */
    private Integer seed_resize_from_w;


    /**
     * 采样方法 (Sampler), 默认 null
     */
    private String sampler_name;

    /**
     * 采样方法 (Sampler) 下标
     */
    private String sampler_index;

    /**
     * 批次数 default: 1
     */
    private Integer batch_size;

    /**
     * 每批的数量 default: 1
     */
    private Integer n_iter;

    /**
     * 迭代步数 (Steps), 默认 50
     */
    private Integer steps;

    /**
     * 提示词引导系数, 默认7
     */
    private Double cfg_scale;

    /**
     * 宽度
     */
    private Integer width;

    /**
     * 高度
     */
    private Integer height;

    /**
     * 面部修复, 默认 false
     */
    private Boolean restore_faces;

    /**
     * 平铺图 默认 false
     */
    private Boolean tiling;

    /**
     * 默认 false
     */
    private Boolean do_not_save_samples;

    /**
     * 默认 false
     */
    private Boolean do_not_save_grid;

    /**
     * 默认 null
     */
    private Integer eta;

    /**
     * 默认 0
     */
    private Integer s_min_uncond;

    /**
     * 默认 0
     */
    private Integer s_churn;

    /**
     * 默认 null
     */
    private Integer s_tmax;

    /**
     * 默认 0
     */
    private Integer s_tmin;

    /**
     * 默认 1
     */
    private Integer s_noise;

    /**
     * 默认 null
     */
    private OverrideSettings override_settings;

    /**
     * 默认 true
     */
    private Boolean override_settings_restore_afterwards;


    private List<Object> script_args;

    /**
     * 默认 null
     */
    private String script_name;

    /**
     * 默认 true
     */
    private Boolean send_images;

    /**
     * 默认 false
     */
    private Boolean save_images;

    /**
     * 默认 {}
     */
    private AlwaysonScripts alwayson_scripts;

上述重要的参数都标注了注释,基本够用,下面也会给出入参类的构建示例。

OverrideSettings 类:用于指定基础模型和 VAE:

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class OverrideSettings {
    private String sd_model_checkpoint;
    private String sd_vae;
}

AlwaysonScripts 类,其中可以指定 ControlNet:

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * 参考:https://zhuanlan.zhihu.com/p/624042359
 */
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class AlwaysonScripts {
    private ControlNet controlnet;
}

ControlNet 类,其中可以指定多组 Args(一个 Args 是一个 ControlNet)

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;

@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ControlNet {
    private List<Args> args;
}

Args 类,即指定一个 ControlNet 的所有参数:

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * 参考:https://github.com/Mikubill/sd-webui-controlnet/wiki/API#integrating-sdapiv12img
 */
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class Args {
    private boolean enabled;
    /**
     * PreProcessor 例如:"module": "lineart_coarse"
     */
    private String module;
    private String model;

    /**
     * defaults to 1
     */
    private double weight = 1.0;
    private String input_image;
    private String mask;

    private int control_mode = 0;

    /**
     * enable pixel-perfect preprocessor. defaults to false
     */
    private boolean pixel_perfect;

    /**
     * whether to compensate low GPU memory with processing time. defaults to false
     */
    private boolean lowvram;
    private int processor_res;
    private int threshold_a;
    private int threshold_b;
    private double guidance_start;
    private double guidance_end = 1.0;

StableDiffusionTextToImgResponse 类,即 stable diffusion webui 的响应结构:

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.Serializable;
import java.util.List;

@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class StableDiffusionTextToImgResponse implements Serializable {

    /**
     * 生成的图片结果 base64
     */
    private List<String> images;

    /**
     * 入参和默认值
     */
    private StableDiffusionTextToImg parameters;

    /**
     * 参数的组合字符串
     */
    private String info;
}

Java 测试调用文生图 API

StableDiffusionTest 类:

import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.assertj.core.util.Lists;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;

import java.io.*;
import java.util.*;

@Slf4j
public class StableDiffusionTest1 {

    @Test
    public void testSdApi() throws IOException {
        StableDiffusionTextToImg body = getArtisticWordStableDiffusionTextToImg();
        final List<String> images = callSdApi(body);
        for (String image : images) {
            writeBase642ImageFile(image, String.format("./%s.png", UUID.randomUUID().toString().replaceAll("-", "")));
        }
    }

    public static void writeBase642ImageFile(String image, String fileName) {
        try (OutputStream outputStream = new FileOutputStream(fileName)) {
            byte[] imageBytes = Base64.getDecoder().decode(image);
            ByteArrayInputStream inputStream = new ByteArrayInputStream(imageBytes);

            byte[] buffer = new byte[1024];
            int bytesRead;
            while ((bytesRead = inputStream.read(buffer)) != -1) {
                outputStream.write(buffer, 0, bytesRead);
            }

            log.info("图片写入成功!");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private StableDiffusionTextToImg getArtisticWordStableDiffusionTextToImg() throws IOException {
        final String base64SrcImg = convertImageToBase64("./cat-768x512.png");
        Args args1 = Args.builder()
                .enabled(true)
                .control_mode(0)
                .guidance_start(0)
                .guidance_end(0.5)
                .weight(0.3)
                .pixel_perfect(true)
                .resize_mode(1)
                .model("control_v11p_sd15_softedge [a8575a2a]")
                .module("softedge_pidinet")
                .input_image(base64SrcImg)
                .build();

        Args args2 = Args.builder()
                .enabled(true)
                .control_mode(0)
                .guidance_start(0)
                .guidance_end(0.5)
                .weight(0.75)
                .pixel_perfect(true)
                .resize_mode(1)
                .model("control_v11f1p_sd15_depth [cfd03158]")
                .module("depth_midas")
                .input_image(base64SrcImg)
                .build();

        String vae = "vae-ft-mse-840000-ema-pruned.safetensors";
        StableDiffusionTextToImg body = StableDiffusionTextToImg.builder().sampler_name("")
                .prompt("(cake:1.8),( 3D:1.8),( shadow:1.8),(best quality:1.25),( masterpiece:1.25), (ultra high res:1.25), (no human:1.3),<lora:tachi-e:1>,(white background:2)")
                .negative_prompt("EasyNegative, paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, glans,extra fingers,fewer fingers,strange fingers,bad hand,backlight, (worst quality, low quality:1.4), watermark, logo, bad anatomy,lace,rabbit,back,")
                .sampler_index("DPM++ SDE Karras")
                .seed(-1)
                .width(768)
                .height(512)
                .restore_faces(false)
                .tiling(false)
                .clip_skip(2)
                .batch_size(4)
                .script_args(new ArrayList<>())
                .alwayson_scripts(AlwaysonScripts.builder().controlnet(ControlNet.builder()
                        .args(Lists.newArrayList(args1, args2)).build()).build())
                .steps(28).override_settings(OverrideSettings.builder()
                        .sd_model_checkpoint("chosenMix_chosenMix.ckpt [dd0aacadb6]")
                        .sd_vae(vae)
                        .build())
                .cfg_scale(7.0).build();
        return body;
    }

    public static String convertImageToBase64(String imagePath) throws IOException {
        File file = new File(imagePath);
        FileInputStream fileInputStream = new FileInputStream(file);
        byte[] imageData = new byte[(int) file.length()];
        fileInputStream.read(imageData);
        fileInputStream.close();
        return Base64.getEncoder().encodeToString(imageData);
    }

    private List<String> callSdApi(StableDiffusionTextToImg body) {
        RestTemplate restTemplate = new RestTemplate();
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);

        HttpEntity<StableDiffusionTextToImg> requestEntity = new HttpEntity<>(body, headers);
        ResponseEntity<JSONObject> entity = restTemplate.postForEntity("http://sd.cn/sdapi/v1/txt2img", requestEntity, JSONObject.class);
        final StableDiffusionTextToImgResponse stableDiffusionTextToImgResponse = handleResponse(entity);
        final List<String> images = stableDiffusionTextToImgResponse.getImages();

        if (CollectionUtils.isEmpty(images)) {
            log.info("empty images");
            return Lists.newArrayList();
        }

        return images;
    }


    private StableDiffusionTextToImgResponse handleResponse(ResponseEntity<JSONObject> response) {
        if (Objects.isNull(response) || !response.getStatusCode().is2xxSuccessful()) {
            log.warn("call stable diffusion api status code: {}", JSONObject.toJSONString(response));
        }

        final JSONObject body = response.getBody();
        if (Objects.isNull(body)) {
            log.error("send request failed. response body is empty");
        }
        return body.toJavaObject(StableDiffusionTextToImgResponse.class);
    }
}

更新时间 2023-11-09