fix(inference): Update ggml sd
This commit is contained in:
parent
b89722e439
commit
ceb01e5eda
@ -4,16 +4,26 @@ from fastapi.staticfiles import StaticFiles
|
||||
import subprocess
|
||||
import os
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
OUTPUT_DIR = "output"
|
||||
SD_PATH = os.environ.get("SD_PATH", "./sd")
|
||||
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
|
||||
BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000")
|
||||
MODEL_NAME = os.environ.get(
|
||||
"MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin")
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
prompt: str
|
||||
neg_prompt: str
|
||||
seed: int
|
||||
steps: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
# Create the OUTPUT_DIR directory if it does not exist
|
||||
if not os.path.exists(OUTPUT_DIR):
|
||||
os.makedirs(OUTPUT_DIR)
|
||||
@ -26,33 +36,37 @@ if not os.path.exists(MODEL_DIR):
|
||||
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output")
|
||||
|
||||
|
||||
def run_command(prompt: str, filename: str):
|
||||
def run_command(payload: Payload, filename: str):
|
||||
# Construct the command based on your provided example
|
||||
command = [SD_PATH,
|
||||
"-m", os.path.join(MODEL_DIR, MODEL_NAME),
|
||||
"-p", prompt,
|
||||
"-o", os.path.join(OUTPUT_DIR, filename)
|
||||
"--model", f'{os.path.join(MODEL_DIR, MODEL_NAME)}',
|
||||
"--prompt", f'"{payload.prompt}"',
|
||||
"--negative-prompt", f'"{payload.neg_prompt}"',
|
||||
"--height", str(payload.height),
|
||||
"--width", str(payload.width),
|
||||
"--steps", str(payload.steps),
|
||||
"--seed", str(payload.seed),
|
||||
"--mode", 'txt2img',
|
||||
"-o", f'{os.path.join(OUTPUT_DIR, filename)}',
|
||||
]
|
||||
|
||||
try:
|
||||
sub_output = subprocess.run(command, timeout=5*60, capture_output=True,
|
||||
check=True, encoding="utf-8")
|
||||
print(sub_output.stdout)
|
||||
subprocess.run(command, timeout=5*60)
|
||||
except subprocess.CalledProcessError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to execute the command.")
|
||||
|
||||
|
||||
@app.post("/inference/")
|
||||
async def run_inference(background_tasks: BackgroundTasks, prompt: str = Form()):
|
||||
@app.post("/inferences/txt2img")
|
||||
async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
|
||||
# Generate a unique filename using uuid4()
|
||||
filename = f"{uuid4()}.png"
|
||||
|
||||
# We will use background task to run the command so it won't block
|
||||
background_tasks.add_task(run_command, prompt, filename)
|
||||
background_tasks.add_task(run_command, payload, filename)
|
||||
|
||||
# Return the expected path of the output file
|
||||
return {"url": f'{BASE_URL}/serve/{filename}'}
|
||||
return {"url": f'/serve/{filename}'}
|
||||
|
||||
|
||||
@app.get("/serve/{filename}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user