Merge pull request #19 from janhq/fix_inf_sd_ggml

Fix inf sd ggml
This commit is contained in:
namvuong 2023-08-30 20:09:08 +07:00 committed by GitHub
commit d8cc0d4262
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 11 deletions

View File

@ -190,15 +190,16 @@ services:
dockerfile: inference.Dockerfile
# Mount the directory that contains the downloaded model.
volumes:
- ./jan-inference/sd/models:/models
- ./jan-inference/sd/output/:/serving/output
- ./jan-inference/sd/models:/models/
- ./jan-inference/sd/output/:/output/
command: /bin/bash -c "python -m uvicorn main:app --proxy-headers --host 0.0.0.0 --port 8000"
environment:
# Specify the path to the model for the web application.
BASE_URL: http://0.0.0.0:8000
BASE_URL: http://0.0.0.0:8001
MODEL_NAME: ${SD_MODEL_FILE}.q4_0.bin
MODEL_DIR: "/models"
SD_PATH: "/sd"
MODEL_DIR: /models
OUTPUT_DIR: /output
SD_PATH: /sd
PYTHONUNBUFFERED: 1
ports:
- 8001:8000

View File

@ -1,4 +1,6 @@
FROM python:3.9.17 as build
ARG UBUNTU_VERSION=22.04
FROM ubuntu:$UBUNTU_VERSION as build
RUN apt-get update && apt-get install -y build-essential git cmake

View File

@ -8,11 +8,12 @@ from pydantic import BaseModel
app = FastAPI()
OUTPUT_DIR = "output"
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "output")
SD_PATH = os.environ.get("SD_PATH", "./sd")
MODEL_DIR = os.environ.get("MODEL_DIR", "./models")
MODEL_NAME = os.environ.get(
"MODEL_NAME", "v1-5-pruned-emaonly-ggml-model-q5_0.bin")
BASE_URL = os.environ.get("BASE_URL", "http://localhost:8000")
class Payload(BaseModel):
@ -51,7 +52,7 @@ def run_command(payload: Payload, filename: str):
]
try:
subprocess.run(command, timeout=5*60)
subprocess.run(command)
except subprocess.CalledProcessError:
raise HTTPException(
status_code=500, detail="Failed to execute the command.")
@ -63,10 +64,11 @@ async def run_inference(background_tasks: BackgroundTasks, payload: Payload):
filename = f"{uuid4()}.png"
# We will use background task to run the command so it won't block
background_tasks.add_task(run_command, payload, filename)
# background_tasks.add_task(run_command, payload, filename)
run_command(payload, filename)
# Return the expected path of the output file
return {"url": f'/serve/{filename}'}
return {"url": f'{BASE_URL}/serve/{filename}'}
@app.get("/serve/{filename}")

@ -1 +1 @@
Subproject commit 0d7f04b135cd48e8d62aecd09a52eb2afa482744
Subproject commit c8f85a4e3063e2cdb27db57b8f6167da16453e0c