From 140ee7d3372d930d39be01bfdbbb9c6a6d6e125f Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 1 Aug 2025 12:09:14 -0400 Subject: [PATCH 1/3] fix: sambanova inference provider (#2996) # What does this PR do? closes #2995 update SambaNovaInferenceAdapter to efficiently use LiteLLMOpenAIMixin ## Test Plan ``` $ uv run pytest -s -v tests/integration/inference --stack-config inference=sambanova --text-model sambanova/Meta-Llama-3.1-8B-Instruct ... ======================== 10 passed, 84 skipped, 3 xfailed, 51 warnings in 8.14s ======================== ``` --- .../remote/inference/sambanova/sambanova.py | 253 +----------------- .../utils/inference/litellm_openai_mixin.py | 12 +- .../utils/inference/openai_compat.py | 5 +- 3 files changed, 17 insertions(+), 253 deletions(-) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 8ba705f59..96469acac 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,178 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import Iterable - -import requests -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) - -from llama_stack.apis.common.content_types import ( - ImageContentItem, - InterleavedContent, - TextContentItem, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - JsonSchemaResponseFormat, - Message, - SystemMessage, - ToolChoice, - ToolResponseMessage, - UserMessage, -) -from llama_stack.apis.models import Model -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - convert_tooldef_to_openai_tool, - get_sampling_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") - - -async def convert_message_to_openai_dict_with_b64_images( - message: Message | dict, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage = None - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - tool_calls=[ - OpenAIChatCompletionMessageToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, - arguments=json.dumps(tool.arguments), - ), - type="function", - ) - for tool in message.tool_calls - ] - or None, - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): - _config: SambaNovaImplConfig - def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -185,89 +20,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): litellm_provider_name="sambanova", api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", + openai_compat_api_base=self.config.url, + download_images=True, # SambaNova requires base64 image encoding + json_schema_strict=False, # SambaNova doesn't support strict=True yet ) - - def _get_api_key(self) -> str: - config_api_key = self.config.api_key if self.config.api_key else None - if config_api_key: - return config_api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.sambanova_api_key: - raise ValueError( - 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' - ) - return provider_data.sambanova_api_key - - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} - - input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt, - "strict": False, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) - - provider_data = self.get_request_provider_data() - key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self._get_api_key() - - return { - "model": request.model, - "api_key": api_key, - "api_base": self.config.url, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - - list_models_url = self.config.url + "/models" - if len(self.environment_available_models) == 0: - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] - - if model_id.split("sambanova/")[-1] not in self.environment_available_models: - logger.warning(f"Model {model_id} not available in {list_models_url}") - return model - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index e9a41fcf3..befb4b092 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -72,6 +72,8 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + download_images: bool = False, + json_schema_strict: bool = True, ): """ Initialize the LiteLLMOpenAIMixin. @@ -81,6 +83,8 @@ class LiteLLMOpenAIMixin( :param provider_data_api_key_field: The field in the provider data that contains the API key. :param litellm_provider_name: The name of the provider, used for model lookups. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param download_images: Whether to download images and convert to base64 for message conversion. + :param json_schema_strict: Whether to use strict mode for JSON schema validation. """ ModelRegistryHelper.__init__(self, model_entries) @@ -88,6 +92,8 @@ class LiteLLMOpenAIMixin( self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field self.api_base = openai_compat_api_base + self.download_images = download_images + self.json_schema_strict = json_schema_strict if openai_compat_api_base: self.is_openai_compat = True @@ -206,7 +212,9 @@ class LiteLLMOpenAIMixin( async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} - input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] + input_dict["messages"] = [ + await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages + ] if fmt := request.response_format: if not isinstance(fmt, JsonSchemaResponseFormat): raise ValueError( @@ -226,7 +234,7 @@ class LiteLLMOpenAIMixin( "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": self.json_schema_strict, }, } if request.tools: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 47144ee0e..e6e5ccc8a 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -564,6 +564,7 @@ class UnparseableToolCall(BaseModel): async def convert_message_to_openai_dict_new( message: Message | dict, + download_images: bool = False, ) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. @@ -607,7 +608,9 @@ async def convert_message_to_openai_dict_new( elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), + image_url=OpenAIImageURL( + url=await convert_image_content_to_url(content_, download=download_images) + ), ) elif isinstance(content_, list): return [await impl(item) for item in content_] From 6ac710f3b09494e501dc09c8885ee1f31f9abcfc Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 1 Aug 2025 16:23:54 -0700 Subject: [PATCH 2/3] fix(recording): endpoint resolution (#3013) # What does this PR do? ## Test Plan --- llama_stack/testing/inference_recorder.py | 62 +- tests/integration/recordings/index.sqlite | Bin 45056 -> 45056 bytes .../recordings/responses/00ba04f74a96.json | 10 +- .../recordings/responses/0b27fd737699.json | 10 +- .../recordings/responses/0b3f2e4754ff.json | 235 + .../recordings/responses/0e8f2b001dd9.json | 56 + .../recordings/responses/10eea8c15ddc.json | 10 +- .../recordings/responses/17253d7cc667.json | 10 +- .../recordings/responses/173ecb3aab28.json | 251 + .../recordings/responses/174458ad71b2.json | 10 +- .../recordings/responses/178016edef0e.json | 10 +- .../recordings/responses/197228e26971.json | 10 +- .../recordings/responses/198ef7208389.json | 10 +- .../recordings/responses/1adfaa0e062e.json | 10 +- .../recordings/responses/1b8394f90636.json | 10 +- .../recordings/responses/211b1562d4e6.json | 10 +- .../recordings/responses/2afe3b38ca01.json | 34 +- .../recordings/responses/2d187a11704c.json | 208 +- .../recordings/responses/325a72db5755.json | 544 ++ .../recordings/responses/3c3f13cb7794.json | 30 +- .../recordings/responses/3ca695048bee.json | 104 + .../recordings/responses/40f524d1934a.json | 30 +- .../recordings/responses/48d2fb183a2a.json | 10 +- .../recordings/responses/4a3a4447b16b.json | 40 +- .../recordings/responses/50340cd4d253.json | 10 +- .../recordings/responses/545d86510a80.json | 34 +- .../recordings/responses/554de3cd986f.json | 46 +- .../recordings/responses/6906a6e71988.json | 10 +- .../recordings/responses/6cc063bbd7d3.json | 48 +- .../recordings/responses/6d35c91287e2.json | 34 +- .../recordings/responses/6f96090aa955.json | 648 +++ .../recordings/responses/6fbea1abca7c.json | 46 +- .../recordings/responses/6fe1d4fedf12.json | 4603 +++++++++++++++++ .../recordings/responses/70adef2c30c4.json | 10 +- .../recordings/responses/75d0dd9d0fa3.json | 10 +- .../recordings/responses/7b4815aba6c5.json | 46 +- .../recordings/responses/80e4404d8987.json | 28 +- .../recordings/responses/81a91f79c51d.json | 108 + .../recordings/responses/836f51dfb3c5.json | 10 +- .../recordings/responses/840fbb380b73.json | 10 +- .../recordings/responses/84cab42e1f5c.json | 974 ++-- .../recordings/responses/85594a69d74a.json | 10 +- .../recordings/responses/97d3812bfccb.json | 10 +- .../recordings/responses/97e259c0d3e5.json | 46 +- .../recordings/responses/9b812cbcb88d.json | 10 +- .../recordings/responses/9c140a29ae09.json | 34 +- .../recordings/responses/9e7a83d3d596.json | 12 +- .../recordings/responses/9fadf5a3d68f.json | 10 +- .../recordings/responses/a0c4df33879f.json | 1740 +++++++ .../recordings/responses/a4c8d19bb1eb.json | 56 + .../recordings/responses/a59d0d7c1485.json | 10 +- .../recordings/responses/a6810c23eda8.json | 94 +- .../recordings/responses/ae6835cfe70e.json | 10 +- .../recordings/responses/b91f1fb4aedb.json | 30 +- .../recordings/responses/bbd0637dce16.json | 466 +- .../recordings/responses/c9cba6f3ee38.json | 10 +- .../recordings/responses/d0ac68cbde69.json | 30 +- .../recordings/responses/d4c86ac355fb.json | 10 +- .../recordings/responses/dd226d71f844.json | 34 +- .../recordings/responses/dd9e7d5913e9.json | 12 +- .../recordings/responses/e96152610712.json | 10 +- .../recordings/responses/e9c8a0e4f0e0.json | 56 + .../recordings/responses/eee47930e3ae.json | 46 +- .../recordings/responses/ef59cbff54d0.json | 10 +- .../recordings/responses/f477c2fe1332.json | 50 +- .../recordings/responses/f70f30f54211.json | 84 + .../recordings/responses/fcdef245da95.json | 10 +- 67 files changed, 9880 insertions(+), 1409 deletions(-) create mode 100644 tests/integration/recordings/responses/0b3f2e4754ff.json create mode 100644 tests/integration/recordings/responses/0e8f2b001dd9.json create mode 100644 tests/integration/recordings/responses/173ecb3aab28.json create mode 100644 tests/integration/recordings/responses/325a72db5755.json create mode 100644 tests/integration/recordings/responses/3ca695048bee.json create mode 100644 tests/integration/recordings/responses/6f96090aa955.json create mode 100644 tests/integration/recordings/responses/6fe1d4fedf12.json create mode 100644 tests/integration/recordings/responses/81a91f79c51d.json create mode 100644 tests/integration/recordings/responses/a0c4df33879f.json create mode 100644 tests/integration/recordings/responses/a4c8d19bb1eb.json create mode 100644 tests/integration/recordings/responses/e9c8a0e4f0e0.json create mode 100644 tests/integration/recordings/responses/f70f30f54211.json diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index abfefa0ce..478f77773 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -217,55 +217,21 @@ class ResponseStorage: return cast(dict[str, Any], data) -async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs): +async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): global _current_mode, _current_storage if _current_mode == InferenceMode.LIVE or _current_storage is None: # Normal operation return await original_method(self, *args, **kwargs) - # Get base URL and endpoint based on client type + # Get base URL based on client type if client_type == "openai": base_url = str(self._client.base_url) - - # Determine endpoint based on the method's module/class path - method_str = str(original_method) - if "chat.completions" in method_str: - endpoint = "/v1/chat/completions" - elif "embeddings" in method_str: - endpoint = "/v1/embeddings" - elif "completions" in method_str: - endpoint = "/v1/completions" - else: - # Fallback - try to guess from the self object - if hasattr(self, "_resource") and hasattr(self._resource, "_resource"): - resource_name = getattr(self._resource._resource, "_resource", "unknown") - if "chat" in str(resource_name): - endpoint = "/v1/chat/completions" - elif "embeddings" in str(resource_name): - endpoint = "/v1/embeddings" - else: - endpoint = "/v1/completions" - else: - endpoint = "/v1/completions" - elif client_type == "ollama": # Get base URL from the client (Ollama client uses host attribute) base_url = getattr(self, "host", "http://localhost:11434") if not base_url.startswith("http"): base_url = f"http://{base_url}" - - # Determine endpoint based on method name - if method_name == "generate": - endpoint = "/api/generate" - elif method_name == "chat": - endpoint = "/api/chat" - elif method_name == "embed": - endpoint = "/api/embeddings" - elif method_name == "list": - endpoint = "/api/tags" - else: - endpoint = f"/api/{method_name}" else: raise ValueError(f"Unknown client type: {client_type}") @@ -366,14 +332,18 @@ def patch_inference_clients(): # Create patched methods for OpenAI client async def patched_chat_completions_create(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["chat_completions_create"], self, "openai", *args, **kwargs + _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs ) async def patched_completions_create(self, *args, **kwargs): - return await _patched_inference_method(_original_methods["completions_create"], self, "openai", *args, **kwargs) + return await _patched_inference_method( + _original_methods["completions_create"], self, "openai", "/v1/completions", *args, **kwargs + ) async def patched_embeddings_create(self, *args, **kwargs): - return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", *args, **kwargs) + return await _patched_inference_method( + _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs + ) # Apply OpenAI patches AsyncChatCompletions.create = patched_chat_completions_create @@ -383,30 +353,32 @@ def patch_inference_clients(): # Create patched methods for Ollama client async def patched_ollama_generate(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["ollama_generate"], self, "ollama", "generate", *args, **kwargs + _original_methods["ollama_generate"], self, "ollama", "/api/generate", *args, **kwargs ) async def patched_ollama_chat(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["ollama_chat"], self, "ollama", "chat", *args, **kwargs + _original_methods["ollama_chat"], self, "ollama", "/api/chat", *args, **kwargs ) async def patched_ollama_embed(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["ollama_embed"], self, "ollama", "embed", *args, **kwargs + _original_methods["ollama_embed"], self, "ollama", "/api/embeddings", *args, **kwargs ) async def patched_ollama_ps(self, *args, **kwargs): - return await _patched_inference_method(_original_methods["ollama_ps"], self, "ollama", "ps", *args, **kwargs) + return await _patched_inference_method( + _original_methods["ollama_ps"], self, "ollama", "/api/ps", *args, **kwargs + ) async def patched_ollama_pull(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["ollama_pull"], self, "ollama", "pull", *args, **kwargs + _original_methods["ollama_pull"], self, "ollama", "/api/pull", *args, **kwargs ) async def patched_ollama_list(self, *args, **kwargs): return await _patched_inference_method( - _original_methods["ollama_list"], self, "ollama", "list", *args, **kwargs + _original_methods["ollama_list"], self, "ollama", "/api/tags", *args, **kwargs ) # Apply Ollama patches diff --git a/tests/integration/recordings/index.sqlite b/tests/integration/recordings/index.sqlite index 2e1256d210d5818a229505ec22c6936ea3b4bf1e..72a1dae9ba02100162c1804f90ae7d35d5b62651 100644 GIT binary patch delta 4998 zcmbtX3yf6N8NPGxo%f5ktRgO2tGGPQz31H5l;SM7JQM`2;+CRvUuRL#b(UpC*Ro6z zW{ou63=C=uDaGiUC1 zzyCk~Iro2^HHXAChs3_~Gjf9<2ub`@|N1Z2M<6n#>!Eb^s32w@6fzHH|4 zktfoZ>BLt;JaSk_6ygWtEwLA34bioc!`Wq-J*T^13VCs`AV^$FQ(d5xm5u$zL8>U7!hL40>4;E%hezBnxi51gRwJT}NX0GWemg6YCMgyf8>;7MhNfA*uDQfkJ!<**@Aw*X43}vRwQa&3sxsGLj_r85%~e-3T}JhweBHED z-S#{0Uha3^)_&KD*4Y&W-^1QCRY@b zx(?2)QO__{msqOFJj$qM;AREn&%y>;Y*@eg_&&I{ZfS&>j_*5;tx~2_!@$AUJff=H z@+^;7wysj5*+}p`?kk4wsJ5k3SLHMye-<{VXv6(OHt5{aiA8K1H$l^FoRHa^JG$;N z2aO@VP6)FM+rSS~@D(Y#%5{TLS3zg-c~Jf=Y#?aEsv#R};xdm@Y8n>D`D{iMQ>9cT zF4r8(^Z#AGhMFweIg-*K7Y1m(ZI4XTMYbPf4~Z@H#Te8!1SY)eJY zb2(LXOJ{~bJ&kLM;VPZ-v$4>QH7O_ZHZ;WJP5uZ&rD-SIusbBrc1 z*%`tvA#xzLTUsZ*9@-?{CT@?d&F+j2gyu&ZvrVyS(Ib(K(YDM7nJx0-WM5`c_~lGC zT@)`)=OWjnwq-jYx3uws)MxQ0LI>lm^2qq`^lPDruq(7c5N656w779&9u^el!JjVA z_elS#LW)6sx006})$}>zI#V2%=&t3t%*N56x~FL9!i;c_+6px-su(Im4>lcx7}Ald z8sbo{`pnQ6=iDc{=PMM2Ir^Hbsg(OFK2bFfvz}s@rmyv2&T}Xu4p$Y74psWRsz=9U z;4zL#K~YW3(rjj7ND#;OT#g1|Rs+wRjJ$*86 zJ99RcepXvnAR&I2get8mZ+GmFC><@2^j-+Ts400-`Ud)%TB6dY)fr%oj?gFY`^FYg z+5-nT4u65hHXQatTV?n&474Hrof>^fAJX53Y0L1*Td=k=?1SCQHbkYHYlJ-`k=PVU)i5Er#KXL$nu_n?PEcKOb%naFZabBp z2irS_LZ;?(d2Ce;TssqeY8Jd#Xo*ObW_WggVP<9f%XfE74~x=4*nSUs+q*UL!sn5E z02WNZ`1%4K#z7tp?@hp%8c`)+CKaWz zDuV-WUeW@6-S0%D&YJ#1kE8repj?XF;jrpbRGh6n`KB2oFbea&k8(?4*lHxm@GjoT zB22ifMHDB&01nbua0CZgfN7Vv%n%_R-y)2XJB7&8k+Ie#3+8&>OAsyj!Gm8H9&d)P_fic%XCw(_7B?Al_bp8(lb?>`LD#qiQ; z&>zE`=erw@&xIaX;kH)x;<6?hgZFtJ#tqODb0C8v<0e=-P(McnT?8MW2E7qhy-?4s zc=-)*Ltpp&;LK`TZWhJ}ty1#gq>#8h{&wuQvCE>5MB~y9=@N017!Gv_9}2CJ`@;VU zFOlEK7BUAi&6TT`83%JCLqLw}(+RZOmDmGv}D#`PTP%EPWF z-+>*!{sqi=_UGlqZ+?KacT$Du19IGfSyJKdHjsxAy6k1N3>~J|Wwg{$rY?9!+JEs}dh47R8>)yq~!~+86mczB@iW z;>E^=-wd1bW;raii~9x(kBNAq?j1Yie&?EjgYtFmS2NJqbIzw=ym(B{&;)jh1lJn6|pmIY!Svy*{yBw=1&hk*qx k!Pyl&cl9P*-M1^K1w>Aadn_PTl4%xL;}v{y9#4}uS7I{*Lx delta 2333 zcmZ9Ndu$X{6o+T-ymxnY`VuLYWede3rS0y{?zW5z9Vkmx)DmNa)Jn?3((;lmVt|5l zVVCk|%i*C3f(AsWr7cWiM4+e`gM=a`@CO13CYS&MF$zJ98qe&s+m=aY&)o0)?%6Z< z&N;W~25GuM5<{F!kfJC%-V)xVWmGT&Nzr>7WaTf4I5$wv4a%9Lk-<)5y>lg<#whnu zM%k77$^+$Z<(83^_LK6AiFz+KmW7X8tPFo`@z61MmjXYI(mh0$ZI7u8>NaJINLQt> z*db;MYx#TpQtn?a!d_)7nKMiw*-h;9QWV{WBF_mnD#o(l_twVrl0wz%_LgXVzvlNv zGh&&r^7A+h+apRy)Kb)CKt|gDooedRU~gMemySAL5Q!{C<>LG@RWh@Jl}E<)M)3wQ ziXF}!Y0OUjo2vOeNNveE1*!-?&DP=T4z<5REy1W#Dr#4;4qKXgqNsuOQp)qZqz2{w3&FeGn1)VJ6grgCLE0q}+l}%SuWIg={U9G(9 zyybjN{$AD`2ONX#YixIIuS+MTXT^l@mr%o>JFx&_N++kveQVtoD6S zLCtG^K5n>hct9kt8EF{&*B1$2Q;fM%kzOs(;o4{YA_dHh#n@1ISr5YI_OP7VaP+8V z4IOQ|MGv+){-laXC&gb8&oOc4Jn0nID#wL(aMcWQI2HM<9F;YGilYrC)wrF~K506e zDW!9#P<4NsTC^%&anzo_fB!ub^ov@H2t_$QV z>p6Y94oZE6Kx(Y<#Y~j1hJDYWe3*q6A4D`93P&Sqve1lnL_^_(i83wJxF1mgyjXx4 zxF9iqGa+^08iRs4mZ17+9aKibn5F9T?BN>kx$~i*LKxi2bzlF7? z`35R;(TI?~9L;FBZjSO`pE)`XVs01;R#XsD5A%m383`v=Or}+MY!2L@k3drkiAFSU z!-WxH$XR(=AW>^sZp^?Os=zxEy(1vHD$J8y=daY{mdM0pp53d^pAE|U=${R5m}U@s zWST6<%}4bZtoxMO$Vasu9_FJz145(1P6FsRFB)&&`p&P>f{deQsHq@Kbrt*hJv_@T zU@x&D<^ba$P4ol0j=DxoHKz@{qGCJ*%eM73;XtVt7>o4zROL;m!zLW-t2|SIH5}J~ zjrY~CNlTE9g~SsSt2AB;IqMoz6>4e&M(a!8vLsH!#c5{uki?Q?TZ~gsJKX|7RAEM2eUt>r^P*RO_&s9d>h%=7Gig5=QrqT-tau;iBRG&TYo>mb1i=c83V8*p)CY ze>jcc;T|uC1DpEo*jNqUT#Qcy&*p(NfOG Date: Fri, 1 Aug 2025 17:38:49 -0700 Subject: [PATCH 3/3] chore: create integration-tests script (#3016) --- .../actions/run-and-record-tests/action.yml | 128 +--------- scripts/integration-tests.sh | 240 ++++++++++++++++++ 2 files changed, 246 insertions(+), 122 deletions(-) create mode 100755 scripts/integration-tests.sh diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index a6acc5ce6..573148e46 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -29,132 +29,16 @@ runs: free -h df -h - - name: Set environment variables - shell: bash - run: | - echo "LLAMA_STACK_CLIENT_TIMEOUT=300" >> $GITHUB_ENV - echo "LLAMA_STACK_TEST_INFERENCE_MODE=${{ inputs.inference-mode }}" >> $GITHUB_ENV - - # Configure provider-specific settings - if [ "${{ inputs.provider }}" == "ollama" ]; then - echo "OLLAMA_URL=http://0.0.0.0:11434" >> $GITHUB_ENV - echo "TEXT_MODEL=ollama/llama3.2:3b-instruct-fp16" >> $GITHUB_ENV - echo "SAFETY_MODEL=ollama/llama-guard3:1b" >> $GITHUB_ENV - else - echo "VLLM_URL=http://localhost:8000/v1" >> $GITHUB_ENV - echo "TEXT_MODEL=vllm/meta-llama/Llama-3.2-1B-Instruct" >> $GITHUB_ENV - fi - - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - echo "LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings/vision" >> $GITHUB_ENV - else - echo "LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings" >> $GITHUB_ENV - fi - - - name: Run Llama Stack Server - if: ${{ contains(inputs.stack-config, 'server:') }} - shell: bash - run: | - # Run this so pytest in a loop doesn't start-stop servers in a loop - echo "Starting Llama Stack Server" - nohup uv run llama stack run ci-tests --image-type venv > server.log 2>&1 & - - echo "Waiting for Llama Stack Server to start" - for i in {1..30}; do - if curl -s http://localhost:8321/v1/health | grep -q "OK"; then - echo "Llama Stack Server started" - exit 0 - fi - sleep 1 - done - - echo "Llama Stack Server failed to start" - cat server.log - exit 1 - - name: Run Integration Tests shell: bash run: | - stack_config="${{ inputs.stack-config }}" - EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" + ./scripts/integration-tests.sh \ + --stack-config '${{ inputs.stack-config }}' \ + --provider '${{ inputs.provider }}' \ + --test-types '${{ inputs.test-types }}' \ + --inference-mode '${{ inputs.inference-mode }}' \ + ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} - # Configure provider-specific settings - if [ "${{ inputs.provider }}" == "ollama" ]; then - EXTRA_PARAMS="--safety-shield=llama-guard" - else - EXTRA_PARAMS="" - EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" - fi - - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - if uv run pytest -s -v tests/integration/inference/test_vision_inference.py --stack-config=${stack_config} \ - -k "not( ${EXCLUDE_TESTS} )" \ - --vision-model=ollama/llama3.2-vision:11b \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes ${EXTRA_PARAMS} \ - --capture=tee-sys | tee pytest-${{ inputs.inference-mode }}-vision.log; then - echo "✅ Tests completed for vision" - else - echo "❌ Tests failed for vision" - exit 1 - fi - - exit 0 - fi - - # Run non-vision tests - TEST_TYPES='${{ inputs.test-types }}' - echo "Test types to run: $TEST_TYPES" - - # Collect all test files for the specified test types - TEST_FILES="" - for test_type in $(echo "$TEST_TYPES" | jq -r '.[]'); do - # if provider is vllm, exclude the following tests: (safety, post_training, tool_runtime) - if [ "${{ inputs.provider }}" == "vllm" ]; then - if [ "$test_type" == "safety" ] || [ "$test_type" == "post_training" ] || [ "$test_type" == "tool_runtime" ]; then - echo "Skipping $test_type for vllm provider" - continue - fi - fi - - if [ -d "tests/integration/$test_type" ]; then - # Find all Python test files in this directory - test_files=$(find tests/integration/$test_type -name "test_*.py" -o -name "*_test.py") - if [ -n "$test_files" ]; then - TEST_FILES="$TEST_FILES $test_files" - echo "Added test files from $test_type: $(echo $test_files | wc -w) files" - fi - else - echo "Warning: Directory tests/integration/$test_type does not exist" - fi - done - - if [ -z "$TEST_FILES" ]; then - echo "No test files found for the specified test types" - exit 1 - fi - - echo "=== Running all collected tests in a single pytest command ===" - echo "Total test files: $(echo $TEST_FILES | wc -w)" - - if uv run pytest -s -v $TEST_FILES --stack-config=${stack_config} \ - -k "not( ${EXCLUDE_TESTS} )" \ - --text-model=$TEXT_MODEL \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes ${EXTRA_PARAMS} \ - --capture=tee-sys | tee pytest-${{ inputs.inference-mode }}-all.log; then - echo "✅ All tests completed successfully" - else - echo "❌ Tests failed" - exit 1 - fi - - - name: Check Storage and Memory Available After Tests - if: ${{ always() }} - shell: bash - run: | - free -h - df -h - name: Commit and push recordings if: ${{ inputs.inference-mode == 'record' }} diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh new file mode 100755 index 000000000..08d16db51 --- /dev/null +++ b/scripts/integration-tests.sh @@ -0,0 +1,240 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -e + +# Integration test runner script for Llama Stack +# This script extracts the integration test logic from GitHub Actions +# to allow developers to run integration tests locally + +# Default values +STACK_CONFIG="" +PROVIDER="" +TEST_TYPES='["inference"]' +RUN_VISION_TESTS="false" +INFERENCE_MODE="replay" +EXTRA_PARAMS="" + +# Function to display usage +usage() { + cat << EOF +Usage: $0 [OPTIONS] + +Options: + --stack-config STRING Stack configuration to use (required) + --provider STRING Provider to use (ollama, vllm, etc.) (required) + --test-types JSON JSON array of test types to run (default: '["inference"]') + --run-vision-tests Run vision tests instead of regular tests + --inference-mode STRING Inference mode: record or replay (default: replay) + --help Show this help message + +Examples: + # Basic inference tests with ollama + $0 --stack-config server:ollama --provider ollama + + # Multiple test types with vllm + $0 --stack-config server:vllm --provider vllm --test-types '["inference", "agents"]' + + # Vision tests with ollama + $0 --stack-config server:ollama --provider ollama --run-vision-tests + + # Record mode for updating test recordings + $0 --stack-config server:ollama --provider ollama --inference-mode record +EOF +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --stack-config) + STACK_CONFIG="$2" + shift 2 + ;; + --provider) + PROVIDER="$2" + shift 2 + ;; + --test-types) + TEST_TYPES="$2" + shift 2 + ;; + --run-vision-tests) + RUN_VISION_TESTS="true" + shift + ;; + --inference-mode) + INFERENCE_MODE="$2" + shift 2 + ;; + --help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" + usage + exit 1 + ;; + esac +done + + +# Validate required parameters +if [[ -z "$STACK_CONFIG" ]]; then + echo "Error: --stack-config is required" + usage + exit 1 +fi + +if [[ -z "$PROVIDER" ]]; then + echo "Error: --provider is required" + usage + exit 1 +fi + +echo "=== Llama Stack Integration Test Runner ===" +echo "Stack Config: $STACK_CONFIG" +echo "Provider: $PROVIDER" +echo "Test Types: $TEST_TYPES" +echo "Vision Tests: $RUN_VISION_TESTS" +echo "Inference Mode: $INFERENCE_MODE" +echo "" + +# Check storage and memory before tests +echo "=== System Resources Before Tests ===" +free -h 2>/dev/null || echo "free command not available" +df -h +echo "" + +# Set environment variables +export LLAMA_STACK_CLIENT_TIMEOUT=300 +export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE" + +# Configure provider-specific settings +if [[ "$PROVIDER" == "ollama" ]]; then + export OLLAMA_URL="http://0.0.0.0:11434" + export TEXT_MODEL="ollama/llama3.2:3b-instruct-fp16" + export SAFETY_MODEL="ollama/llama-guard3:1b" + EXTRA_PARAMS="--safety-shield=llama-guard" +else + export VLLM_URL="http://localhost:8000/v1" + export TEXT_MODEL="vllm/meta-llama/Llama-3.2-1B-Instruct" + EXTRA_PARAMS="" +fi + +# Set recording directory +if [[ "$RUN_VISION_TESTS" == "true" ]]; then + export LLAMA_STACK_TEST_RECORDING_DIR="tests/integration/recordings/vision" +else + export LLAMA_STACK_TEST_RECORDING_DIR="tests/integration/recordings" +fi + +# Start Llama Stack Server if needed +if [[ "$STACK_CONFIG" == *"server:"* ]]; then + echo "=== Starting Llama Stack Server ===" + nohup uv run llama stack run ci-tests --image-type venv > server.log 2>&1 & + + echo "Waiting for Llama Stack Server to start..." + for i in {1..30}; do + if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then + echo "✅ Llama Stack Server started successfully" + break + fi + if [[ $i -eq 30 ]]; then + echo "❌ Llama Stack Server failed to start" + echo "Server logs:" + cat server.log + exit 1 + fi + sleep 1 + done + echo "" +fi + +# Run tests +echo "=== Running Integration Tests ===" +EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" + +# Additional exclusions for vllm provider +if [[ "$PROVIDER" == "vllm" ]]; then + EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" +fi + +# Run vision tests if specified +if [[ "$RUN_VISION_TESTS" == "true" ]]; then + echo "Running vision tests..." + if uv run pytest -s -v tests/integration/inference/test_vision_inference.py \ + --stack-config="$STACK_CONFIG" \ + -k "not( $EXCLUDE_TESTS )" \ + --vision-model=ollama/llama3.2-vision:11b \ + --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ + --color=yes $EXTRA_PARAMS \ + --capture=tee-sys | tee pytest-${INFERENCE_MODE}-vision.log; then + echo "✅ Vision tests completed successfully" + else + echo "❌ Vision tests failed" + exit 1 + fi + exit 0 +fi + +# Run regular tests +echo "Test types to run: $TEST_TYPES" + +# Collect all test files for the specified test types +TEST_FILES="" +for test_type in $(echo "$TEST_TYPES" | jq -r '.[]'); do + # Skip certain test types for vllm provider + if [[ "$PROVIDER" == "vllm" ]]; then + if [[ "$test_type" == "safety" ]] || [[ "$test_type" == "post_training" ]] || [[ "$test_type" == "tool_runtime" ]]; then + echo "Skipping $test_type for vllm provider" + continue + fi + fi + + if [[ -d "tests/integration/$test_type" ]]; then + # Find all Python test files in this directory + test_files=$(find tests/integration/$test_type -name "test_*.py" -o -name "*_test.py") + if [[ -n "$test_files" ]]; then + TEST_FILES="$TEST_FILES $test_files" + echo "Added test files from $test_type: $(echo $test_files | wc -w) files" + fi + else + echo "Warning: Directory tests/integration/$test_type does not exist" + fi +done + +if [[ -z "$TEST_FILES" ]]; then + echo "No test files found for the specified test types" + exit 1 +fi + +echo "" +echo "=== Running all collected tests in a single pytest command ===" +echo "Total test files: $(echo $TEST_FILES | wc -w)" + +if uv run pytest -s -v $TEST_FILES \ + --stack-config="$STACK_CONFIG" \ + -k "not( $EXCLUDE_TESTS )" \ + --text-model="$TEXT_MODEL" \ + --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ + --color=yes $EXTRA_PARAMS \ + --capture=tee-sys | tee pytest-${INFERENCE_MODE}-all.log; then + echo "✅ All tests completed successfully" +else + echo "❌ Tests failed" + exit 1 +fi + +# Check storage and memory after tests +echo "" +echo "=== System Resources After Tests ===" +free -h 2>/dev/null || echo "free command not available" +df -h + +echo "" +echo "=== Integration Tests Complete ==="