374 lines
14 KiB
JavaScript
Raw Normal View History

2024-05-15 12:57:15 +08:00
import axios from "axios";
import path from "path";
import { DEFINE_STRING } from "../../define/define_string";
import { define } from "../../define/define";
import { ImageStyleDefine } from "../../define/iamgeStyleDefine";
import { cloneDeep } from 'lodash';
let fspromises = require("fs").promises;
const sharp = require('sharp');
import { SdSettingDefine } from "../../define/setting/sdSettingDefine";
import { PublicMethod } from "./publicMethod";
import { Tools } from "../tools";
2024-05-24 13:46:19 +08:00
import { errorMessage, successMessage } from "../generalTools";
import { SdApi } from "../../api/sdApi";
const { v4: uuidv4 } = require('uuid');
2024-05-15 12:57:15 +08:00
export class SD {
constructor(global) {
this.global = global;
this.pm = new PublicMethod(global);
this.tools = new Tools();
2024-05-24 13:46:19 +08:00
this.sdApi = new SdApi();
}
/**
* 获取当前SD服务器所有的lora信息
*/
async GetAllLoras(baseURL = null) {
try {
let data = await this.sdApi.getAllLoras(baseURL);
return successMessage(data);
} catch (error) {
return errorMessage(error.toString());
}
}
/**
* 获取所有的checkpoint模型
* @param {*} baseURL
* @returns
*/
async GetAllSDModel(baseURL = null) {
try {
let data = await this.sdApi.getAllSDModel(baseURL);
return successMessage(data);
} catch (error) {
return errorMessage(error.toString());
}
}
/**
* 获取所有的采样器
* @param {*} baseURL
* @returns
*/
async GetAllSamplers(baseURL = null) {
try {
let data = await this.sdApi.getAllSamplers(baseURL);
return successMessage(data);
} catch (error) {
return errorMessage(error.toString());
}
}
/**
* 加载所有的SD数据
* @param {*} baseURL
* @returns
*/
async LoadSDServiceData(baseURL = null) {
try {
// 加载大模型
let sd_model = await this.GetAllSDModel(baseURL);
// 往sd_model中添加一个默认的选项
sd_model.data.data.unshift({
title: "无",
name: "无",
description: "无",
})
// 加载Lora
let lora = await this.GetAllLoras(baseURL);
lora.data.data.unshift({
Key: "无",
name: "无",
description: "无",
})
// 加载采样器
let sampler = await this.GetAllSamplers(baseURL);
sampler.data.data.unshift({
name: "无",
description: "无",
})
if (!(sd_model.code & lora.code & sampler.code)) {
throw new Error("获取SD数据错误请检查SD WEBUI链接");
}
for (let i = 0; i < lora.data.data.length; i++) {
delete lora.data.data[i].metadata;
}
let data = {
sd_model: sd_model.data.data,
lora: lora.data.data,
sampler: sampler.data.data
}
// 处理当前获取的数据,保存到配置文件中
await SdSettingDefine.SavePropertyValue("sd_model", data.sd_model);
await SdSettingDefine.SavePropertyValue("lora", data.lora);
await SdSettingDefine.SavePropertyValue("sampler", data.sampler);
return successMessage(data);
} catch (error) {
return errorMessage(error.toString());
}
2024-05-15 12:57:15 +08:00
}
/**
* 获取图片风格菜单
* @returns 返回图片风格菜单
*
* */
async GetImageStyleMenu() {
try {
let style = ImageStyleDefine.getImageStyleMenu();
return {
code: 1,
data: style
}
} catch (error) {
2024-05-24 13:46:19 +08:00
return {
code: 0,
message: "不可能出现错误"
}
2024-05-15 12:57:15 +08:00
}
}
/**
* 获取指定的ID的风格信息传入的是一个数组
* @param {*} value id集合
*/
async GetImageStyleInfomation(value) {
try {
if (value) {
value = JSON.parse(value);
} else {
value = [];
}
value = value ? value : [];
let style = ImageStyleDefine.getAllSubStyle();
let tmp = [];
for (let i = 0; i < value.length; i++) {
const element = value[i];
for (let j = 0; j < style.length; j++) {
const item = style[j];
if (item.id == element) {
tmp.push(item);
break;
}
}
}
let newSubStyle = cloneDeep(tmp);
for (let i = 0; i < newSubStyle.length; i++) {
const element = newSubStyle[i];
element.image = path.join(define.image_path, "style/" + element.image);
}
return {
code: 1,
data: newSubStyle
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
/**
* 获取指定ID的分类的子风格信息
* @param {*} value ID
* @returns 返回ID对应的子风格的详细信息
*/
async GetStyleImageSubList(value) {
try {
let subStyle = ImageStyleDefine.getImagePathById(value);
let newSubStyle = cloneDeep(subStyle);
for (let i = 0; i < newSubStyle.length; i++) {
const element = newSubStyle[i];
element.image = path.join(define.image_path, "style/" + element.image);
}
return {
code: 1,
data: newSubStyle
}
} catch (error) {
return {
code: 0,
message: error.toString()
}
}
}
2024-05-24 13:46:19 +08:00
/**
* 单张生图
* @param {*} value 0 生图的参数1 图片的表示用于保存 2 baseUrl
* @returns
*/
async txt2img(value) {
try {
value = JSON.parse(value);
let data = value[0];
let res = await this.sdApi.txt2img(data);
// 将base· 64的图片转换为图片
// 将当前的图片保存到指定的文件夹中然后返回文件路径并且可以复制到指定的文件删除exif信息
let image_paths = [];
for (let i = 0; res.data.images && i < res.data.images.length; i++) {
const element = res.data.images[i];
let image_data = {
base64: element
}
// 将保存图片添加到队列中
let image_name = `sd_${Date.now()}_${uuidv4()}.png`;
let image_path = path.join(define.temp_sd_image, image_name);
image_path = await this.tools.saveBase64ToImage(element, image_path);
image_data["image_path"] = image_path;
image_paths.push(image_data);
}
return successMessage(image_paths);
} catch (error) {
return errorMessage("生图失败,错误信息如下:" + error.toString());
}
}
2024-05-15 12:57:15 +08:00
/**
* 生成一次图片的方法可以区分模式
* @param {图片名称 } image
* @param {任务队列信息} task_list 301198499
*/
async OneImageGeneration(image, task_list, seed = -1) {
let taskPath = path.join(this.global.config.project_path, "scripts/task_list.json")
try {
let imageJson = JSON.parse(await fspromises.readFile(image + '.json', 'utf-8'));
let sd_setting = JSON.parse(await fspromises.readFile(define.sd_setting, 'utf-8'));
let model = imageJson.model;
let image_json = JSON.parse(await fspromises.readFile(image + '.json', 'utf-8'));
let image_path = "";
let target_image_path = "";
if (image_json.name) {
image_path = path.join(this.global.config.project_path, `tmp/${task_list.out_folder}/tmp_${image_json.name}`)
target_image_path = path.join(this.global.config.project_path, `tmp/${task_list.out_folder}/${image_json.name}`)
} else {
image_path = image.replaceAll("input_crop", task_list.out_folder).split(".png")[0] + "_tmp.png";
target_image_path = image.replaceAll("input_crop", task_list.out_folder);
}
2024-05-24 13:46:19 +08:00
let image_styles = await ImageStyleDefine.getImageStyleStringByIds(task_list.image_style_list ? task_list.image_style_list : []);
let prompt = sd_setting.webui.prompt + image_styles;
// 拼接提示词
if (task_list.image_style != null) {
prompt += `((${task_list.image_style})), `;
}
if (task_list.lora != null) {
prompt += `${task_list.lora}, `;
}
prompt += imageJson.webui_config.prompt;
2024-05-15 12:57:15 +08:00
// 判断当前是不是有开修脸修手
let ADetailer = {
args: sd_setting.adetailer
};
if (model == "img2img") {
let web_api = this.global.config.webui_api_url + 'sdapi/v1/img2img'
let sd_config = imageJson["webui_config"];
sd_config.prompt = prompt;
sd_config.seed = seed;
let im = await fspromises.readFile(image, 'binary');
sd_config.init_images = [new Buffer.from(im, 'binary').toString('base64')];
if (imageJson.adetailer) {
let ta = {
ADetailer: ADetailer
}
sd_config.alwayson_scripts = ta;
}
sd_config.height = sd_setting.webui.height;
sd_config.width = sd_setting.webui.width;
const response = await axios.post(web_api, sd_config);
let info = JSON.parse(response.data.info);
if (seed == -1) {
seed = info.seed;
}
// 目前是单图出图
let images = response.data.images;
let imageData = Buffer.from(images[0].split(",", 1)[0], 'base64');
await sharp(imageData)
.toFile(image_path)
.then(async () => {
// console.log("图生图成功" + image_path);
await this.tools.deletePngAndDeleteExifData(image_path, target_image_path);
})
.catch(err => {
throw new Error(err);
});
return seed;
} else if (model == "txt2img") {
let body = {
"prompt": prompt,
"negative_prompt": imageJson.webui_config.negative_prompt,
"seed": seed,
"sampler_name": imageJson.webui_config.sampler_name,
// 提示词相关性
"cfg_scale": imageJson.webui_config.cfg_scale,
"width": sd_setting.webui.width,
"height": sd_setting.webui.height,
"batch_size": 1,
"n_iter": 1,
"steps": imageJson.webui_config.steps,
"save_images": false,
}
let web_api = this.global.config.webui_api_url + 'sdapi/v1/txt2img';
if (imageJson.adetailer) {
let ta = {
ADetailer: ADetailer
}
body.alwayson_scripts = ta;
}
const response = await axios.post(web_api, body);
let info = JSON.parse(response.data.info);
if (seed == -1) {
seed = info.seed;
}
// 目前是单图出图
let images = response.data.images;
let imageData = Buffer.from(images[0].split(",", 1)[0], 'base64');
await sharp(imageData)
.toFile(image_path)
.then(async () => {
// console.log("文生图成功" + image_path);
await this.tools.deletePngAndDeleteExifData(image_path, target_image_path);
})
.catch(err => {
// console.log(err)
throw new Error(err);
});
return seed;
} else {
throw new Error("SD 模式错误");
}
} catch (error) {
// 当前队列执行失败移除整个批次的任务
this.global.requestQuene.removeTask(task_list.out_folder, null)
this.global.fileQueue.enqueue(async () => {
// 记录失败状态
let task_list_json = JSON.parse(await fspromises.readFile(taskPath, 'utf-8'));
// 修改指定的列表的数据
task_list_json.task_list.map(a => {
if (a.id == task_list.id) {
a.status = "error";
a.errorMessage = error.toString();
}
})
// 写入
await fspromises.writeFile(taskPath, JSON.stringify(task_list_json));
this.global.newWindow[0].win.webContents.send(DEFINE_STRING.IMAGE_TASK_STATUS_REFRESH, {
out_folder: task_list.out_folder,
status: "error"
});
})
throw error;
}
}
}