diff --git a/crates/openshell-core/src/inference.rs b/crates/openshell-core/src/inference.rs index c04feb6b4..3071d53cd 100644 --- a/crates/openshell-core/src/inference.rs +++ b/crates/openshell-core/src/inference.rs @@ -56,6 +56,7 @@ const OPENAI_PROTOCOLS: &[&str] = &[ "openai_chat_completions", "openai_completions", "openai_responses", + "openai_embeddings", "model_discovery", ]; @@ -305,6 +306,17 @@ mod tests { assert!(profile_for("OpenAI").is_some()); // case insensitive } + #[test] + fn openai_compatible_profiles_include_embeddings() { + for provider_type in ["openai", "nvidia"] { + let profile = profile_for(provider_type).expect("provider profile should exist"); + assert!( + profile.protocols.contains(&"openai_embeddings"), + "{provider_type} should route OpenAI-compatible embeddings" + ); + } + } + #[test] fn profile_for_unknown_types() { assert!(profile_for("github").is_none()); diff --git a/crates/openshell-router/src/backend.rs b/crates/openshell-router/src/backend.rs index 272da265a..9eb63c88b 100644 --- a/crates/openshell-router/src/backend.rs +++ b/crates/openshell-router/src/backend.rs @@ -6,6 +6,13 @@ use crate::config::{AuthHeader, ResolvedRoute}; use crate::mock; use std::collections::HashSet; +/// Maximum buffered inference response body, in bytes. The buffered path +/// reads the whole response into memory; the route timeout bounds time, not +/// memory, so without this cap an oversized upstream could force unbounded +/// allocation. Mirrors the sandbox streaming byte cap. Over-cap responses fail +/// as an upstream protocol error. +const MAX_BUFFERED_RESPONSE_BODY: usize = 32 * 1024 * 1024; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedEndpoint { pub url: String, @@ -32,10 +39,12 @@ struct ValidationProbe { path: &'static str, protocol: &'static str, body: bytes::Bytes, - /// Alternate body to try when the primary probe fails with HTTP 400. - /// Used for `OpenAI` chat completions where newer models require - /// `max_completion_tokens` while legacy/self-hosted backends only - /// accept `max_tokens`. + /// Alternate body to try when the primary probe is rejected specifically + /// for `max_completion_tokens`. Used for `OpenAI` chat completions where + /// newer models require `max_completion_tokens` while legacy/self-hosted + /// backends only accept `max_tokens`. The retry is gated on the error + /// body naming that parameter, so an unrelated request-shape rejection + /// (wrong protocol for the model) falls through instead. fallback_body: Option, } @@ -290,11 +299,11 @@ async fn send_backend_request( route: &ResolvedRoute, method: &str, path: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { let (builder, url) = - prepare_backend_request(client, route, method, path, &headers, body, false)?; + prepare_backend_request(client, route, method, path, headers, body, false)?; builder .timeout(route.timeout) .send() @@ -312,23 +321,30 @@ async fn send_backend_request_streaming( route: &ResolvedRoute, method: &str, path: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { - let (builder, url) = - prepare_backend_request(client, route, method, path, &headers, body, true)?; + let (builder, url) = prepare_backend_request(client, route, method, path, headers, body, true)?; builder.send().await.map_err(|e| map_send_error(e, &url)) } -fn validation_probe(route: &ResolvedRoute) -> Result { - if route - .protocols - .iter() - .any(|protocol| protocol == "openai_chat_completions") - { +/// Validation probes for a route, in preference order. +/// +/// A managed route advertises every protocol in its provider profile, so an +/// embeddings model resolves to a route that also lists chat/completions. The +/// caller tries these in order and falls through to the next on a request-shape +/// rejection, so such a model validates against `/v1/embeddings` even though +/// the chat probe rejects it. Embeddings is ordered last so a genuinely +/// chat-capable route still validates against chat. Empty when the route +/// exposes no writable protocol. +fn validation_probes(route: &ResolvedRoute) -> Vec { + let has = |protocol: &str| route.protocols.iter().any(|p| p == protocol); + let mut probes = Vec::new(); + + if has("openai_chat_completions") { // Use max_completion_tokens (modern OpenAI parameter, required by GPT-5+) // with max_tokens as fallback for legacy/self-hosted backends. - return Ok(ValidationProbe { + probes.push(ValidationProbe { path: "/v1/chat/completions", protocol: "openai_chat_completions", body: bytes::Bytes::from_static( @@ -340,12 +356,8 @@ fn validation_probe(route: &ResolvedRoute) -> Result Result Result Result ValidationFailure { + ValidationFailure { kind: ValidationFailureKind::RequestShape, details: format!( "route '{}' does not expose a writable inference protocol for validation", route.name ), - }) + } } pub async fn verify_backend_endpoint( client: &reqwest::Client, route: &ResolvedRoute, ) -> Result { - let probe = validation_probe(route)?; - let headers = vec![("content-type".to_string(), "application/json".to_string())]; + let probes = validation_probes(route); + let Some(first) = probes.first() else { + return Err(no_writable_protocol_failure(route)); + }; if mock::is_mock_route(route) { return Ok(ValidatedEndpoint { - url: build_provider_url(route, &route.model, probe.path, false), - protocol: probe.protocol.to_string(), + url: build_provider_url(route, &route.model, first.path, false), + protocol: first.protocol.to_string(), }); } + let headers = vec![("content-type".to_string(), "application/json".to_string())]; + let mut last_shape_failure = None; + + for probe in &probes { + match try_validation_probe(client, route, probe, &headers).await { + Ok(endpoint) => return Ok(endpoint), + // A request-shape rejection means this protocol is wrong for the + // model (e.g. a chat probe against an embeddings model), so fall + // through to the next advertised protocol. Any other failure + // describes the backend itself (credentials, rate limit, + // connectivity, health) and is terminal across all protocols. + // + // Keep the first shape failure: it is the most-preferred protocol's + // rejection and the most actionable error to report. + Err(err) if err.kind == ValidationFailureKind::RequestShape => { + last_shape_failure.get_or_insert(err); + } + Err(err) => return Err(err), + } + } + + Err(last_shape_failure.unwrap_or_else(|| no_writable_protocol_failure(route))) +} + +/// Run one validation probe, retrying with its fallback body only when the +/// upstream specifically rejected `max_completion_tokens`. +/// +/// That retry exists for the GPT-5+ (`max_completion_tokens`) versus legacy +/// (`max_tokens`) chat split. Firing it for any request-shape rejection would +/// issue a second, pointless probe when the real signal is "wrong protocol for +/// this model", and a transient `429`/`5xx` on that retry could become a +/// terminal failure that stops the caller from reaching a protocol that would +/// have validated. +async fn try_validation_probe( + client: &reqwest::Client, + route: &ResolvedRoute, + probe: &ValidationProbe, + headers: &[(String, String)], +) -> Result { let result = try_validation_request( client, route, probe.path, probe.protocol, - headers.clone(), - probe.body, + headers, + probe.body.clone(), ) .await; - // If the primary probe failed with a request-shape error (HTTP 400) and - // there is a fallback body, retry with the alternate token parameter. - // This handles the split between `max_completion_tokens` (GPT-5+) and - // `max_tokens` (legacy/self-hosted backends). - if let (Err(err), Some(fallback_body)) = (&result, probe.fallback_body) + if let (Err(err), Some(fallback_body)) = (&result, &probe.fallback_body) && err.kind == ValidationFailureKind::RequestShape + && err.details.contains("max_completion_tokens") { return try_validation_request( client, @@ -427,7 +488,7 @@ pub async fn verify_backend_endpoint( probe.path, probe.protocol, headers, - fallback_body, + fallback_body.clone(), ) .await; } @@ -441,7 +502,7 @@ async fn try_validation_request( route: &ResolvedRoute, path: &str, protocol: &str, - headers: Vec<(String, String)>, + headers: &[(String, String)], body: bytes::Bytes, ) -> Result { let response = send_backend_request(client, route, "POST", path, headers, body) @@ -488,32 +549,57 @@ async fn try_validation_request( ) }; - let details = match status.as_u16() { - 400 | 404 | 405 | 422 => { - format!("upstream rejected the validation request with HTTP {status}.{body_suffix}") - } - 401 | 403 => { - format!("upstream rejected credentials with HTTP {status}.{body_suffix}") - } - 429 => { - format!("upstream rate-limited the validation request with HTTP {status}.{body_suffix}") - } - 500..=599 => format!("upstream returned HTTP {status}.{body_suffix}"), - _ => format!("upstream returned unexpected HTTP {status}.{body_suffix}"), + // Some OpenAI-compatible providers report an auth failure as 400/404/422 + // with an auth-shaped error body rather than 401/403. Classify those as a + // terminal credential failure so a bad key is not mistaken for a + // wrong-protocol probe and masked by a later probe that accepts it. + let kind = match status.as_u16() { + 401 | 403 => ValidationFailureKind::Credentials, + 400 | 404 | 422 if body_looks_like_auth_error(body) => ValidationFailureKind::Credentials, + 400 | 404 | 405 | 422 => ValidationFailureKind::RequestShape, + 429 => ValidationFailureKind::RateLimited, + 500..=599 => ValidationFailureKind::UpstreamHealth, + _ => ValidationFailureKind::Unexpected, + }; + + let summary = match kind { + ValidationFailureKind::Credentials => "upstream rejected credentials", + ValidationFailureKind::RateLimited => "upstream rate-limited the validation request", + ValidationFailureKind::UpstreamHealth => "upstream returned a server error", + ValidationFailureKind::RequestShape => "upstream rejected the validation request", + _ => "upstream returned an unexpected response", }; Err(ValidationFailure { - kind: match status.as_u16() { - 400 | 404 | 405 | 422 => ValidationFailureKind::RequestShape, - 401 | 403 => ValidationFailureKind::Credentials, - 429 => ValidationFailureKind::RateLimited, - 500..=599 => ValidationFailureKind::UpstreamHealth, - _ => ValidationFailureKind::Unexpected, - }, - details, + kind, + details: format!("{summary} with HTTP {status}.{body_suffix}"), }) } +/// Whether an upstream error body reads as an authentication or authorization +/// failure. Some OpenAI-compatible providers return these as HTTP 400/404/422 +/// rather than 401/403, so validation inspects the body to avoid classifying a +/// bad key as a wrong-protocol probe. Matching is conservative: only strong, +/// auth-specific phrases, lowercased, to avoid catching generic "invalid model" +/// request-shape errors. +fn body_looks_like_auth_error(body: &str) -> bool { + let body = body.to_ascii_lowercase(); + [ + "invalid_api_key", + "invalid api key", + "incorrect api key", + "invalid_authentication", + "authentication_error", + "authentication failed", + "unauthorized", + "permission_denied", + "permission denied", + "missing api key", + ] + .iter() + .any(|needle| body.contains(needle)) +} + /// Extract status and headers from a [`reqwest::Response`]. fn extract_response_metadata(response: &reqwest::Response) -> (u16, Vec<(String, String)>) { let status = response.status().as_u16(); @@ -538,20 +624,60 @@ pub async fn proxy_to_backend( headers: Vec<(String, String)>, body: bytes::Bytes, ) -> Result { - let response = send_backend_request(client, route, method, path, headers, body).await?; + let response = send_backend_request(client, route, method, path, &headers, body).await?; let (status, resp_headers) = extract_response_metadata(&response); - let resp_body = response - .bytes() - .await - .map_err(|e| RouterError::UpstreamProtocol(format!("failed to read response body: {e}")))?; + let body = read_capped_response_body(response, MAX_BUFFERED_RESPONSE_BODY).await?; Ok(ProxyResponse { status, headers: resp_headers, - body: resp_body, + body, }) } +/// Read a response body fully into memory, rejecting anything over `max` bytes. +/// +/// Used by the buffered proxy path so a misbehaving upstream cannot force +/// unbounded allocation. The `Content-Length` check is a fast early-out; the +/// chunk loop is the real guard and bounds an absent, chunked, or +/// under-reported length. The cap counts the bytes reqwest yields: with no +/// decompression features enabled (see `Cargo.toml`) those are wire bytes, so +/// enabling a compression feature later would change what the cap measures. +/// Over-cap responses fail as `UpstreamProtocol` and are never partially +/// returned. +async fn read_capped_response_body( + mut response: reqwest::Response, + max: usize, +) -> Result { + if let Some(len) = response.content_length() + && len > max as u64 + { + return Err(RouterError::UpstreamProtocol(format!( + "inference response body of {len} bytes exceeds the {max} byte cap" + ))); + } + + // Preallocate to the advertised length when it is within the cap; the loop + // still enforces the bound for an absent or under-reported length. + let mut body: Vec = match response.content_length() { + Some(len) if len <= max as u64 => Vec::with_capacity(usize::try_from(len).unwrap_or(max)), + _ => Vec::new(), + }; + while let Some(chunk) = response + .chunk() + .await + .map_err(|e| RouterError::UpstreamProtocol(format!("failed to read response body: {e}")))? + { + if body.len() + chunk.len() > max { + return Err(RouterError::UpstreamProtocol(format!( + "inference response body exceeds the {max} byte cap" + ))); + } + body.extend_from_slice(&chunk); + } + Ok(bytes::Bytes::from(body)) +} + /// Forward a raw HTTP request to the backend, returning response headers /// immediately without buffering the body. /// @@ -567,7 +693,7 @@ pub async fn proxy_to_backend_streaming( body: bytes::Bytes, ) -> Result { let response = - send_backend_request_streaming(client, route, method, path, headers, body).await?; + send_backend_request_streaming(client, route, method, path, &headers, body).await?; let (status, resp_headers) = extract_response_metadata(&response); Ok(StreamingProxyResponse { @@ -673,7 +799,8 @@ fn is_vertex_anthropic_rawpredict_route(route: &ResolvedRoute) -> bool { #[cfg(test)] mod tests { use super::{ - ValidationFailureKind, build_backend_url, build_provider_url, verify_backend_endpoint, + ValidationFailure, ValidationFailureKind, build_backend_url, build_provider_url, + verify_backend_endpoint, }; use crate::config::{DEFAULT_ROUTE_TIMEOUT, ResolvedRoute}; use openshell_core::inference::AuthHeader; @@ -723,6 +850,100 @@ mod tests { } } + /// The buffered path must reject an over-cap upstream response rather than + /// buffer it. Guards the DoS/OOM exposure of reading the body unbounded. + #[tokio::test] + async fn proxy_to_backend_rejects_over_cap_response_body() { + use super::{MAX_BUFFERED_RESPONSE_BODY, proxy_to_backend}; + + let mock_server = MockServer::start().await; + // One byte over the cap. wiremock sets an accurate Content-Length, so + // the size check rejects before the body is buffered. + let oversized = vec![b'a'; MAX_BUFFERED_RESPONSE_BODY + 1]; + Mock::given(method("GET")) + .and(path("/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_bytes(oversized)) + .mount(&mock_server) + .await; + + let route = test_route(&mock_server.uri(), &["model_discovery"], AuthHeader::Bearer); + let client = reqwest::Client::new(); + let result = proxy_to_backend( + &client, + &route, + "model_discovery", + "GET", + "/v1/models", + vec![], + bytes::Bytes::new(), + ) + .await; + + assert!( + matches!(result, Err(crate::RouterError::UpstreamProtocol(_))), + "over-cap response must fail as UpstreamProtocol, got: {result:?}" + ); + } + + /// Spawn a one-shot HTTP/1.1 upstream that replies with a chunked body and + /// no `Content-Length`, so the buffered read cannot pre-check a length and + /// must enforce the cap inside the chunk loop. + async fn spawn_chunked_upstream(chunks: &'static [&'static str]) -> std::net::SocketAddr { + use std::fmt::Write as _; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + let (mut sock, _) = listener.accept().await.unwrap(); + let mut buf = [0u8; 1024]; + let _ = sock.read(&mut buf).await; + let mut resp = String::from( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\n\r\n", + ); + for c in chunks { + let _ = write!(resp, "{:x}\r\n{c}\r\n", c.len()); + } + resp.push_str("0\r\n\r\n"); + sock.write_all(resp.as_bytes()).await.unwrap(); + }); + addr + } + + /// The chunk-accumulation guard (not the `Content-Length` pre-check) must + /// reject an over-cap body when the response advertises no length. + #[tokio::test] + async fn read_capped_response_body_rejects_over_cap_chunked() { + let addr = spawn_chunked_upstream(&["aaaa", "bbbb", "cccc"]).await; + let response = reqwest::Client::new() + .get(format!("http://{addr}/")) + .send() + .await + .unwrap(); + assert!( + response.content_length().is_none(), + "chunked response should advertise no Content-Length" + ); + let result = super::read_capped_response_body(response, 8).await; + assert!( + matches!(result, Err(crate::RouterError::UpstreamProtocol(_))), + "over-cap chunked body must be rejected by the loop, got: {result:?}" + ); + } + + /// A body exactly at the cap is accepted (inclusive bound) and returned + /// intact through the chunk loop. + #[tokio::test] + async fn read_capped_response_body_accepts_body_at_cap() { + let addr = spawn_chunked_upstream(&["aaaa", "bbbb"]).await; + let response = reqwest::Client::new() + .get(format!("http://{addr}/")) + .send() + .await + .unwrap(); + let body = super::read_capped_response_body(response, 8).await.unwrap(); + assert_eq!(&body[..], b"aaaabbbb"); + } + #[test] fn sanitize_request_headers_drops_unknown_sensitive_headers() { let route = ResolvedRoute { @@ -993,6 +1214,189 @@ mod tests { assert_eq!(validated.protocol, "openai_chat_completions"); } + /// A managed route for an embeddings model advertises the full provider + /// protocol set. The chat probe (tried first) rejects the embeddings model + /// as wrong-shape, so validation must fall through to the embeddings probe + /// rather than fail the route. + #[tokio::test] + async fn verify_embeddings_model_falls_through_chat_probe() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &[ + "openai_chat_completions", + "openai_completions", + "openai_responses", + "openai_embeddings", + "model_discovery", + ], + AuthHeader::Bearer, + ); + + // Chat, completions, and responses probes reject the embedding model. + for chat_path in ["/v1/chat/completions", "/v1/completions", "/v1/responses"] { + Mock::given(method("POST")) + .and(path(chat_path)) + .respond_with( + ResponseTemplate::new(400) + .set_body_string(r#"{"error":{"message":"not a chat model"}}"#), + ) + .mount(&mock_server) + .await; + } + // The embeddings probe accepts it. + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let validated = verify_backend_endpoint(&client, &route) + .await + .expect("embeddings model should validate via the embeddings probe"); + assert_eq!(validated.protocol, "openai_embeddings"); + } + + /// A non-request-shape failure (credentials) is terminal: validation must + /// stop at the first probe and not fall through to a protocol that would + /// succeed, so a bad key is reported as such rather than masked. + #[tokio::test] + async fn verify_stops_on_credentials_failure() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(401).set_body_string(r#"{"error":"bad key"}"#)) + .mount(&mock_server) + .await; + // Would succeed, but credentials failure on the first probe is terminal + // and this must never be reached. + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let client = reqwest::Client::new(); + let err = verify_backend_endpoint(&client, &route) + .await + .expect_err("a 401 must fail validation"); + assert_eq!(err.kind, ValidationFailureKind::Credentials); + } + + /// A 429 on the first probe is terminal (`RateLimited`) and must not fall + /// through to a later probe that would succeed. + #[tokio::test] + async fn verify_stops_on_rate_limit() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(429).set_body_string(r#"{"error":"slow down"}"#)) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::RateLimited); + } + + /// An auth failure reported as HTTP 400 with an auth-shaped body is terminal + /// (`Credentials`), not a request-shape fall-through, so a bad key cannot be + /// masked by a later probe that accepts it. + #[tokio::test] + async fn verify_auth_error_as_400_is_terminal() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(ResponseTemplate::new(400).set_body_string( + r#"{"error":{"code":"invalid_api_key","message":"Incorrect API key provided"}}"#, + )) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(200) + .set_body_json(serde_json::json!({"object": "list", "data": []})), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::Credentials); + } + + /// When every probe is rejected as request-shape, validation returns the + /// first (most-preferred protocol's) failure, not the last. + #[tokio::test] + async fn verify_all_probes_request_shape_returns_first() { + let mock_server = MockServer::start().await; + let route = test_route( + &mock_server.uri(), + &["openai_chat_completions", "openai_embeddings"], + AuthHeader::Bearer, + ); + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with( + ResponseTemplate::new(404).set_body_string(r#"{"error":"model not found: chat"}"#), + ) + .mount(&mock_server) + .await; + Mock::given(method("POST")) + .and(path("/v1/embeddings")) + .respond_with( + ResponseTemplate::new(400) + .set_body_string(r#"{"error":"not an embeddings model"}"#), + ) + .mount(&mock_server) + .await; + + let err = reqwest_verify(&route).await; + assert_eq!(err.kind, ValidationFailureKind::RequestShape); + assert!( + err.details.contains("model not found: chat"), + "should report the first (chat) failure, got: {}", + err.details + ); + } + + /// Helper: run `verify_backend_endpoint` and return the expected failure. + async fn reqwest_verify(route: &ResolvedRoute) -> ValidationFailure { + verify_backend_endpoint(&reqwest::Client::new(), route) + .await + .expect_err("validation should fail") + } + /// Non-chat-completions probes (e.g. `anthropic_messages`) should not /// have a fallback — a 400 remains a hard failure. #[tokio::test] diff --git a/crates/openshell-router/src/mock.rs b/crates/openshell-router/src/mock.rs index 92f3671ba..a60f7dcfc 100644 --- a/crates/openshell-router/src/mock.rs +++ b/crates/openshell-router/src/mock.rs @@ -32,6 +32,7 @@ pub fn mock_response(route: &ResolvedRoute, source_protocol: &str) -> ProxyRespo let body = match protocol { "openai_chat_completions" => openai_chat_completion_body(&route.model), "openai_completions" => openai_completion_body(&route.model), + "openai_embeddings" => openai_embeddings_body(&route.model), "anthropic_messages" => anthropic_messages_body(&route.model), _ => generic_body(&route.model), }; @@ -90,6 +91,26 @@ fn openai_completion_body(model: &str) -> Vec { .expect("static JSON must serialize") } +fn openai_embeddings_body(model: &str) -> Vec { + // Shape must match the OpenAI embeddings response (`object: "list"` with a + // `data` array of `{object, index, embedding}`) so callers that deserialize + // into an embeddings type get a structurally valid — if canned — vector. + serde_json::to_vec(&serde_json::json!({ + "object": "list", + "data": [{ + "object": "embedding", + "index": 0, + "embedding": [0.0_f32, 0.0_f32, 0.0_f32] + }], + "model": model, + "usage": { + "prompt_tokens": 1, + "total_tokens": 1 + } + })) + .expect("static JSON must serialize") +} + fn anthropic_messages_body(model: &str) -> Vec { serde_json::to_vec(&serde_json::json!({ "id": "mock-msg-001", @@ -196,6 +217,26 @@ mod tests { ); } + #[test] + fn mock_openai_embeddings() { + let route = make_route( + "mock://test", + &["openai_embeddings"], + "text-embedding-3-small", + ); + let resp = mock_response(&route, "openai_embeddings"); + assert_eq!(resp.status, 200); + + let body: serde_json::Value = serde_json::from_slice(&resp.body).unwrap(); + assert_eq!(body["object"], "list"); + assert_eq!(body["model"], "text-embedding-3-small"); + assert_eq!(body["data"][0]["object"], "embedding"); + assert!( + body["data"][0]["embedding"].is_array(), + "embedding must be a numeric array, got: {body}" + ); + } + #[test] fn mock_generic_protocol() { let route = make_route("mock://test", &["unknown_protocol"], "some-model"); diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index acda0bb36..ec789ef95 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -7,6 +7,38 @@ //! HTTP request is a known inference API call and routes it through the local //! sandbox router. +/// How an inference protocol delivers its response to the sandboxed client. +/// +/// `Streaming` protocols (chat completions, completions, responses, Anthropic +/// messages) emit a Server-Sent Events token stream and are served through the +/// chunked transfer-encoding path so tokens reach the client incrementally. +/// +/// `Buffered` protocols (embeddings, model discovery) return a single JSON +/// object the client parses whole. They must be served in one piece with an +/// accurate `Content-Length`. Sending them through the streaming path is +/// unsafe: a mid-body truncation (the streaming size cap or idle timeout) +/// appends an SSE error event to bytes the client decodes as one JSON object, +/// silently corrupting it. +/// +/// Framing is a property of the protocol, declared once per pattern in +/// [`default_patterns`], so the streaming-vs-buffered decision cannot drift +/// across the dispatch sites that consume it. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ResponseFraming { + /// SSE token stream, served via chunked transfer-encoding. + /// + /// The `OpenAI` completion-style protocols are classified streaming + /// unconditionally. They are dual-mode (a `stream: false` request returns + /// one buffered JSON object), but the dispatch keys framing off the + /// protocol alone and does not inspect the request body, so they are + /// streamed defensively. Their buffered responses tolerate chunked framing; + /// only the embeddings and model-discovery shapes (no streaming mode at + /// all) must be served buffered. + Streaming, + /// Single JSON object, served buffered with an accurate `Content-Length`. + Buffered, +} + /// An inference API pattern for detecting inference calls in intercepted traffic. #[derive(Debug, Clone)] pub struct InferenceApiPattern { @@ -14,6 +46,18 @@ pub struct InferenceApiPattern { pub path_glob: String, pub protocol: String, pub kind: String, + /// Response delivery mode for this protocol. Selects the buffered or + /// streaming proxy path; see [`ResponseFraming`]. + pub framing: ResponseFraming, +} + +impl InferenceApiPattern { + /// Whether this protocol's response must be served buffered (one JSON + /// object framed with an accurate `Content-Length`) rather than streamed. + #[must_use] + pub fn is_buffered(&self) -> bool { + matches!(self.framing, ResponseFraming::Buffered) + } } /// Default patterns for known inference APIs (`OpenAI`, Anthropic). @@ -24,36 +68,51 @@ pub fn default_patterns() -> Vec { path_glob: "/v1/chat/completions".to_string(), protocol: "openai_chat_completions".to_string(), kind: "chat_completion".to_string(), + framing: ResponseFraming::Streaming, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/completions".to_string(), protocol: "openai_completions".to_string(), kind: "completion".to_string(), + framing: ResponseFraming::Streaming, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/responses".to_string(), protocol: "openai_responses".to_string(), kind: "responses".to_string(), + framing: ResponseFraming::Streaming, + }, + InferenceApiPattern { + method: "POST".to_string(), + path_glob: "/v1/embeddings".to_string(), + protocol: "openai_embeddings".to_string(), + kind: "embeddings".to_string(), + framing: ResponseFraming::Buffered, }, InferenceApiPattern { method: "POST".to_string(), path_glob: "/v1/messages".to_string(), protocol: "anthropic_messages".to_string(), kind: "messages".to_string(), + framing: ResponseFraming::Streaming, }, + // Model discovery returns one JSON object (a model list), never an SSE + // stream, so it is served buffered for the same reason as embeddings. InferenceApiPattern { method: "GET".to_string(), path_glob: "/v1/models".to_string(), protocol: "model_discovery".to_string(), kind: "models_list".to_string(), + framing: ResponseFraming::Buffered, }, InferenceApiPattern { method: "GET".to_string(), path_glob: "/v1/models/*".to_string(), protocol: "model_discovery".to_string(), kind: "models_get".to_string(), + framing: ResponseFraming::Buffered, }, ] } @@ -267,20 +326,37 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } -/// Format an HTTP/1.1 response from status, headers, and body. -pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u8]) -> Vec { - use std::fmt::Write; - - let status_text = match status { +/// Reason phrase for an HTTP status code used on the inference proxy path. +/// +/// Covers the statuses produced by the router error mapping (400/401/403/500/ +/// 502/503) and the upstream codes an inference backend can pass through +/// verbatim (404/405 on unknown model or method, 422 on malformed embeddings +/// input, 429 on rate limit). Unknown codes fall back to `"Unknown"` so the +/// status line is still well-formed. +fn http_status_text(status: u16) -> &'static str { + match status { 200 => "OK", 400 => "Bad Request", + 401 => "Unauthorized", 403 => "Forbidden", + 404 => "Not Found", + 405 => "Method Not Allowed", 411 => "Length Required", 413 => "Payload Too Large", + 422 => "Unprocessable Entity", + 429 => "Too Many Requests", 500 => "Internal Server Error", 502 => "Bad Gateway", + 503 => "Service Unavailable", _ => "Unknown", - }; + } +} + +/// Format an HTTP/1.1 response from status, headers, and body. +pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u8]) -> Vec { + use std::fmt::Write; + + let status_text = http_status_text(status); let mut response = format!("HTTP/1.1 {status} {status_text}\r\n"); let mut has_content_length = false; @@ -310,17 +386,7 @@ pub fn format_http_response(status: u16, headers: &[(String, String)], body: &[u pub fn format_http_response_header(status: u16, headers: &[(String, String)]) -> Vec { use std::fmt::Write; - let status_text = match status { - 200 => "OK", - 400 => "Bad Request", - 403 => "Forbidden", - 411 => "Length Required", - 413 => "Payload Too Large", - 500 => "Internal Server Error", - 502 => "Bad Gateway", - 503 => "Service Unavailable", - _ => "Unknown", - }; + let status_text = http_status_text(status); let mut response = format!("HTTP/1.1 {status} {status_text}\r\n"); for (name, value) in headers { @@ -427,7 +493,11 @@ mod tests { let patterns = default_patterns(); let result = detect_inference_pattern("GET", "/v1/models", &patterns); assert!(result.is_some()); - assert_eq!(result.unwrap().protocol, "model_discovery"); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "model_discovery"); + // A model list is one JSON object; it must be served buffered, never + // through the SSE streaming path that could append an error frame. + assert!(pattern.is_buffered()); } #[test] @@ -435,14 +505,42 @@ mod tests { let patterns = default_patterns(); let result = detect_inference_pattern("GET", "/v1/models/gpt-4.1", &patterns); assert!(result.is_some()); - assert_eq!(result.unwrap().protocol, "model_discovery"); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "model_discovery"); + assert!(pattern.is_buffered()); } #[test] - fn no_match_for_embeddings() { + fn detect_openai_embeddings() { let patterns = default_patterns(); let result = detect_inference_pattern("POST", "/v1/embeddings", &patterns); - assert!(result.is_none()); + assert!(result.is_some()); + let pattern = result.unwrap(); + assert_eq!(pattern.protocol, "openai_embeddings"); + assert_eq!(pattern.kind, "embeddings"); + assert!(pattern.is_buffered()); + } + + /// Every default pattern must declare framing consistent with how its + /// protocol actually responds: single-JSON-object protocols buffered, + /// SSE token streams streaming. A wrong classification routes a response + /// through the path that can corrupt or stall it. + #[test] + fn protocol_framing_classification() { + let patterns = default_patterns(); + for pattern in &patterns { + let expected_buffered = matches!( + pattern.protocol.as_str(), + "model_discovery" | "openai_embeddings" + ); + assert_eq!( + pattern.is_buffered(), + expected_buffered, + "{} ({}) has wrong framing", + pattern.protocol, + pattern.path_glob + ); + } } #[test] diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index ae100d734..aa8338433 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -1666,6 +1666,35 @@ async fn route_inference_request( return Ok(true); } + // Buffered protocols (embeddings, model discovery) return a single JSON + // object, not an SSE token stream. Serve them buffered with an accurate + // Content-Length: the streaming path would append an SSE error frame to + // the body on a size-cap or idle-timeout truncation, corrupting a + // payload the client parses as one JSON object. Framing is declared per + // protocol on the matched pattern. + if pattern.is_buffered() { + match ctx + .router + .proxy_with_candidates( + &pattern.protocol, + &request.method, + &normalized_path, + request.headers.clone(), + bytes::Bytes::from(request.body.clone()), + &routes, + ) + .await + { + Ok(resp) => { + let resp_headers = sanitize_inference_response_headers(resp.headers); + let response = format_http_response(resp.status, &resp_headers, &resp.body); + write_all(tls_client, &response).await?; + } + Err(e) => write_inference_router_error(tls_client, &e).await?, + } + return Ok(true); + } + match ctx .router .proxy_with_candidates_streaming( @@ -1760,29 +1789,7 @@ async fn route_inference_request( // Terminate the chunked stream. write_all(tls_client, format_chunk_terminator()).await?; } - Err(e) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "inference endpoint detected but upstream service failed: {e}" - )) - .build(); - ocsf_emit!(event); - } - let (status, msg) = router_error_to_http(&e); - let body = serde_json::json!({"error": msg}); - let body_bytes = body.to_string(); - let response = format_http_response( - status, - &[("content-type".to_string(), "application/json".to_string())], - body_bytes.as_bytes(), - ); - write_all(tls_client, &response).await?; - } + Err(e) => write_inference_router_error(tls_client, &e).await?, } Ok(true) } else { @@ -1814,11 +1821,42 @@ async fn route_inference_request( } } +/// Emit an OCSF failure event and write a buffered JSON error response for a +/// router error hit while proxying an inference request. +/// +/// Shared by the streaming and buffered routing paths so both surface upstream +/// failures with the same status mapping and the same audit record. +async fn write_inference_router_error( + tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), + err: &openshell_router::RouterError, +) -> Result<()> { + use crate::l7::inference::format_http_response; + + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "inference endpoint detected but upstream service failed: {err}" + )) + .build(); + ocsf_emit!(event); + + let (status, msg) = router_error_to_http(err); + let body = serde_json::json!({ "error": msg }).to_string(); + let response = format_http_response( + status, + &[("content-type".to_string(), "application/json".to_string())], + body.as_bytes(), + ); + write_all(tls_client, &response).await +} + /// Map router errors to HTTP status codes and sanitized messages. /// -/// Returns generic error messages instead of verbatim internal details. -/// Full error context (upstream URLs, hostnames, TLS details) is logged -/// server-side by the caller at `warn` level for debugging. +/// Returns generic, client-safe messages instead of verbatim internal details; +/// the full error is recorded in the OCSF failure event by the caller. fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { use openshell_router::RouterError; match err { @@ -5433,6 +5471,313 @@ network_policies: } } + fn embeddings_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_embeddings".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// Embeddings responses are a single buffered JSON object, not an SSE + /// stream. They must be framed with `Content-Length` and must never be sent + /// through the chunked streaming path, whose truncation handlers would + /// append an SSE `proxy_stream_error` frame into the JSON body. + #[tokio::test] + async fn inference_embeddings_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"object":"list","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2]}],"model":"text-embedding-3-small"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + // Buffered upstream response with Content-Length (no chunked TE). + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![embeddings_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let body = r#"{"model":"text-embedding-3-small","input":"hello"}"#; + let request = format!( + "POST /v1/embeddings HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "embeddings response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "embeddings response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "embeddings response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "embeddings JSON body must be forwarded intact, got: {response}" + ); + } + + fn model_discovery_inference_route( + endpoint: String, + ) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["model_discovery".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// `GET /v1/models` (model discovery) returns one JSON object — a model + /// list — exactly like embeddings. It must be served buffered with + /// `Content-Length`, never through the chunked streaming path whose + /// truncation handlers would append an SSE `proxy_stream_error` frame into + /// the JSON body. This guards the framing classification for the protocol. + #[tokio::test] + async fn inference_model_discovery_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = + r#"{"object":"list","data":[{"id":"text-embedding-3-small","object":"model"}]}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + // GET model discovery carries no request body. + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "model discovery response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "model discovery response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "model discovery response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "model discovery JSON body must be forwarded intact, got: {response}" + ); + } + + /// `GET /v1/models/{id}` (model discovery glob) must forward the model id in + /// the path through the buffered path with the id intact, never streamed. + #[tokio::test] + async fn inference_model_discovery_glob_path_served_buffered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"id":"gpt-4.1","object":"model"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + let forwarded = read_forwarded_request_line(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + forwarded + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models/gpt-4.1 HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + let (method, forwarded_path) = upstream_task.await.unwrap(); + + assert_eq!(method, "GET"); + assert_eq!( + forwarded_path, "/v1/models/gpt-4.1", + "the model id in the glob path must be forwarded intact" + ); + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n") + && lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "glob model discovery must be buffered and Content-Length framed, got: {response}" + ); + } + + /// A failed model-discovery upstream must produce a buffered, Content-Length + /// framed JSON error, never a chunked SSE `proxy_stream_error` frame. + #[tokio::test] + async fn inference_model_discovery_error_served_buffered() { + // A port with no listener so the upstream connection is refused. + let dead_addr = { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr + }; + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{dead_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 5"), + "a refused upstream should yield a 5xx, got: {response}" + ); + assert!( + lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "buffered model-discovery error must be Content-Length framed JSON, got: {response}" + ); + assert!( + response.contains("error"), + "error response should carry a JSON error body, got: {response}" + ); + } + async fn read_forwarded_inference_request(stream: &mut S) { use crate::l7::inference::{ParseResult, try_parse_http_request}; @@ -5453,6 +5798,28 @@ network_policies: } } + /// Like [`read_forwarded_inference_request`] but returns the forwarded + /// request line (method, path) so a test can assert the upstream URL path. + async fn read_forwarded_request_line(stream: &mut S) -> (String, String) { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + let n = stream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before completion"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(req, _) => return (req.method, req.path), + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + } + async fn run_live_streaming_inference(serve_upstream: F) -> String where F: FnOnce(TcpStream) -> Fut + Send + 'static, diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 58b5feb2a..13496cd99 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -1284,6 +1284,7 @@ mod tests { "openai_chat_completions".to_string(), "openai_completions".to_string(), "openai_responses".to_string(), + "openai_embeddings".to_string(), "model_discovery".to_string(), ] ); diff --git a/docs/sandboxes/inference-routing.mdx b/docs/sandboxes/inference-routing.mdx index 3d8c48cd8..0a4e9d726 100644 --- a/docs/sandboxes/inference-routing.mdx +++ b/docs/sandboxes/inference-routing.mdx @@ -42,6 +42,7 @@ Supported request patterns depend on the provider configured for `inference.loca | Chat Completions | `POST` | `/v1/chat/completions` | | Completions | `POST` | `/v1/completions` | | Responses | `POST` | `/v1/responses` | +| Embeddings | `POST` | `/v1/embeddings` | | Model Discovery | `GET` | `/v1/models` | | Model Discovery | `GET` | `/v1/models/*` |