fix(inference): Update ggml sd

This commit is contained in:
vuonghoainam 2023-08-30 16:44:05 +07:00
parent b89722e439
commit ceb01e5eda

View File

@ -4,16 +4,26 @@ from fastapi.staticfiles import StaticFiles
import subprocess import subprocess
import os import os
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel
app = FastAPI() app = FastAPI()
OUTPUT_DIR = "output" OUTPUT_DIR = "output"
SD_PATH = os.environ.get("SD_PATH", "./sd") SD_PATH = os.environ.get("SD_PATH", "./sd")
MODEL_DIR = os.environ.get("MODEL_DIR", "./models") MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000")
MODEL_NAME = os.environ.get( MODEL_NAME = os.environ.get(
"MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin") "MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin")
class Payload(BaseModel):
prompt: str
neg_prompt: str
seed: int
steps: int
width: int
height: int
# Create the OUTPUT_DIR directory if it does not exist # Create the OUTPUT_DIR directory if it does not exist
if not os.path.exists(OUTPUT_DIR): if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR) os.makedirs(OUTPUT_DIR)
@ -26,33 +36,37 @@ if not os.path.exists(MODEL_DIR):
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output") app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output")
def run_command(prompt: str, filename: str): def run_command(payload: Payload, filename: str):
# Construct the command based on your provided example # Construct the command based on your provided example
command = [SD_PATH, command = [SD_PATH,
"-m", os.path.join(MODEL_DIR, MODEL_NAME), "--model", f'{os.path.join(MODEL_DIR, MODEL_NAME)}',
"-p", prompt, "--prompt", f'"{payload.prompt}"',
"-o", os.path.join(OUTPUT_DIR, filename) "--negative-prompt", f'"{payload.neg_prompt}"',
"--height", str(payload.height),
"--width", str(payload.width),
"--steps", str(payload.steps),
"--seed", str(payload.seed),
"--mode", 'txt2img',
"-o", f'{os.path.join(OUTPUT_DIR, filename)}',
] ]
try: try:
sub_output = subprocess.run(command, timeout=5*60, capture_output=True, subprocess.run(command, timeout=5*60)
check=True, encoding="utf-8")
print(sub_output.stdout)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
raise HTTPException( raise HTTPException(
status_code=500, detail="Failed to execute the command.") status_code=500, detail="Failed to execute the command.")
@app.post("/inference/") @app.post("/inferences/txt2img")
async def run_inference(background_tasks: BackgroundTasks, prompt: str = Form()): async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
# Generate a unique filename using uuid4() # Generate a unique filename using uuid4()
filename = f"{uuid4()}.png" filename = f"{uuid4()}.png"
# We will use background task to run the command so it won't block # We will use background task to run the command so it won't block
background_tasks.add_task(run_command, prompt, filename) background_tasks.add_task(run_command, payload, filename)
# Return the expected path of the output file # Return the expected path of the output file
return {"url": f'{BASE_URL}/serve/{filename}'} return {"url": f'/serve/{filename}'}
@app.get("/serve/{filename}") @app.get("/serve/{filename}")