diff --git a/jan-inference/sd/main.py b/jan-inference/sd/main.py index f31380dd3..bf77767f2 100644 --- a/jan-inference/sd/main.py +++ b/jan-inference/sd/main.py @@ -4,16 +4,26 @@ from fastapi.staticfiles import StaticFiles import subprocess import os from uuid import uuid4 +from pydantic import BaseModel app = FastAPI() OUTPUT_DIR = "output" SD_PATH = os.environ.get("SD_PATH", "./sd") MODEL_DIR = os.environ.get("MODEL_DIR", "./models") -BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000") MODEL_NAME = os.environ.get( "MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin") + +class Payload(BaseModel): + prompt: str + neg_prompt: str + seed: int + steps: int + width: int + height: int + + # Create the OUTPUT_DIR directory if it does not exist if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) @@ -26,33 +36,37 @@ if not os.path.exists(MODEL_DIR): app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output") -def run_command(prompt: str, filename: str): +def run_command(payload: Payload, filename: str): # Construct the command based on your provided example command = [SD_PATH, - "-m", os.path.join(MODEL_DIR, MODEL_NAME), - "-p", prompt, - "-o", os.path.join(OUTPUT_DIR, filename) + "--model", f'{os.path.join(MODEL_DIR, MODEL_NAME)}', + "--prompt", f'"{payload.prompt}"', + "--negative-prompt", f'"{payload.neg_prompt}"', + "--height", str(payload.height), + "--width", str(payload.width), + "--steps", str(payload.steps), + "--seed", str(payload.seed), + "--mode", 'txt2img', + "-o", f'{os.path.join(OUTPUT_DIR, filename)}', ] try: - sub_output = subprocess.run(command, timeout=5*60, capture_output=True, - check=True, encoding="utf-8") - print(sub_output.stdout) + subprocess.run(command, timeout=5*60) except subprocess.CalledProcessError: raise HTTPException( status_code=500, detail="Failed to execute the command.") -@app.post("/inference/") -async def run_inference(background_tasks: BackgroundTasks, prompt: str = Form()): +@app.post("/inferences/txt2img") +async def run_inference(background_tasks: BackgroundTasks, payload: Payload): # Generate a unique filename using uuid4() filename = f"{uuid4()}.png" # We will use background task to run the command so it won't block - background_tasks.add_task(run_command, prompt, filename) + background_tasks.add_task(run_command, payload, filename) # Return the expected path of the output file - return {"url": f'{BASE_URL}/serve/{filename}'} + return {"url": f'/serve/{filename}'} @app.get("/serve/{filename}")