feat(inf-sd): Add local s3 support for sd api
This commit is contained in:
parent
0ac19f1f39
commit
19a0fe448c
@ -2,3 +2,4 @@
|
|||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
python-multipart
|
python-multipart
|
||||||
|
boto3
|
||||||
@ -5,6 +5,8 @@ import subprocess
|
|||||||
import os
|
import os
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import boto3
|
||||||
|
from botocore.client import Config
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@ -12,8 +14,23 @@ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output")
|
|||||||
SD_PATH = os.environ.get("SD_PATH", "./sd")
|
SD_PATH = os.environ.get("SD_PATH", "./sd")
|
||||||
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
|
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
|
||||||
MODEL_NAME = os.environ.get(
|
MODEL_NAME = os.environ.get(
|
||||||
"MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin")
|
"MODEL_NAME", "v1-5-pruned-emaonly.safetensors.q4_0.bin")
|
||||||
BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000")
|
|
||||||
|
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", "http://localhost:9000")
|
||||||
|
S3_PUBLIC_ENDPOINT_URL = os.environ.get(
|
||||||
|
"S3_PUBLIC_ENDPOINT_URL", "http://localhost:9000")
|
||||||
|
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", "minio")
|
||||||
|
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", "minio123")
|
||||||
|
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "jan")
|
||||||
|
|
||||||
|
s3 = boto3.resource('s3',
|
||||||
|
endpoint_url=S3_ENDPOINT_URL,
|
||||||
|
aws_access_key_id=S3_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
|
||||||
|
config=Config(signature_version='s3v4'),
|
||||||
|
region_name='us-east-1')
|
||||||
|
|
||||||
|
s3_bucket = s3.Bucket(S3_BUCKET_NAME)
|
||||||
|
|
||||||
|
|
||||||
class Payload(BaseModel):
|
class Payload(BaseModel):
|
||||||
@ -33,9 +50,6 @@ if not os.path.exists(OUTPUT_DIR):
|
|||||||
if not os.path.exists(MODEL_DIR):
|
if not os.path.exists(MODEL_DIR):
|
||||||
os.makedirs(MODEL_DIR)
|
os.makedirs(MODEL_DIR)
|
||||||
|
|
||||||
# Serve files from the "files" directory
|
|
||||||
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output")
|
|
||||||
|
|
||||||
|
|
||||||
def run_command(payload: Payload, filename: str):
|
def run_command(payload: Payload, filename: str):
|
||||||
# Construct the command based on your provided example
|
# Construct the command based on your provided example
|
||||||
@ -66,21 +80,11 @@ async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
|
|||||||
# We will use background task to run the command so it won't block
|
# We will use background task to run the command so it won't block
|
||||||
# background_tasks.add_task(run_command, payload, filename)
|
# background_tasks.add_task(run_command, payload, filename)
|
||||||
run_command(payload, filename)
|
run_command(payload, filename)
|
||||||
|
s3_bucket.upload_file(f'{os.path.join(OUTPUT_DIR, filename)}', filename)
|
||||||
# Return the expected path of the output file
|
# Return the expected path of the output file
|
||||||
return {"url": f'{BASE_URL}/serve/{filename}'}
|
return {"url": f'{S3_PUBLIC_ENDPOINT_URL}/{S3_BUCKET_NAME}/{filename}'}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/serve/{filename}")
|
|
||||||
async def serve_file(filename: str):
|
|
||||||
file_path = os.path.join(OUTPUT_DIR, filename)
|
|
||||||
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
return FileResponse(file_path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404, detail="File not found")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||||
|
|||||||
@ -15,3 +15,10 @@ LLM_MODEL_FILE=llama-2-7b-chat.ggmlv3.q4_1.bin
|
|||||||
## SD
|
## SD
|
||||||
SD_MODEL_URL=https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
|
SD_MODEL_URL=https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
|
||||||
SD_MODEL_FILE=v1-5-pruned-emaonly.safetensors
|
SD_MODEL_FILE=v1-5-pruned-emaonly.safetensors
|
||||||
|
|
||||||
|
# Minio
|
||||||
|
S3_ACCESS_KEY_ID=minio
|
||||||
|
S3_SECRET_ACCESS_KEY=minio123
|
||||||
|
S3_BUCKET_NAME=jan
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
S3_PUBLIC_ENDPOINT_URL=http://127.0.0.1:9000
|
||||||
Loading…
x
Reference in New Issue
Block a user