diff --git a/src-tauri/src/core/server.rs b/src-tauri/src/core/server.rs index 627ec6a7c..ee8b1cbb1 100644 --- a/src-tauri/src/core/server.rs +++ b/src-tauri/src/core/server.rs @@ -7,6 +7,7 @@ use std::net::SocketAddr; use std::sync::LazyLock; use tokio::sync::Mutex; use tokio::task::JoinHandle; +use futures_util::StreamExt; use flate2::read::GzDecoder; use std::io::Read; @@ -389,11 +390,11 @@ async fn proxy_request( &config.trusted_hosts, ); - // Read response body - match response.bytes().await { - Ok(bytes) => { - // Check if this is a /models endpoint request and filter the response - if path.contains("/models") && method == hyper::Method::GET { + // Handle streaming vs non-streaming responses + if path.contains("/models") && method == hyper::Method::GET { + // For /models endpoint, we need to buffer and filter the response + match response.bytes().await { + Ok(bytes) => { match filter_models_response(&bytes) { Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), Err(e) => { @@ -401,24 +402,46 @@ async fn proxy_request( Ok(builder.body(Body::from(bytes)).unwrap()) } } - } else { - Ok(builder.body(Body::from(bytes)).unwrap()) + }, + Err(e) => { + log::error!("Failed to read response body: {}", e); + let mut error_response = + Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + Ok(error_response + .body(Body::from("Error reading upstream response")) + .unwrap()) } - }, - Err(e) => { - log::error!("Failed to read response body: {}", e); - let mut error_response = - Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); - error_response = add_cors_headers_with_host_and_origin( - error_response, - &host_header, - &origin_header, - &config.trusted_hosts, - ); - Ok(error_response - .body(Body::from("Error reading upstream response")) - .unwrap()) } + } else { + // For streaming endpoints (like chat completions), we need to collect and forward the stream + let mut stream = response.bytes_stream(); + let (mut sender, body) = hyper::Body::channel(); + + // Spawn a task to forward the stream + tokio::spawn(async move { + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if sender.send_data(chunk).await.is_err() { + log::debug!("Client disconnected during streaming"); + break; + } + } + Err(e) => { + log::error!("Stream error: {}", e); + break; + } + } + } + }); + + Ok(builder.body(body).unwrap()) } } Err(e) => { @@ -640,9 +663,11 @@ pub async fn start_server( trusted_hosts, }; - // Create HTTP client + // Create HTTP client with longer timeout for streaming let client = Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming + .pool_max_idle_per_host(10) + .pool_idle_timeout(std::time::Duration::from_secs(30)) .build()?; // Create service handler