jan/jan-inference/sd/main.py
2023-08-31 00:24:14 +07:00

91 lines
3.0 KiB
Python

from fastapi import FastAPI, BackgroundTasks, HTTPException, Form
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import subprocess
import os
from uuid import uuid4
from pydantic import BaseModel
import boto3
from botocore.client import Config
app = FastAPI()
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output")
SD_PATH = os.environ.get("SD_PATH", "./sd")
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
MODEL_NAME = os.environ.get(
"MODEL_NAME", "v1-5-pruned-emaonly.safetensors.q4_0.bin")
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")
S3_PUBLIC_ENDPOINT_URL = os.environ.get(
"S3_PUBLIC_ENDPOINT_URL", "http://localhost:9000")
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", "minio")
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", "minio123")
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "jan")
s3 = boto3.resource('s3',
endpoint_url=S3_ENDPOINT_URL,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
config=Config(signature_version='s3v4'),
region_name='us-east-1')
s3_bucket = s3.Bucket(S3_BUCKET_NAME)
class Payload(BaseModel):
prompt: str
neg_prompt: str
seed: int
steps: int
width: int
height: int
# Create the OUTPUT_DIR directory if it does not exist
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
# Create the OUTPUT_DIR directory if it does not exist
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
def run_command(payload: Payload, filename: str):
# Construct the command based on your provided example
command = [SD_PATH,
"--model", f'{os.path.join(MODEL_DIR, MODEL_NAME)}',
"--prompt", f'"{payload.prompt}"',
"--negative-prompt", f'"{payload.neg_prompt}"',
"--height", str(payload.height),
"--width", str(payload.width),
"--steps", str(payload.steps),
"--seed", str(payload.seed),
"--mode", 'txt2img',
"-o", f'{os.path.join(OUTPUT_DIR, filename)}',
]
try:
subprocess.run(command)
except subprocess.CalledProcessError:
raise HTTPException(
status_code=500, detail="Failed to execute the command.")
@app.post("/inferences/txt2img")
async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
# Generate a unique filename using uuid4()
filename = f"{uuid4()}.png"
# We will use background task to run the command so it won't block
# background_tasks.add_task(run_command, payload, filename)
run_command(payload, filename)
s3_bucket.upload_file(f'{os.path.join(OUTPUT_DIR, filename)}', filename)
# Return the expected path of the output file
return {"url": f'{S3_PUBLIC_ENDPOINT_URL}/{S3_BUCKET_NAME}/{filename}'}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)