fix(proxy): implement true HTTP streaming for chat completions API (#5350)
This commit is contained in:
parent
6cee466f52
commit
eb5655bbd4
@ -7,6 +7,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::LazyLock;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use futures_util::StreamExt;
|
||||
use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
|
||||
@ -389,11 +390,11 @@ async fn proxy_request(
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
|
||||
// Read response body
|
||||
match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
// Check if this is a /models endpoint request and filter the response
|
||||
if path.contains("/models") && method == hyper::Method::GET {
|
||||
// Handle streaming vs non-streaming responses
|
||||
if path.contains("/models") && method == hyper::Method::GET {
|
||||
// For /models endpoint, we need to buffer and filter the response
|
||||
match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
match filter_models_response(&bytes) {
|
||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||
Err(e) => {
|
||||
@ -401,24 +402,46 @@ async fn proxy_request(
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read response body: {}", e);
|
||||
let mut error_response =
|
||||
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
Ok(error_response
|
||||
.body(Body::from("Error reading upstream response"))
|
||||
.unwrap())
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read response body: {}", e);
|
||||
let mut error_response =
|
||||
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
Ok(error_response
|
||||
.body(Body::from("Error reading upstream response"))
|
||||
.unwrap())
|
||||
}
|
||||
} else {
|
||||
// For streaming endpoints (like chat completions), we need to collect and forward the stream
|
||||
let mut stream = response.bytes_stream();
|
||||
let (mut sender, body) = hyper::Body::channel();
|
||||
|
||||
// Spawn a task to forward the stream
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if sender.send_data(chunk).await.is_err() {
|
||||
log::debug!("Client disconnected during streaming");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Stream error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(builder.body(body).unwrap())
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@ -640,9 +663,11 @@ pub async fn start_server(
|
||||
trusted_hosts,
|
||||
};
|
||||
|
||||
// Create HTTP client
|
||||
// Create HTTP client with longer timeout for streaming
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming
|
||||
.pool_max_idle_per_host(10)
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
// Create service handler
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user