From 5f66c3cf83316e107c3f9b9c066d1d739747eeb1 Mon Sep 17 00:00:00 2001 From: Shiju Date: Thu, 4 Jun 2026 20:47:11 +0530 Subject: [PATCH 1/5] feat(inference): allow local embeddings route Route OpenAI-compatible embeddings through the local inference proxy so sandboxed vector workloads reach a configured provider via the same route classification and auth path that chat, completion, and model discovery already use. - Add openai_embeddings to the OpenAI-compatible protocol set so providers (openai, nvidia) advertise embeddings routing. - Classify POST /v1/embeddings as the openai_embeddings protocol in the sandbox L7 patterns. - Serve embeddings buffered with an accurate Content-Length, since the response is a single JSON object rather than an SSE token stream. The streaming path appends an SSE error frame on a size-cap or idle-timeout truncation, which would corrupt a one-object body the client parses whole. protocol_returns_buffered_body() selects the path. - Probe an embeddings-only backend against /v1/embeddings during validation, after the chat and completion protocols so a multi-protocol route still prefers those. - Extract two shared helpers. http_status_text() backs both response formatters and adds 401/422/429/503 for embeddings passthrough and router error mapping; write_inference_router_error() backs the streaming and buffered routing paths. - Return an OpenAI-shaped embeddings body from the mock route. Tests cover profile lookup, L7 pattern detection, the mock body, and buffered Content-Length framing with no chunked transfer-encoding and no SSE error frame. Signed-off-by: Shiju --- crates/openshell-core/src/inference.rs | 12 ++ crates/openshell-router/src/backend.rs | 16 ++ crates/openshell-router/src/mock.rs | 41 ++++ crates/openshell-sandbox/src/l7/inference.rs | 69 +++++-- crates/openshell-sandbox/src/proxy.rs | 185 ++++++++++++++++--- crates/openshell-server/src/inference.rs | 1 + 6 files changed, 279 insertions(+), 45 deletions(-) diff --git a/crates/openshell-core/src/inference.rs b/crates/openshell-core/src/inference.rs index c04feb6b4..3071d53cd 100644 --- a/crates/openshell-core/src/inference.rs +++ b/crates/openshell-core/src/inference.rs @@ -56,6 +56,7 @@ const OPENAI_PROTOCOLS: &[&str] = &[ "openai_chat_completions", "openai_completions", "openai_responses", + "openai_embeddings", "model_discovery", ]; @@ -305,6 +306,17 @@ mod tests { assert!(profile_for("OpenAI").is_some()); // case insensitive } + #[test] + fn openai_compatible_profiles_include_embeddings() { + for provider_type in ["openai", "nvidia"] { + let profile = profile_for(provider_type).expect("provider profile should exist"); + assert!( + profile.protocols.contains(&"openai_embeddings"), + "{provider_type} should route OpenAI-compatible embeddings" + ); + } + } + #[test] fn profile_for_unknown_types() { assert!(profile_for("github").is_none()); diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 272da265a..39a2c6011 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -381,6 +381,22 @@ fn validation_probe(route: &ResolvedRoute) -> Result ProxyRespo let body = match protocol { "openai_chat_completions" => openai_chat_completion_body(&route.model), "openai_completions" => openai_completion_body(&route.model), + "openai_embeddings" => openai_embeddings_body(&route.model), "anthropic_messages" => anthropic_messages_body(&route.model), _ => generic_body(&route.model), }; @@ -90,6 +91,26 @@ fn openai_completion_body(model: &str) -> Vec { .expect("static JSON must serialize") } +fn openai_embeddings_body(model: &str) -> Vec { + // Shape must match the OpenAI embeddings response (`object: "list"` with a + // `data` array of `{object, index, embedding}`) so callers that deserialize + // into an embeddings type get a structurally valid — if canned — vector. + serde_json::to_vec(&serde_json::json!({ + "object": "list", + "data": [{ + "object": "embedding", + "index": 0, + "embedding": [0.0_f32, 0.0_f32, 0.0_f32] + }], + "model": model, + "usage": { + "prompt_tokens": 1, + "total_tokens": 1 + } + })) + .expect("static JSON must serialize") +} + fn anthropic_messages_body(model: &str) -> Vec { serde_json::to_vec(&serde_json::json!({ "id": "mock-msg-001", @@ -196,6 +217,26 @@ mod tests { ); } + #[test] + fn mock_openai_embeddings() { + let route = make_route( + "mock://test", + &["openai_embeddings"], + "text-embedding-3-small", + ); + let resp = mock_response(&route, "openai_embeddings"); + assert_eq!(resp.status, 200); + + let body: serde_json::Value = serde_json::from_slice(&resp.body).unwrap(); + assert_eq!(body["object"], "list"); + assert_eq!(body["model"], "text-embedding-3-small"); + assert_eq!(body["data"][0]["object"], "embedding"); + assert!( + body["data"][0]["embedding"].is_array(), + "embedding must be a numeric array, got: {body}" + ); + } + #[test] fn mock_generic_protocol() { let route = make_route("mock://test", &["unknown_protocol"], "some-model"); diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index acda0bb36..d9b29586c 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -37,6 +37,12 @@ pub fn default_patterns() -> Vec { protocol: "openai_responses".to_string(), kind: "responses".to_string(), }, + InferenceApiPattern { + method: "POST".to_string(), + path_glob: "/v1/embeddings".to_string(), + protocol: "openai_embeddings".to_string(), + kind: "embeddings".to_string(), + }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/messages".to_string(), @@ -82,6 +88,21 @@ pub fn detect_inference_pattern<'a>( }) } +/// Whether a detected inference protocol returns a single buffered JSON body +/// rather than a Server-Sent Events token stream. +/// +/// SSE-capable protocols (chat completions, completions, responses, Anthropic +/// messages) are served through the chunked streaming path so tokens reach the +/// client incrementally. Embeddings always return one JSON object, so they must +/// be framed with an accurate `Content-Length`. Streaming a buffered body would +/// let a mid-body truncation (the streaming size cap or idle timeout) append a +/// [`format_sse_error`] event to a payload the client parses as a single JSON +/// object, silently corrupting it. Add future non-streaming protocols here. +#[must_use] +pub fn protocol_returns_buffered_body(protocol: &str) -> bool { + matches!(protocol, "openai_embeddings") +} + /// A parsed HTTP request from the intercepted tunnel. pub struct ParsedHttpRequest { pub method: String, @@ -267,20 +288,37 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } -/// Format an HTTP/1.1 response from status, headers, and body. -pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u8]) -> Vec { - use std::fmt::Write; - - let status_text = match status { +/// Reason phrase for an HTTP status code used on the inference proxy path. +/// +/// Covers the statuses produced by the router error mapping (400/401/403/500/ +/// 502/503) and the upstream codes an inference backend can pass through +/// verbatim (404/405 on unknown model or method, 422 on malformed embeddings +/// input, 429 on rate limit). Unknown codes fall back to `"Unknown"` so the +/// status line is still well-formed. +fn http_status_text(status: u16) -> &'static str { + match status { 200 => "OK", 400 => "Bad Request", + 401 => "Unauthorized", 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", 411 => "Length Required", 413 => "Payload Too Large", + 422 => "Unprocessable Entity", + 429 => "Too Many Requests", 500 => "Internal Server Error", 502 => "Bad Gateway", + 503 => "Service Unavailable", _ => "Unknown", - }; + } +} + +/// Format an HTTP/1.1 response from status, headers, and body. +pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u8]) -> Vec { + use std::fmt::Write; + + let status_text = http_status_text(status); let mut response = format!("HTTP/1.1 {status} {status_text}\r\n"); let mut has_content_length = false; @@ -310,17 +348,7 @@ pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u pub fn format_http_response_header(status: u16, headers: &[(String, String)]) -> Vec { use std::fmt::Write; - let status_text = match status { - 200 => "OK", - 400 => "Bad Request", - 403 => "Forbidden", - 411 => "Length Required", - 413 => "Payload Too Large", - 500 => "Internal Server Error", - 502 => "Bad Gateway", - 503 => "Service Unavailable", - _ => "Unknown", - }; + let status_text = http_status_text(status); let mut response = format!("HTTP/1.1 {status} {status_text}\r\n"); for (name, value) in headers { @@ -439,10 +467,13 @@ mod tests { } #[test] - fn no_match_for_embeddings() { + fn detect_openai_embeddings() { let patterns = default_patterns(); let result = detect_inference_pattern("POST", "/v1/embeddings", &patterns); - assert!(result.is_none()); + assert!(result.is_some()); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "openai_embeddings"); + assert_eq!(pattern.kind, "embeddings"); } #[test] diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index ae100d734..3e922f4d2 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -1666,6 +1666,34 @@ async fn route_inference_request( return Ok(true); } + // Non-streaming protocols (embeddings) return a single JSON object, not + // an SSE token stream. Serve them buffered with an accurate + // Content-Length: the streaming path would append an SSE error frame to + // the body on a size-cap or idle-timeout truncation, corrupting a + // payload the client parses as one JSON object. + if crate::l7::inference::protocol_returns_buffered_body(&pattern.protocol) { + match ctx + .router + .proxy_with_candidates( + &pattern.protocol, + &request.method, + &normalized_path, + request.headers.clone(), + bytes::Bytes::from(request.body.clone()), + &routes, + ) + .await + { + Ok(resp) => { + let resp_headers = sanitize_inference_response_headers(resp.headers); + let response = format_http_response(resp.status, &resp_headers, &resp.body); + write_all(tls_client, &response).await?; + } + Err(e) => write_inference_router_error(tls_client, &e).await?, + } + return Ok(true); + } + match ctx .router .proxy_with_candidates_streaming( @@ -1760,29 +1788,7 @@ async fn route_inference_request( // Terminate the chunked stream. write_all(tls_client, format_chunk_terminator()).await?; } - Err(e) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "inference endpoint detected but upstream service failed: {e}" - )) - .build(); - ocsf_emit!(event); - } - let (status, msg) = router_error_to_http(&e); - let body = serde_json::json!({"error": msg}); - let body_bytes = body.to_string(); - let response = format_http_response( - status, - &[("content-type".to_string(), "application/json".to_string())], - body_bytes.as_bytes(), - ); - write_all(tls_client, &response).await?; - } + Err(e) => write_inference_router_error(tls_client, &e).await?, } Ok(true) } else { @@ -1814,11 +1820,42 @@ async fn route_inference_request( } } +/// Emit an OCSF failure event and write a buffered JSON error response for a +/// router error hit while proxying an inference request. +/// +/// Shared by the streaming and buffered routing paths so both surface upstream +/// failures with the same status mapping and the same audit record. +async fn write_inference_router_error( + tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), + err: &openshell_router::RouterError, +) -> Result<()> { + use crate::l7::inference::format_http_response; + + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "inference endpoint detected but upstream service failed: {err}" + )) + .build(); + ocsf_emit!(event); + + let (status, msg) = router_error_to_http(err); + let body = serde_json::json!({ "error": msg }).to_string(); + let response = format_http_response( + status, + &[("content-type".to_string(), "application/json".to_string())], + body.as_bytes(), + ); + write_all(tls_client, &response).await +} + /// Map router errors to HTTP status codes and sanitized messages. /// -/// Returns generic error messages instead of verbatim internal details. -/// Full error context (upstream URLs, hostnames, TLS details) is logged -/// server-side by the caller at `warn` level for debugging. +/// Returns generic, client-safe messages instead of verbatim internal details; +/// the full error is recorded in the OCSF failure event by the caller. fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { use openshell_router::RouterError; match err { @@ -5433,6 +5470,102 @@ network_policies: } } + fn embeddings_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_embeddings".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// Embeddings responses are a single buffered JSON object, not an SSE + /// stream. They must be framed with `Content-Length` and must never be sent + /// through the chunked streaming path, whose truncation handlers would + /// append an SSE `proxy_stream_error` frame into the JSON body. + #[tokio::test] + async fn inference_embeddings_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"object":"list","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2]}],"model":"text-embedding-3-small"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + // Buffered upstream response with Content-Length (no chunked TE). + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![embeddings_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let body = r#"{"model":"text-embedding-3-small","input":"hello"}"#; + let request = format!( + "POST /v1/embeddings HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "embeddings response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "embeddings response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "embeddings response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "embeddings JSON body must be forwarded intact, got: {response}" + ); + } + async fn read_forwarded_inference_request(stream: &mut S) { use crate::l7::inference::{ParseResult, try_parse_http_request}; diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 58b5feb2a..13496cd99 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -1284,6 +1284,7 @@ mod tests { "openai_chat_completions".to_string(), "openai_completions".to_string(), "openai_responses".to_string(), + "openai_embeddings".to_string(), "model_discovery".to_string(), ] ); From 2afa1325da875c82e5bdc739cd51cb3b2572577f Mon Sep 17 00:00:00 2001 From: Shiju Date: Fri, 5 Jun 2026 15:49:00 +0530 Subject: [PATCH 2/5] fix(inference): cap buffered inference response body The buffered proxy path read the whole upstream response into memory with no size bound. The route timeout bounds elapsed time but not memory, so a misbehaving or oversized upstream could force unbounded allocation in the sandbox proxy. The streaming path already caps each response at 32 MiB; the buffered path did not. Cap the buffered read at the same 32 MiB. An advertised over-cap body is rejected from its Content-Length before any bytes are read, and chunks accumulate under the same bound so a chunked or mislabeled body cannot slip past. An over-cap response fails as an upstream protocol error, surfaced as HTTP 502 at the proxy boundary, and is never partially returned. Tests - cargo test -p openshell-router \ proxy_to_backend_rejects_over_cap_response_body Signed-off-by: Shiju --- crates/openshell-router/src/backend.rs | 151 ++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 5 deletions(-) diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 39a2c6011..a1e73badc 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -6,6 +6,13 @@ use crate::config::{AuthHeader, ResolvedRoute}; use crate::mock; use std::collections::HashSet; +/// Maximum buffered inference response body, in bytes. The buffered path +/// reads the whole response into memory; the route timeout bounds time, not +/// memory, so without this cap an oversized upstream could force unbounded +/// allocation. Mirrors the sandbox streaming byte cap. Over-cap responses fail +/// as an upstream protocol error. +const MAX_BUFFERED_RESPONSE_BODY: usize = 32 * 1024 * 1024; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedEndpoint { pub url: String, @@ -556,18 +563,58 @@ pub async fn proxy_to_backend( ) -> Result { let response = send_backend_request(client, route, method, path, headers, body).await?; let (status, resp_headers) = extract_response_metadata(&response); - let resp_body = response - .bytes() - .await - .map_err(|e| RouterError::UpstreamProtocol(format!("failed to read response body: {e}")))?; + let body = read_capped_response_body(response, MAX_BUFFERED_RESPONSE_BODY).await?; Ok(ProxyResponse { status, headers: resp_headers, - body: resp_body, + body, }) } +/// Read a response body fully into memory, rejecting anything over `max` bytes. +/// +/// Used by the buffered proxy path so a misbehaving upstream cannot force +/// unbounded allocation. The `Content-Length` check is a fast early-out; the +/// chunk loop is the real guard and bounds an absent, chunked, or +/// under-reported length. The cap counts the bytes reqwest yields: with no +/// decompression features enabled (see `Cargo.toml`) those are wire bytes, so +/// enabling a compression feature later would change what the cap measures. +/// Over-cap responses fail as `UpstreamProtocol` and are never partially +/// returned. +async fn read_capped_response_body( + mut response: reqwest::Response, + max: usize, +) -> Result { + if let Some(len) = response.content_length() + && len > max as u64 + { + return Err(RouterError::UpstreamProtocol(format!( + "inference response body of {len} bytes exceeds the {max} byte cap" + ))); + } + + // Preallocate to the advertised length when it is within the cap; the loop + // still enforces the bound for an absent or under-reported length. + let mut body: Vec = match response.content_length() { + Some(len) if len <= max as u64 => Vec::with_capacity(usize::try_from(len).unwrap_or(max)), + _ => Vec::new(), + }; + while let Some(chunk) = response + .chunk() + .await + .map_err(|e| RouterError::UpstreamProtocol(format!("failed to read response body: {e}")))? + { + if body.len() + chunk.len() > max { + return Err(RouterError::UpstreamProtocol(format!( + "inference response body exceeds the {max} byte cap" + ))); + } + body.extend_from_slice(&chunk); + } + Ok(bytes::Bytes::from(body)) +} + /// Forward a raw HTTP request to the backend, returning response headers /// immediately without buffering the body. /// @@ -739,6 +786,100 @@ mod tests { } } + /// The buffered path must reject an over-cap upstream response rather than + /// buffer it. Guards the DoS/OOM exposure of reading the body unbounded. + #[tokio::test] + async fn proxy_to_backend_rejects_over_cap_response_body() { + use super::{MAX_BUFFERED_RESPONSE_BODY, proxy_to_backend}; + + let mock_server = MockServer::start().await; + // One byte over the cap. wiremock sets an accurate Content-Length, so + // the size check rejects before the body is buffered. + let oversized = vec![b'a'; MAX_BUFFERED_RESPONSE_BODY + 1]; + Mock::given(method("GET")) + .and(path("/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized)) + .mount(&mock_server) + .await; + + let route = test_route(&mock_server.uri(), &["model_discovery"], AuthHeader::Bearer); + let client = reqwest::Client::new(); + let result = proxy_to_backend( + &client, + &route, + "model_discovery", + "GET", + "/v1/models", + vec![], + bytes::Bytes::new(), + ) + .await; + + assert!( + matches!(result, Err(crate::RouterError::UpstreamProtocol(_))), + "over-cap response must fail as UpstreamProtocol, got: {result:?}" + ); + } + + /// Spawn a one-shot HTTP/1.1 upstream that replies with a chunked body and + /// no `Content-Length`, so the buffered read cannot pre-check a length and + /// must enforce the cap inside the chunk loop. + async fn spawn_chunked_upstream(chunks: &'static [&'static str]) -> std::net::SocketAddr { + use std::fmt::Write as _; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let (mut sock, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let _ = sock.read(&mut buf).await; + let mut resp = String::from( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\n\r\n", + ); + for c in chunks { + let _ = write!(resp, "{:x}\r\n{c}\r\n", c.len()); + } + resp.push_str("0\r\n\r\n"); + sock.write_all(resp.as_bytes()).await.unwrap(); + }); + addr + } + + /// The chunk-accumulation guard (not the `Content-Length` pre-check) must + /// reject an over-cap body when the response advertises no length. + #[tokio::test] + async fn read_capped_response_body_rejects_over_cap_chunked() { + let addr = spawn_chunked_upstream(&["aaaa", "bbbb", "cccc"]).await; + let response = reqwest::Client::new() + .get(format!("http://{addr}/")) + .send() + .await + .unwrap(); + assert!( + response.content_length().is_none(), + "chunked response should advertise no Content-Length" + ); + let result = super::read_capped_response_body(response, 8).await; + assert!( + matches!(result, Err(crate::RouterError::UpstreamProtocol(_))), + "over-cap chunked body must be rejected by the loop, got: {result:?}" + ); + } + + /// A body exactly at the cap is accepted (inclusive bound) and returned + /// intact through the chunk loop. + #[tokio::test] + async fn read_capped_response_body_accepts_body_at_cap() { + let addr = spawn_chunked_upstream(&["aaaa", "bbbb"]).await; + let response = reqwest::Client::new() + .get(format!("http://{addr}/")) + .send() + .await + .unwrap(); + let body = super::read_capped_response_body(response, 8).await.unwrap(); + assert_eq!(&body[..], b"aaaabbbb"); + } + #[test] fn sanitize_request_headers_drops_unknown_sensitive_headers() { let route = ResolvedRoute { From 9b328fe8f63b607359006416100acf328c90e96d Mon Sep 17 00:00:00 2001 From: Shiju Date: Fri, 5 Jun 2026 16:00:28 +0530 Subject: [PATCH 3/5] fix(inference): validate embeddings models against all advertised protocols A managed route resolves to its provider profile's full protocol set, so an embeddings model such as text-embedding-3-small lists openai_chat_completions alongside openai_embeddings. Route verification probed only the first writable protocol and stopped on its failure. It sent a chat probe with the embedding model, the provider rejected it as wrong-shape, and the route failed validation before the embeddings probe ran. Embeddings-only configs could not be verified. Try the advertised protocols in preference order. A request-shape rejection (HTTP 400, 404, 405, 422) falls through to the next protocol, so an embeddings model validates against /v1/embeddings even when the chat probe rejects it. Credential, rate-limit, connectivity, and upstream-health failures stay terminal and stop validation at the first probe, so a bad key or a down backend is reported as itself rather than masked by a later probe. validation_probe becomes validation_probes, which returns the ordered list, and the per-probe fallback retry (max_completion_tokens versus max_tokens) moves into a shared helper. Tests - cargo test -p openshell-router \ verify_embeddings_model_falls_through_chat_probe - cargo test -p openshell-router verify_stops_on_credentials_failure Signed-off-by: Shiju --- crates/openshell-router/src/backend.rs | 409 ++++++++++++++++++++----- 1 file changed, 328 insertions(+), 81 deletions(-) diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index a1e73badc..9eb63c88b 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -39,10 +39,12 @@ struct ValidationProbe { path: &'static str, protocol: &'static str, body: bytes::Bytes, - /// Alternate body to try when the primary probe fails with HTTP 400. - /// Used for `OpenAI` chat completions where newer models require - /// `max_completion_tokens` while legacy/self-hosted backends only - /// accept `max_tokens`. + /// Alternate body to try when the primary probe is rejected specifically + /// for `max_completion_tokens`. Used for `OpenAI` chat completions where + /// newer models require `max_completion_tokens` while legacy/self-hosted + /// backends only accept `max_tokens`. The retry is gated on the error + /// body naming that parameter, so an unrelated request-shape rejection + /// (wrong protocol for the model) falls through instead. fallback_body: Option, } @@ -297,11 +299,11 @@ async fn send_backend_request( route: &ResolvedRoute, method: &str, path: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { let (builder, url) = - prepare_backend_request(client, route, method, path, &headers, body, false)?; + prepare_backend_request(client, route, method, path, headers, body, false)?; builder .timeout(route.timeout) .send() @@ -319,23 +321,30 @@ async fn send_backend_request_streaming( route: &ResolvedRoute, method: &str, path: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { - let (builder, url) = - prepare_backend_request(client, route, method, path, &headers, body, true)?; + let (builder, url) = prepare_backend_request(client, route, method, path, headers, body, true)?; builder.send().await.map_err(|e| map_send_error(e, &url)) } -fn validation_probe(route: &ResolvedRoute) -> Result { - if route - .protocols - .iter() - .any(|protocol| protocol == "openai_chat_completions") - { +/// Validation probes for a route, in preference order. +/// +/// A managed route advertises every protocol in its provider profile, so an +/// embeddings model resolves to a route that also lists chat/completions. The +/// caller tries these in order and falls through to the next on a request-shape +/// rejection, so such a model validates against `/v1/embeddings` even though +/// the chat probe rejects it. Embeddings is ordered last so a genuinely +/// chat-capable route still validates against chat. Empty when the route +/// exposes no writable protocol. +fn validation_probes(route: &ResolvedRoute) -> Vec { + let has = |protocol: &str| route.protocols.iter().any(|p| p == protocol); + let mut probes = Vec::new(); + + if has("openai_chat_completions") { // Use max_completion_tokens (modern OpenAI parameter, required by GPT-5+) // with max_tokens as fallback for legacy/self-hosted backends. - return Ok(ValidationProbe { + probes.push(ValidationProbe { path: "/v1/chat/completions", protocol: "openai_chat_completions", body: bytes::Bytes::from_static( @@ -347,12 +356,8 @@ fn validation_probe(route: &ResolvedRoute) -> Result Result Result Result Result ValidationFailure { + ValidationFailure { kind: ValidationFailureKind::RequestShape, details: format!( "route '{}' does not expose a writable inference protocol for validation", route.name ), - }) + } } pub async fn verify_backend_endpoint( client: &reqwest::Client, route: &ResolvedRoute, ) -> Result { - let probe = validation_probe(route)?; - let headers = vec![("content-type".to_string(), "application/json".to_string())]; + let probes = validation_probes(route); + let Some(first) = probes.first() else { + return Err(no_writable_protocol_failure(route)); + }; if mock::is_mock_route(route) { return Ok(ValidatedEndpoint { - url: build_provider_url(route, &route.model, probe.path, false), - protocol: probe.protocol.to_string(), + url: build_provider_url(route, &route.model, first.path, false), + protocol: first.protocol.to_string(), }); } + let headers = vec![("content-type".to_string(), "application/json".to_string())]; + let mut last_shape_failure = None; + + for probe in &probes { + match try_validation_probe(client, route, probe, &headers).await { + Ok(endpoint) => return Ok(endpoint), + // A request-shape rejection means this protocol is wrong for the + // model (e.g. a chat probe against an embeddings model), so fall + // through to the next advertised protocol. Any other failure + // describes the backend itself (credentials, rate limit, + // connectivity, health) and is terminal across all protocols. + // + // Keep the first shape failure: it is the most-preferred protocol's + // rejection and the most actionable error to report. + Err(err) if err.kind == ValidationFailureKind::RequestShape => { + last_shape_failure.get_or_insert(err); + } + Err(err) => return Err(err), + } + } + + Err(last_shape_failure.unwrap_or_else(|| no_writable_protocol_failure(route))) +} + +/// Run one validation probe, retrying with its fallback body only when the +/// upstream specifically rejected `max_completion_tokens`. +/// +/// That retry exists for the GPT-5+ (`max_completion_tokens`) versus legacy +/// (`max_tokens`) chat split. Firing it for any request-shape rejection would +/// issue a second, pointless probe when the real signal is "wrong protocol for +/// this model", and a transient `429`/`5xx` on that retry could become a +/// terminal failure that stops the caller from reaching a protocol that would +/// have validated. +async fn try_validation_probe( + client: &reqwest::Client, + route: &ResolvedRoute, + probe: &ValidationProbe, + headers: &[(String, String)], +) -> Result { let result = try_validation_request( client, route, probe.path, probe.protocol, - headers.clone(), - probe.body, + headers, + probe.body.clone(), ) .await; - // If the primary probe failed with a request-shape error (HTTP 400) and - // there is a fallback body, retry with the alternate token parameter. - // This handles the split between `max_completion_tokens` (GPT-5+) and - // `max_tokens` (legacy/self-hosted backends). - if let (Err(err), Some(fallback_body)) = (&result, probe.fallback_body) + if let (Err(err), Some(fallback_body)) = (&result, &probe.fallback_body) && err.kind == ValidationFailureKind::RequestShape + && err.details.contains("max_completion_tokens") { return try_validation_request( client, @@ -450,7 +488,7 @@ pub async fn verify_backend_endpoint( probe.path, probe.protocol, headers, - fallback_body, + fallback_body.clone(), ) .await; } @@ -464,7 +502,7 @@ async fn try_validation_request( route: &ResolvedRoute, path: &str, protocol: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { let response = send_backend_request(client, route, "POST", path, headers, body) @@ -511,32 +549,57 @@ async fn try_validation_request( ) }; - let details = match status.as_u16() { - 400 | 404 | 405 | 422 => { - format!("upstream rejected the validation request with HTTP {status}.{body_suffix}") - } - 401 | 403 => { - format!("upstream rejected credentials with HTTP {status}.{body_suffix}") - } - 429 => { - format!("upstream rate-limited the validation request with HTTP {status}.{body_suffix}") - } - 500..=599 => format!("upstream returned HTTP {status}.{body_suffix}"), - _ => format!("upstream returned unexpected HTTP {status}.{body_suffix}"), + // Some OpenAI-compatible providers report an auth failure as 400/404/422 + // with an auth-shaped error body rather than 401/403. Classify those as a + // terminal credential failure so a bad key is not mistaken for a + // wrong-protocol probe and masked by a later probe that accepts it. + let kind = match status.as_u16() { + 401 | 403 => ValidationFailureKind::Credentials, + 400 | 404 | 422 if body_looks_like_auth_error(body) => ValidationFailureKind::Credentials, + 400 | 404 | 405 | 422 => ValidationFailureKind::RequestShape, + 429 => ValidationFailureKind::RateLimited, + 500..=599 => ValidationFailureKind::UpstreamHealth, + _ => ValidationFailureKind::Unexpected, + }; + + let summary = match kind { + ValidationFailureKind::Credentials => "upstream rejected credentials", + ValidationFailureKind::RateLimited => "upstream rate-limited the validation request", + ValidationFailureKind::UpstreamHealth => "upstream returned a server error", + ValidationFailureKind::RequestShape => "upstream rejected the validation request", + _ => "upstream returned an unexpected response", }; Err(ValidationFailure { - kind: match status.as_u16() { - 400 | 404 | 405 | 422 => ValidationFailureKind::RequestShape, - 401 | 403 => ValidationFailureKind::Credentials, - 429 => ValidationFailureKind::RateLimited, - 500..=599 => ValidationFailureKind::UpstreamHealth, - _ => ValidationFailureKind::Unexpected, - }, - details, + kind, + details: format!("{summary} with HTTP {status}.{body_suffix}"), }) } +/// Whether an upstream error body reads as an authentication or authorization +/// failure. Some OpenAI-compatible providers return these as HTTP 400/404/422 +/// rather than 401/403, so validation inspects the body to avoid classifying a +/// bad key as a wrong-protocol probe. Matching is conservative: only strong, +/// auth-specific phrases, lowercased, to avoid catching generic "invalid model" +/// request-shape errors. +fn body_looks_like_auth_error(body: &str) -> bool { + let body = body.to_ascii_lowercase(); + [ + "invalid_api_key", + "invalid api key", + "incorrect api key", + "invalid_authentication", + "authentication_error", + "authentication failed", + "unauthorized", + "permission_denied", + "permission denied", + "missing api key", + ] + .iter() + .any(|needle| body.contains(needle)) +} + /// Extract status and headers from a [`reqwest::Response`]. fn extract_response_metadata(response: &reqwest::Response) -> (u16, Vec<(String, String)>) { let status = response.status().as_u16(); @@ -561,7 +624,7 @@ pub async fn proxy_to_backend( headers: Vec<(String, String)>, body: bytes::Bytes, ) -> Result { - let response = send_backend_request(client, route, method, path, headers, body).await?; + let response = send_backend_request(client, route, method, path, &headers, body).await?; let (status, resp_headers) = extract_response_metadata(&response); let body = read_capped_response_body(response, MAX_BUFFERED_RESPONSE_BODY).await?; @@ -630,7 +693,7 @@ pub async fn proxy_to_backend_streaming( body: bytes::Bytes, ) -> Result { let response = - send_backend_request_streaming(client, route, method, path, headers, body).await?; + send_backend_request_streaming(client, route, method, path, &headers, body).await?; let (status, resp_headers) = extract_response_metadata(&response); Ok(StreamingProxyResponse { @@ -736,7 +799,8 @@ fn is_vertex_anthropic_rawpredict_route(route: &ResolvedRoute) -> bool { #[cfg(test)] mod tests { use super::{ - ValidationFailureKind, build_backend_url, build_provider_url, verify_backend_endpoint, + ValidationFailure, ValidationFailureKind, build_backend_url, build_provider_url, + verify_backend_endpoint, }; use crate::config::{DEFAULT_ROUTE_TIMEOUT, ResolvedRoute}; use openshell_core::inference::AuthHeader; @@ -1150,6 +1214,189 @@ mod tests { assert_eq!(validated.protocol, "openai_chat_completions"); } + /// A managed route for an embeddings model advertises the full provider + /// protocol set. The chat probe (tried first) rejects the embeddings model + /// as wrong-shape, so validation must fall through to the embeddings probe + /// rather than fail the route. + #[tokio::test] + async fn verify_embeddings_model_falls_through_chat_probe() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &[ + "openai_chat_completions", + "openai_completions", + "openai_responses", + "openai_embeddings", + "model_discovery", + ], + AuthHeader::Bearer, + ); + + // Chat, completions, and responses probes reject the embedding model. + for chat_path in ["/v1/chat/completions", "/v1/completions", "/v1/responses"] { + Mock::given(method("POST")) + .and(path(chat_path)) + .respond_with( + ResponseTemplate::new(400) + .set_body_string(r#"{"error":{"message":"not a chat model"}}"#), + ) + .mount(&mock_server) + .await; + } + // The embeddings probe accepts it. + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let validated = verify_backend_endpoint(&client, &route) + .await + .expect("embeddings model should validate via the embeddings probe"); + assert_eq!(validated.protocol, "openai_embeddings"); + } + + /// A non-request-shape failure (credentials) is terminal: validation must + /// stop at the first probe and not fall through to a protocol that would + /// succeed, so a bad key is reported as such rather than masked. + #[tokio::test] + async fn verify_stops_on_credentials_failure() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(401).set_body_string(r#"{"error":"bad key"}"#)) + .mount(&mock_server) + .await; + // Would succeed, but credentials failure on the first probe is terminal + // and this must never be reached. + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let err = verify_backend_endpoint(&client, &route) + .await + .expect_err("a 401 must fail validation"); + assert_eq!(err.kind, ValidationFailureKind::Credentials); + } + + /// A 429 on the first probe is terminal (`RateLimited`) and must not fall + /// through to a later probe that would succeed. + #[tokio::test] + async fn verify_stops_on_rate_limit() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(429).set_body_string(r#"{"error":"slow down"}"#)) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::RateLimited); + } + + /// An auth failure reported as HTTP 400 with an auth-shaped body is terminal + /// (`Credentials`), not a request-shape fall-through, so a bad key cannot be + /// masked by a later probe that accepts it. + #[tokio::test] + async fn verify_auth_error_as_400_is_terminal() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(400).set_body_string( + r#"{"error":{"code":"invalid_api_key","message":"Incorrect API key provided"}}"#, + )) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::Credentials); + } + + /// When every probe is rejected as request-shape, validation returns the + /// first (most-preferred protocol's) failure, not the last. + #[tokio::test] + async fn verify_all_probes_request_shape_returns_first() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(404).set_body_string(r#"{"error":"model not found: chat"}"#), + ) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(400) + .set_body_string(r#"{"error":"not an embeddings model"}"#), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::RequestShape); + assert!( + err.details.contains("model not found: chat"), + "should report the first (chat) failure, got: {}", + err.details + ); + } + + /// Helper: run `verify_backend_endpoint` and return the expected failure. + async fn reqwest_verify(route: &ResolvedRoute) -> ValidationFailure { + verify_backend_endpoint(&reqwest::Client::new(), route) + .await + .expect_err("validation should fail") + } + /// Non-chat-completions probes (e.g. `anthropic_messages`) should not /// have a fallback — a 400 remains a hard failure. #[tokio::test] From 1a6e83598fbf965e778040285cdb1d52451f620e Mon Sep 17 00:00:00 2001 From: Shiju Date: Fri, 5 Jun 2026 15:44:25 +0530 Subject: [PATCH 4/5] fix(inference): serve model discovery responses buffered GET /v1/models returns a single JSON model list, the same response shape as embeddings. The sandbox inference proxy was routing it through the SSE streaming path. A streaming size-cap or idle-timeout truncation appends an SSE error frame to the body, which corrupts a payload the client parses as one JSON object. Make response framing a property of the protocol. A new ResponseFraming field on InferenceApiPattern is set once per pattern in default_patterns. model_discovery and openai_embeddings are now Buffered, while chat completions, completions, responses, and Anthropic messages stay Streaming. The proxy dispatch gates on pattern.is_buffered(), which replaces the stringly-typed protocol_returns_buffered_body predicate so the streaming-versus-buffered decision lives in one place and cannot drift across the sites that read it. Model discovery now flows through the same buffered path as embeddings, framed with an accurate Content-Length and bounded by the buffered-read size cap that path already enforces. Tests - cargo test -p openshell-sandbox protocol_framing_classification - cargo test -p openshell-sandbox \ inference_model_discovery_served_buffered_with_content_length Signed-off-by: Shiju --- crates/openshell-sandbox/src/l7/inference.rs | 101 ++++++-- crates/openshell-sandbox/src/proxy.rs | 242 ++++++++++++++++++- 2 files changed, 322 insertions(+), 21 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index d9b29586c..ec789ef95 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -7,6 +7,38 @@ //! HTTP request is a known inference API call and routes it through the local //! sandbox router. +/// How an inference protocol delivers its response to the sandboxed client. +/// +/// `Streaming` protocols (chat completions, completions, responses, Anthropic +/// messages) emit a Server-Sent Events token stream and are served through the +/// chunked transfer-encoding path so tokens reach the client incrementally. +/// +/// `Buffered` protocols (embeddings, model discovery) return a single JSON +/// object the client parses whole. They must be served in one piece with an +/// accurate `Content-Length`. Sending them through the streaming path is +/// unsafe: a mid-body truncation (the streaming size cap or idle timeout) +/// appends an SSE error event to bytes the client decodes as one JSON object, +/// silently corrupting it. +/// +/// Framing is a property of the protocol, declared once per pattern in +/// [`default_patterns`], so the streaming-vs-buffered decision cannot drift +/// across the dispatch sites that consume it. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseFraming { + /// SSE token stream, served via chunked transfer-encoding. + /// + /// The `OpenAI` completion-style protocols are classified streaming + /// unconditionally. They are dual-mode (a `stream: false` request returns + /// one buffered JSON object), but the dispatch keys framing off the + /// protocol alone and does not inspect the request body, so they are + /// streamed defensively. Their buffered responses tolerate chunked framing; + /// only the embeddings and model-discovery shapes (no streaming mode at + /// all) must be served buffered. + Streaming, + /// Single JSON object, served buffered with an accurate `Content-Length`. + Buffered, +} + /// An inference API pattern for detecting inference calls in intercepted traffic. #[derive(Debug, Clone)] pub struct InferenceApiPattern { @@ -14,6 +46,18 @@ pub struct InferenceApiPattern { pub path_glob: String, pub protocol: String, pub kind: String, + /// Response delivery mode for this protocol. Selects the buffered or + /// streaming proxy path; see [`ResponseFraming`]. + pub framing: ResponseFraming, +} + +impl InferenceApiPattern { + /// Whether this protocol's response must be served buffered (one JSON + /// object framed with an accurate `Content-Length`) rather than streamed. + #[must_use] + pub fn is_buffered(&self) -> bool { + matches!(self.framing, ResponseFraming::Buffered) + } } /// Default patterns for known inference APIs (`OpenAI`, Anthropic). @@ -24,42 +68,51 @@ pub fn default_patterns() -> Vec { path_glob: "/v1/chat/completions".to_string(), protocol: "openai_chat_completions".to_string(), kind: "chat_completion".to_string(), + framing: ResponseFraming::Streaming, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/completions".to_string(), protocol: "openai_completions".to_string(), kind: "completion".to_string(), + framing: ResponseFraming::Streaming, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/responses".to_string(), protocol: "openai_responses".to_string(), kind: "responses".to_string(), + framing: ResponseFraming::Streaming, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/embeddings".to_string(), protocol: "openai_embeddings".to_string(), kind: "embeddings".to_string(), + framing: ResponseFraming::Buffered, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/messages".to_string(), protocol: "anthropic_messages".to_string(), kind: "messages".to_string(), + framing: ResponseFraming::Streaming, }, + // Model discovery returns one JSON object (a model list), never an SSE + // stream, so it is served buffered for the same reason as embeddings. InferenceApiPattern { method: "GET".to_string(), path_glob: "/v1/models".to_string(), protocol: "model_discovery".to_string(), kind: "models_list".to_string(), + framing: ResponseFraming::Buffered, }, InferenceApiPattern { method: "GET".to_string(), path_glob: "/v1/models/*".to_string(), protocol: "model_discovery".to_string(), kind: "models_get".to_string(), + framing: ResponseFraming::Buffered, }, ] } @@ -88,21 +141,6 @@ pub fn detect_inference_pattern<'a>( }) } -/// Whether a detected inference protocol returns a single buffered JSON body -/// rather than a Server-Sent Events token stream. -/// -/// SSE-capable protocols (chat completions, completions, responses, Anthropic -/// messages) are served through the chunked streaming path so tokens reach the -/// client incrementally. Embeddings always return one JSON object, so they must -/// be framed with an accurate `Content-Length`. Streaming a buffered body would -/// let a mid-body truncation (the streaming size cap or idle timeout) append a -/// [`format_sse_error`] event to a payload the client parses as a single JSON -/// object, silently corrupting it. Add future non-streaming protocols here. -#[must_use] -pub fn protocol_returns_buffered_body(protocol: &str) -> bool { - matches!(protocol, "openai_embeddings") -} - /// A parsed HTTP request from the intercepted tunnel. pub struct ParsedHttpRequest { pub method: String, @@ -455,7 +493,11 @@ mod tests { let patterns = default_patterns(); let result = detect_inference_pattern("GET", "/v1/models", &patterns); assert!(result.is_some()); - assert_eq!(result.unwrap().protocol, "model_discovery"); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "model_discovery"); + // A model list is one JSON object; it must be served buffered, never + // through the SSE streaming path that could append an error frame. + assert!(pattern.is_buffered()); } #[test] @@ -463,7 +505,9 @@ mod tests { let patterns = default_patterns(); let result = detect_inference_pattern("GET", "/v1/models/gpt-4.1", &patterns); assert!(result.is_some()); - assert_eq!(result.unwrap().protocol, "model_discovery"); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "model_discovery"); + assert!(pattern.is_buffered()); } #[test] @@ -474,6 +518,29 @@ mod tests { let pattern = result.unwrap(); assert_eq!(pattern.protocol, "openai_embeddings"); assert_eq!(pattern.kind, "embeddings"); + assert!(pattern.is_buffered()); + } + + /// Every default pattern must declare framing consistent with how its + /// protocol actually responds: single-JSON-object protocols buffered, + /// SSE token streams streaming. A wrong classification routes a response + /// through the path that can corrupt or stall it. + #[test] + fn protocol_framing_classification() { + let patterns = default_patterns(); + for pattern in &patterns { + let expected_buffered = matches!( + pattern.protocol.as_str(), + "model_discovery" | "openai_embeddings" + ); + assert_eq!( + pattern.is_buffered(), + expected_buffered, + "{} ({}) has wrong framing", + pattern.protocol, + pattern.path_glob + ); + } } #[test] diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 3e922f4d2..aa8338433 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -1666,12 +1666,13 @@ async fn route_inference_request( return Ok(true); } - // Non-streaming protocols (embeddings) return a single JSON object, not - // an SSE token stream. Serve them buffered with an accurate + // Buffered protocols (embeddings, model discovery) return a single JSON + // object, not an SSE token stream. Serve them buffered with an accurate // Content-Length: the streaming path would append an SSE error frame to // the body on a size-cap or idle-timeout truncation, corrupting a - // payload the client parses as one JSON object. - if crate::l7::inference::protocol_returns_buffered_body(&pattern.protocol) { + // payload the client parses as one JSON object. Framing is declared per + // protocol on the matched pattern. + if pattern.is_buffered() { match ctx .router .proxy_with_candidates( @@ -5566,6 +5567,217 @@ network_policies: ); } + fn model_discovery_inference_route( + endpoint: String, + ) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["model_discovery".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// `GET /v1/models` (model discovery) returns one JSON object — a model + /// list — exactly like embeddings. It must be served buffered with + /// `Content-Length`, never through the chunked streaming path whose + /// truncation handlers would append an SSE `proxy_stream_error` frame into + /// the JSON body. This guards the framing classification for the protocol. + #[tokio::test] + async fn inference_model_discovery_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = + r#"{"object":"list","data":[{"id":"text-embedding-3-small","object":"model"}]}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + // GET model discovery carries no request body. + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "model discovery response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "model discovery response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "model discovery response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "model discovery JSON body must be forwarded intact, got: {response}" + ); + } + + /// `GET /v1/models/{id}` (model discovery glob) must forward the model id in + /// the path through the buffered path with the id intact, never streamed. + #[tokio::test] + async fn inference_model_discovery_glob_path_served_buffered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"id":"gpt-4.1","object":"model"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + let forwarded = read_forwarded_request_line(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + forwarded + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models/gpt-4.1 HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + let (method, forwarded_path) = upstream_task.await.unwrap(); + + assert_eq!(method, "GET"); + assert_eq!( + forwarded_path, "/v1/models/gpt-4.1", + "the model id in the glob path must be forwarded intact" + ); + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n") + && lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "glob model discovery must be buffered and Content-Length framed, got: {response}" + ); + } + + /// A failed model-discovery upstream must produce a buffered, Content-Length + /// framed JSON error, never a chunked SSE `proxy_stream_error` frame. + #[tokio::test] + async fn inference_model_discovery_error_served_buffered() { + // A port with no listener so the upstream connection is refused. + let dead_addr = { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr + }; + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{dead_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 5"), + "a refused upstream should yield a 5xx, got: {response}" + ); + assert!( + lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "buffered model-discovery error must be Content-Length framed JSON, got: {response}" + ); + assert!( + response.contains("error"), + "error response should carry a JSON error body, got: {response}" + ); + } + async fn read_forwarded_inference_request(stream: &mut S) { use crate::l7::inference::{ParseResult, try_parse_http_request}; @@ -5586,6 +5798,28 @@ network_policies: } } + /// Like [`read_forwarded_inference_request`] but returns the forwarded + /// request line (method, path) so a test can assert the upstream URL path. + async fn read_forwarded_request_line(stream: &mut S) -> (String, String) { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + let n = stream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before completion"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(req, _) => return (req.method, req.path), + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + } + async fn run_live_streaming_inference(serve_upstream: F) -> String where F: FnOnce(TcpStream) -> Fut + Send + 'static, From 33d0924950e144be1008edd7e7b41ab72d79d43c Mon Sep 17 00:00:00 2001 From: Shiju Date: Sat, 6 Jun 2026 14:39:11 +0530 Subject: [PATCH 5/5] docs(inference): document embeddings route in supported patterns Signed-off-by: Shiju --- docs/sandboxes/inference-routing.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/sandboxes/inference-routing.mdx b/docs/sandboxes/inference-routing.mdx index 3d8c48cd8..0a4e9d726 100644 --- a/docs/sandboxes/inference-routing.mdx +++ b/docs/sandboxes/inference-routing.mdx @@ -42,6 +42,7 @@ Supported request patterns depend on the provider configured for `inference.loca | Chat Completions | `POST` | `/v1/chat/completions` | | Completions | `POST` | `/v1/completions` | | Responses | `POST` | `/v1/responses` | +| Embeddings | `POST` | `/v1/embeddings` | | Model Discovery | `GET` | `/v1/models` | | Model Discovery | `GET` | `/v1/models/*` |