diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index ceb1ba2d9..5a810639e 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo apis: - agents - inference +- safety - telemetry - tool_runtime - vector_io @@ -30,6 +31,11 @@ providers: db: ${env.POSTGRES_DB:=llamastack} user: ${env.POSTGRES_USER:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -95,6 +101,8 @@ models: - model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 8dcad85e3..045093fe0 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -527,7 +527,7 @@ class InferenceRouter(Inference): # Store the response with the ID that will be returned to the client if self.store: - await self.store.store_chat_completion(response, messages) + asyncio.create_task(self.store.store_chat_completion(response, messages)) if self.telemetry: metrics = self._construct_metrics( @@ -855,4 +855,4 @@ class InferenceRouter(Inference): object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - await self.store.store_chat_completion(final_response, messages) + asyncio.create_task(self.store.store_chat_completion(final_response, messages)) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index b4e1c52ae..d3e875fec 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -141,6 +141,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, ConnectionError | httpx.ConnectError): + return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc)) elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index fb841afdf..50956f58c 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="fireworks", pip_packages=[ - "fireworks-ai<=0.18.0", + "fireworks-ai<=0.17.16", ], module="llama_stack.providers.remote.inference.fireworks", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 65ba2854b..9bd0aa8ce 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import base64 import struct from typing import TYPE_CHECKING @@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin: task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode( - [interleaved_content_as_str(content) for content in contents], show_progress_bar=False + embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id) + embeddings = await asyncio.to_thread( + embedding_model.encode, + [interleaved_content_as_str(content) for content in contents], + show_progress_bar=False, ) return EmbeddingsResponse(embeddings=embeddings) @@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin: # Get the model and generate embeddings model_obj = await self.model_store.get_model(model) - embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) - embeddings = embedding_model.encode(input_list, show_progress_bar=False) + embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False) # Convert embeddings to the requested format data = [] @@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin: usage=usage, ) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) @@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin: return loaded_model log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(model) + def _load_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(model) + + loaded_model = await asyncio.to_thread(_load_model) EMBEDDING_MODELS[model] = loaded_model return loaded_model diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 02f7aaf8a..fc8e2f377 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat raise AuthenticationRequiredError(exc) from exc if i == len(connection_strategies) - 1: raise + except* httpx.ConnectError as eg: + # Connection refused, server down, network unreachable + if i == len(connection_strategies) - 1: + error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused" + logger.error(f"MCP connection error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.TimeoutException as eg: + # Request timeout, server too slow + if i == len(connection_strategies) - 1: + error_msg = f"MCP server at {endpoint} timed out" + logger.error(f"MCP timeout error: {error_msg}") + raise TimeoutError(error_msg) from eg + else: + logger.warning( + f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.RequestError as eg: + # DNS resolution failures, network errors, invalid URLs + if i == len(connection_strategies) - 1: + # Get the first exception's message for the error string + exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error" + error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}" + logger.error(f"MCP network error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) except* McpError: if i < len(connection_strategies) - 1: logger.warning( diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 72137662d..62185e470 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +import time + import pytest from ..test_cases.test_case import TestCase @@ -323,8 +325,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea response_id = response.id content = response.choices[0].message.content - responses = client.chat.completions.list(limit=1000) - assert response_id in [r.id for r in responses.data] + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + assert tries < 10, f"Response {response_id} not found after 1 second" retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id @@ -388,6 +397,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode response_id = response.id content = response.choices[0].message.content + # wait for the response to be stored + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + + assert tries < 10, f"Response {response_id} not found after 1 second" + responses = client.chat.completions.list(limit=1000) assert response_id in [r.id for r in responses.data] diff --git a/tests/unit/server/test_server.py b/tests/unit/server/test_server.py index 803111fc7..f21bbdd67 100644 --- a/tests/unit/server/test_server.py +++ b/tests/unit/server/test_server.py @@ -113,6 +113,15 @@ class TestTranslateException: assert result.status_code == 504 assert result.detail == "Operation timed out: " + def test_translate_connection_error(self): + """Test that ConnectionError is translated to 502 HTTP status.""" + exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 502 + assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused" + def test_translate_not_implemented_error(self): """Test that NotImplementedError is translated to 501 HTTP status.""" exc = NotImplementedError("Not implemented")