fix(proxy): implement true HTTP streaming for chat completions API (#5350)

This commit is contained in:
Sam Hoang Van 2025-06-18 16:19:48 +07:00 committed by GitHub
parent 6cee466f52
commit eb5655bbd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@ use std::net::SocketAddr;
use std::sync::LazyLock; use std::sync::LazyLock;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use futures_util::StreamExt;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use std::io::Read; use std::io::Read;
@ -389,11 +390,11 @@ async fn proxy_request(
&config.trusted_hosts, &config.trusted_hosts,
); );
// Read response body // Handle streaming vs non-streaming responses
match response.bytes().await { if path.contains("/models") && method == hyper::Method::GET {
Ok(bytes) => { // For /models endpoint, we need to buffer and filter the response
// Check if this is a /models endpoint request and filter the response match response.bytes().await {
if path.contains("/models") && method == hyper::Method::GET { Ok(bytes) => {
match filter_models_response(&bytes) { match filter_models_response(&bytes) {
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
Err(e) => { Err(e) => {
@ -401,24 +402,46 @@ async fn proxy_request(
Ok(builder.body(Body::from(bytes)).unwrap()) Ok(builder.body(Body::from(bytes)).unwrap())
} }
} }
} else { },
Ok(builder.body(Body::from(bytes)).unwrap()) Err(e) => {
log::error!("Failed to read response body: {}", e);
let mut error_response =
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
Ok(error_response
.body(Body::from("Error reading upstream response"))
.unwrap())
} }
},
Err(e) => {
log::error!("Failed to read response body: {}", e);
let mut error_response =
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
Ok(error_response
.body(Body::from("Error reading upstream response"))
.unwrap())
} }
} else {
// For streaming endpoints (like chat completions), we need to collect and forward the stream
let mut stream = response.bytes_stream();
let (mut sender, body) = hyper::Body::channel();
// Spawn a task to forward the stream
tokio::spawn(async move {
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
if sender.send_data(chunk).await.is_err() {
log::debug!("Client disconnected during streaming");
break;
}
}
Err(e) => {
log::error!("Stream error: {}", e);
break;
}
}
}
});
Ok(builder.body(body).unwrap())
} }
} }
Err(e) => { Err(e) => {
@ -640,9 +663,11 @@ pub async fn start_server(
trusted_hosts, trusted_hosts,
}; };
// Create HTTP client // Create HTTP client with longer timeout for streaming
let client = Client::builder() let client = Client::builder()
.timeout(std::time::Duration::from_secs(30)) .timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming
.pool_max_idle_per_host(10)
.pool_idle_timeout(std::time::Duration::from_secs(30))
.build()?; .build()?;
// Create service handler // Create service handler