fix(proxy): implement true HTTP streaming for chat completions API (#5350)
This commit is contained in:
parent
6cee466f52
commit
eb5655bbd4
@ -7,6 +7,7 @@ use std::net::SocketAddr;
|
|||||||
use std::sync::LazyLock;
|
use std::sync::LazyLock;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
use futures_util::StreamExt;
|
||||||
use flate2::read::GzDecoder;
|
use flate2::read::GzDecoder;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
|
||||||
@ -389,11 +390,11 @@ async fn proxy_request(
|
|||||||
&config.trusted_hosts,
|
&config.trusted_hosts,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Read response body
|
// Handle streaming vs non-streaming responses
|
||||||
match response.bytes().await {
|
if path.contains("/models") && method == hyper::Method::GET {
|
||||||
Ok(bytes) => {
|
// For /models endpoint, we need to buffer and filter the response
|
||||||
// Check if this is a /models endpoint request and filter the response
|
match response.bytes().await {
|
||||||
if path.contains("/models") && method == hyper::Method::GET {
|
Ok(bytes) => {
|
||||||
match filter_models_response(&bytes) {
|
match filter_models_response(&bytes) {
|
||||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -401,24 +402,46 @@ async fn proxy_request(
|
|||||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
},
|
||||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
Err(e) => {
|
||||||
|
log::error!("Failed to read response body: {}", e);
|
||||||
|
let mut error_response =
|
||||||
|
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
error_response = add_cors_headers_with_host_and_origin(
|
||||||
|
error_response,
|
||||||
|
&host_header,
|
||||||
|
&origin_header,
|
||||||
|
&config.trusted_hosts,
|
||||||
|
);
|
||||||
|
Ok(error_response
|
||||||
|
.body(Body::from("Error reading upstream response"))
|
||||||
|
.unwrap())
|
||||||
}
|
}
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Failed to read response body: {}", e);
|
|
||||||
let mut error_response =
|
|
||||||
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
error_response = add_cors_headers_with_host_and_origin(
|
|
||||||
error_response,
|
|
||||||
&host_header,
|
|
||||||
&origin_header,
|
|
||||||
&config.trusted_hosts,
|
|
||||||
);
|
|
||||||
Ok(error_response
|
|
||||||
.body(Body::from("Error reading upstream response"))
|
|
||||||
.unwrap())
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// For streaming endpoints (like chat completions), we need to collect and forward the stream
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
let (mut sender, body) = hyper::Body::channel();
|
||||||
|
|
||||||
|
// Spawn a task to forward the stream
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
if sender.send_data(chunk).await.is_err() {
|
||||||
|
log::debug!("Client disconnected during streaming");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Stream error: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(builder.body(body).unwrap())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@ -640,9 +663,11 @@ pub async fn start_server(
|
|||||||
trusted_hosts,
|
trusted_hosts,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create HTTP client
|
// Create HTTP client with longer timeout for streaming
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming
|
||||||
|
.pool_max_idle_per_host(10)
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(30))
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
// Create service handler
|
// Create service handler
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user