import { basicApi } from "./apiBasic"; import { define } from "../define/define"; import { promises as fspromises } from 'fs' import { errorMessage, successMessage } from "../main/generalTools"; export class SdApi { constructor() { this.baseUrl = global.config?.webui_api_url; this.sd_setting = null; } /** * 获取当前SD的服务器中所有的lora信息 * @returns */ async getAllLoras(baseURL = null) { let url = this.baseUrl + "sdapi/v1/loras"; if (baseURL != null) { url = baseURL + "sdapi/v1/loras"; } return await basicApi.get(url); } /** * 获取当前的所有的checkpoint模型 * @param {*} baseURL */ async getAllSDModel(baseURL = null) { let url = this.baseUrl + "sdapi/v1/sd-models"; if (baseURL != null) { url = baseURL + "sdapi/v1/sd-models"; } return await basicApi.get(url); } /** * 获取当前连接的所有的samplers(采样器) * @param {*} baseURL * @returns */ async getAllSamplers(baseURL = null) { try { let url = this.baseUrl + "sdapi/v1/samplers"; if (baseURL != null) { url = baseURL + "sdapi/v1/samplers"; } return await basicApi.get(url); } catch (error) { throw error; } } async txt2img(data, baseURL = null) { try { if (this.sd_setting == null) { this.sd_setting = JSON.parse(await fspromises.readFile(define.sd_setting, 'utf-8')); this.baseUrl = this.sd_setting.setting.webui_api_url; } // 加上通用前缀 data.prompt = this.sd_setting.webui.prompt + data.prompt data.negative_prompt = this.sd_setting.webui.negative_prompt; data.sampler_name = this.sd_setting.webui.sampler_name; data.cfg_scale = this.sd_setting.webui.cfg_scale; data.n_iter = 1; data.steps = this.sd_setting.webui.steps; data.save_images = false; data.batch_size = data.batch_size ? data.batch_size : 1; if (data.width == null) { data.width = 512; } if (data.height == null) { data.height = 512; } let url = this.baseUrl + "sdapi/v1/txt2img"; if (baseURL != null) { url = baseURL + "sdapi/v1/txt2img"; } let res = await basicApi.post(url, data); return res; } catch (error) { throw error; } } }