91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
from fastapi import FastAPI, BackgroundTasks, HTTPException, Form
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
import subprocess
|
|
import os
|
|
from uuid import uuid4
|
|
from pydantic import BaseModel
|
|
import boto3
|
|
from botocore.client import Config
|
|
|
|
app = FastAPI()
|
|
|
|
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output")
|
|
SD_PATH = os.environ.get("SD_PATH", "./sd")
|
|
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
|
|
MODEL_NAME = os.environ.get(
|
|
"MODEL_NAME", "v1-5-pruned-emaonly.safetensors.q4_0.bin")
|
|
|
|
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")
|
|
S3_PUBLIC_ENDPOINT_URL = os.environ.get(
|
|
"S3_PUBLIC_ENDPOINT_URL", "http://localhost:9000")
|
|
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", "minio")
|
|
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", "minio123")
|
|
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "jan")
|
|
|
|
s3 = boto3.resource('s3',
|
|
endpoint_url=S3_ENDPOINT_URL,
|
|
aws_access_key_id=S3_ACCESS_KEY_ID,
|
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
|
config=Config(signature_version='s3v4'),
|
|
region_name='us-east-1')
|
|
|
|
s3_bucket = s3.Bucket(S3_BUCKET_NAME)
|
|
|
|
|
|
class Payload(BaseModel):
|
|
prompt: str
|
|
neg_prompt: str
|
|
seed: int
|
|
steps: int
|
|
width: int
|
|
height: int
|
|
|
|
|
|
# Create the OUTPUT_DIR directory if it does not exist
|
|
if not os.path.exists(OUTPUT_DIR):
|
|
os.makedirs(OUTPUT_DIR)
|
|
|
|
# Create the OUTPUT_DIR directory if it does not exist
|
|
if not os.path.exists(MODEL_DIR):
|
|
os.makedirs(MODEL_DIR)
|
|
|
|
|
|
def run_command(payload: Payload, filename: str):
|
|
# Construct the command based on your provided example
|
|
command = [SD_PATH,
|
|
"--model", f'{os.path.join(MODEL_DIR, MODEL_NAME)}',
|
|
"--prompt", f'"{payload.prompt}"',
|
|
"--negative-prompt", f'"{payload.neg_prompt}"',
|
|
"--height", str(payload.height),
|
|
"--width", str(payload.width),
|
|
"--steps", str(payload.steps),
|
|
"--seed", str(payload.seed),
|
|
"--mode", 'txt2img',
|
|
"-o", f'{os.path.join(OUTPUT_DIR, filename)}',
|
|
]
|
|
|
|
try:
|
|
subprocess.run(command)
|
|
except subprocess.CalledProcessError:
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to execute the command.")
|
|
|
|
|
|
@app.post("/inferences/txt2img")
|
|
async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
|
|
# Generate a unique filename using uuid4()
|
|
filename = f"{uuid4()}.png"
|
|
|
|
# We will use background task to run the command so it won't block
|
|
# background_tasks.add_task(run_command, payload, filename)
|
|
run_command(payload, filename)
|
|
s3_bucket.upload_file(f'{os.path.join(OUTPUT_DIR, filename)}', filename)
|
|
# Return the expected path of the output file
|
|
return {"url": f'{S3_PUBLIC_ENDPOINT_URL}/{S3_BUCKET_NAME}/{filename}'}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|