From 63e6acd0c3a229555746e92ff36458de4c18f01c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 25 Feb 2025 22:07:33 -0800 Subject: [PATCH] feat: add (openai, anthropic, gemini) providers via litellm (#1267) # What does this PR do? This PR introduces more non-llama model support to llama stack. Providers introduced: openai, anthropic and gemini. All of these providers use essentially the same piece of code -- the implementation works via the `litellm` library. We will expose only specific models for providers we enable making sure they all work well and pass tests. This setup (instead of automatically enabling _all_ providers and models allowed by LiteLLM) ensures we can also perform any needed prompt tuning on a per-model basis as needed (just like we do it for llama models.) ## Test Plan ```bash #!/bin/bash args=("$@") for model in openai/gpt-4o anthropic/claude-3-5-sonnet-latest gemini/gemini-1.5-flash; do LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --embedding-model=all-MiniLM-L6-v2 \ --vision-inference-model="" \ --inference-model=$model "${args[@]}" done ``` --- distributions/dependencies.json | 36 +++ llama_stack/providers/registry/inference.py | 27 ++ .../remote/inference/anthropic/__init__.py | 23 ++ .../remote/inference/anthropic/anthropic.py | 22 ++ .../remote/inference/anthropic/config.py | 25 ++ .../remote/inference/anthropic/models.py | 35 +++ .../remote/inference/fireworks/config.py | 4 +- .../remote/inference/gemini/__init__.py | 23 ++ .../remote/inference/gemini/config.py | 25 ++ .../remote/inference/gemini/gemini.py | 22 ++ .../remote/inference/gemini/models.py | 24 ++ .../remote/inference/openai/__init__.py | 23 ++ .../remote/inference/openai/config.py | 25 ++ .../remote/inference/openai/models.py | 30 ++ .../remote/inference/openai/openai.py | 22 ++ .../test_cases/inference/chat_completion.json | 4 +- .../utils/inference/litellm_openai_mixin.py | 171 ++++++++++++ llama_stack/templates/ci-tests/ci_tests.py | 23 +- llama_stack/templates/ci-tests/run.yaml | 11 - llama_stack/templates/dev/__init__.py | 7 + llama_stack/templates/dev/build.yaml | 36 +++ llama_stack/templates/dev/dev.py | 174 ++++++++++++ llama_stack/templates/dev/run.yaml | 261 ++++++++++++++++++ tests/client-sdk/conftest.py | 8 +- .../inference/test_text_inference.py | 20 +- 25 files changed, 1048 insertions(+), 33 deletions(-) create mode 100644 llama_stack/providers/remote/inference/anthropic/__init__.py create mode 100644 llama_stack/providers/remote/inference/anthropic/anthropic.py create mode 100644 llama_stack/providers/remote/inference/anthropic/config.py create mode 100644 llama_stack/providers/remote/inference/anthropic/models.py create mode 100644 llama_stack/providers/remote/inference/gemini/__init__.py create mode 100644 llama_stack/providers/remote/inference/gemini/config.py create mode 100644 llama_stack/providers/remote/inference/gemini/gemini.py create mode 100644 llama_stack/providers/remote/inference/gemini/models.py create mode 100644 llama_stack/providers/remote/inference/openai/__init__.py create mode 100644 llama_stack/providers/remote/inference/openai/config.py create mode 100644 llama_stack/providers/remote/inference/openai/models.py create mode 100644 llama_stack/providers/remote/inference/openai/openai.py create mode 100644 llama_stack/providers/utils/inference/litellm_openai_mixin.py create mode 100644 llama_stack/templates/dev/__init__.py create mode 100644 llama_stack/templates/dev/build.yaml create mode 100644 llama_stack/templates/dev/dev.py create mode 100644 llama_stack/templates/dev/run.yaml diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 1ddedc148..b5adad332 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -136,6 +136,42 @@ "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" ], + "dev": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "fastapi", + "fire", + "fireworks-ai", + "httpx", + "litellm", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "sqlite-vec", + "tqdm", + "transformers", + "uvicorn", + "sentence-transformers --no-deps", + "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + ], "fireworks": [ "aiosqlite", "autoevals", diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index b0402f6a5..3ba634e9a 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -207,6 +207,33 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="openai", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.openai", + config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="anthropic", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.anthropic", + config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", + ), + ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="gemini", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.gemini", + config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/anthropic/__init__.py b/llama_stack/providers/remote/inference/anthropic/__init__.py new file mode 100644 index 000000000..3075f856e --- /dev/null +++ b/llama_stack/providers/remote/inference/anthropic/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + +from .config import AnthropicConfig + + +class AnthropicProviderDataValidator(BaseModel): + anthropic_api_key: Optional[str] = None + + +async def get_adapter_impl(config: AnthropicConfig, _deps): + from .anthropic import AnthropicInferenceAdapter + + impl = AnthropicInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/anthropic/anthropic.py b/llama_stack/providers/remote/inference/anthropic/anthropic.py new file mode 100644 index 000000000..2b392b295 --- /dev/null +++ b/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + +from .config import AnthropicConfig +from .models import MODEL_ENTRIES + + +class AnthropicInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: AnthropicConfig) -> None: + LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/providers/remote/inference/anthropic/config.py b/llama_stack/providers/remote/inference/anthropic/config.py new file mode 100644 index 000000000..00323b1e7 --- /dev/null +++ b/llama_stack/providers/remote/inference/anthropic/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class AnthropicConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="API key for Anthropic models", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> Dict[str, Any]: + return { + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/inference/anthropic/models.py b/llama_stack/providers/remote/inference/anthropic/models.py new file mode 100644 index 000000000..39cb64440 --- /dev/null +++ b/llama_stack/providers/remote/inference/anthropic/models.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +LLM_MODEL_IDS = [ + "anthropic/claude-3-5-sonnet-latest", + "anthropic/claude-3-7-sonnet-latest", + "anthropic/claude-3-5-haiku-latest", +] + + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ + ProviderModelEntry( + provider_model_id="anthropic/voyage-3", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 1024, "context_length": 32000}, + ), + ProviderModelEntry( + provider_model_id="anthropic/voyage-3-lite", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 512, "context_length": 32000}, + ), + ProviderModelEntry( + provider_model_id="anthropic/voyage-code-3", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 1024, "context_length": 32000}, + ), +] diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 005dfe829..c21ce4a40 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -23,8 +23,8 @@ class FireworksImplConfig(BaseModel): ) @classmethod - def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]: return { "url": "https://api.fireworks.ai/inference/v1", - "api_key": "${env.FIREWORKS_API_KEY}", + "api_key": api_key, } diff --git a/llama_stack/providers/remote/inference/gemini/__init__.py b/llama_stack/providers/remote/inference/gemini/__init__.py new file mode 100644 index 000000000..dd972f21c --- /dev/null +++ b/llama_stack/providers/remote/inference/gemini/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + +from .config import GeminiConfig + + +class GeminiProviderDataValidator(BaseModel): + gemini_api_key: Optional[str] = None + + +async def get_adapter_impl(config: GeminiConfig, _deps): + from .gemini import GeminiInferenceAdapter + + impl = GeminiInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/gemini/config.py b/llama_stack/providers/remote/inference/gemini/config.py new file mode 100644 index 000000000..cce8c756c --- /dev/null +++ b/llama_stack/providers/remote/inference/gemini/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class GeminiConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="API key for Gemini models", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> Dict[str, Any]: + return { + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py new file mode 100644 index 000000000..b269bc14a --- /dev/null +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + +from .config import GeminiConfig +from .models import MODEL_ENTRIES + + +class GeminiInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: GeminiConfig) -> None: + LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/providers/remote/inference/gemini/models.py b/llama_stack/providers/remote/inference/gemini/models.py new file mode 100644 index 000000000..1d7b47315 --- /dev/null +++ b/llama_stack/providers/remote/inference/gemini/models.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +LLM_MODEL_IDS = [ + "gemini/gemini-1.5-flash", + "gemini/gemini-1.5-pro", +] + + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ + ProviderModelEntry( + provider_model_id="gemini/text-embedding-004", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 768, "context_length": 2048}, + ), +] diff --git a/llama_stack/providers/remote/inference/openai/__init__.py b/llama_stack/providers/remote/inference/openai/__init__.py new file mode 100644 index 000000000..000a03d33 --- /dev/null +++ b/llama_stack/providers/remote/inference/openai/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from pydantic import BaseModel + +from .config import OpenAIConfig + + +class OpenAIProviderDataValidator(BaseModel): + openai_api_key: Optional[str] = None + + +async def get_adapter_impl(config: OpenAIConfig, _deps): + from .openai import OpenAIInferenceAdapter + + impl = OpenAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/openai/config.py b/llama_stack/providers/remote/inference/openai/config.py new file mode 100644 index 000000000..07f96a3df --- /dev/null +++ b/llama_stack/providers/remote/inference/openai/config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class OpenAIConfig(BaseModel): + api_key: Optional[str] = Field( + default=None, + description="API key for OpenAI models", + ) + + @classmethod + def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> Dict[str, Any]: + return { + "api_key": api_key, + } diff --git a/llama_stack/providers/remote/inference/openai/models.py b/llama_stack/providers/remote/inference/openai/models.py new file mode 100644 index 000000000..657895f27 --- /dev/null +++ b/llama_stack/providers/remote/inference/openai/models.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.models.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +LLM_MODEL_IDS = [ + "openai/gpt-4o", + "openai/gpt-4o-mini", + "openai/chatgpt-4o-latest", +] + + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ + ProviderModelEntry( + provider_model_id="openai/text-embedding-3-small", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 1536}, + ), + ProviderModelEntry( + provider_model_id="openai/text-embedding-3-large", + model_type=ModelType.embedding, + metadata={"embedding_dimension": 3072}, + ), +] diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py new file mode 100644 index 000000000..80ab2943f --- /dev/null +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin + +from .config import OpenAIConfig +from .models import MODEL_ENTRIES + + +class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: OpenAIConfig) -> None: + LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES) + self.config = config + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/providers/tests/test_cases/inference/chat_completion.json b/llama_stack/providers/tests/test_cases/inference/chat_completion.json index 5e302b4fe..50f6b1c15 100644 --- a/llama_stack/providers/tests/test_cases/inference/chat_completion.json +++ b/llama_stack/providers/tests/test_cases/inference/chat_completion.json @@ -40,7 +40,7 @@ "tool_calling": { "data": { "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "Pretend you are a weather assistant."}, {"role": "user", "content": "What's the weather like in San Francisco?"} ], "tools": [ @@ -65,7 +65,7 @@ "messages": [ { "role": "system", - "content": "You are a helpful assistant." + "content": "Pretend you are a weather assistant." }, { "role": "user", diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py new file mode 100644 index 000000000..0f53b5b88 --- /dev/null +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import AsyncGenerator, AsyncIterator, List, Optional, Union + +import litellm + +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + EmbeddingsResponse, + EmbeddingTaskType, + Inference, + JsonSchemaResponseFormat, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + TextTruncation, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models.models import Model +from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, +) +from llama_stack.providers.utils.inference.openai_compat import ( + convert_message_to_openai_dict_new, + convert_openai_chat_completion_choice, + convert_openai_chat_completion_stream, + convert_tooldef_to_openai_tool, + get_sampling_options, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) + + +class LiteLLMOpenAIMixin( + ModelRegistryHelper, + Inference, +): + def __init__(self, model_entries) -> None: + self.model_entries = model_entries + ModelRegistryHelper.__init__(self, model_entries) + + async def register_model(self, model: Model) -> Model: + model_id = self.get_provider_model_id(model.provider_resource_id) + if model_id is None: + raise ValueError(f"Unsupported model: {model.provider_resource_id}") + return model + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = SamplingParams(), + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + raise NotImplementedError("LiteLLM does not support completion requests") + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + stream=stream, + logprobs=logprobs, + tool_config=tool_config, + ) + + params = await self._get_params(request) + + # unfortunately, we need to use synchronous litellm.completion here because litellm + # caches various httpx.client objects in a non-eventloop aware manner + response = litellm.completion(**params) + if stream: + return self._stream_chat_completion(response) + else: + return convert_openai_chat_completion_choice(response.choices[0]) + + async def _stream_chat_completion( + self, response: litellm.ModelResponse + ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: + async def _stream_generator(): + for chunk in response: + yield chunk + + async for chunk in convert_openai_chat_completion_stream( + _stream_generator(), enable_incremental_tool_calls=True + ): + yield chunk + + async def _get_params(self, request: ChatCompletionRequest) -> dict: + input_dict = {} + + input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] + if fmt := request.response_format: + if not isinstance(fmt, JsonSchemaResponseFormat): + raise ValueError( + f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." + ) + + fmt = fmt.json_schema + name = fmt["title"] + del fmt["title"] + fmt["additionalProperties"] = False + input_dict["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": name, + "schema": fmt, + "strict": True, + }, + } + if request.tools: + input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] + if request.tool_config.tool_choice: + input_dict["tool_choice"] = request.tool_config.tool_choice.value + + return { + "model": request.model, + **input_dict, + "stream": request.stream, + **get_sampling_options(request.sampling_params), + } + + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: + model = await self.model_store.get_model(model_id) + + response = litellm.embedding( + model=model.provider_resource_id, + input=[interleaved_content_as_str(content) for content in contents], + ) + + embeddings = [data["embedding"] for data in response["data"]] + return EmbeddingsResponse(embeddings=embeddings) diff --git a/llama_stack/templates/ci-tests/ci_tests.py b/llama_stack/templates/ci-tests/ci_tests.py index 992d9936e..a93cfff9c 100644 --- a/llama_stack/templates/ci-tests/ci_tests.py +++ b/llama_stack/templates/ci-tests/ci_tests.py @@ -57,17 +57,6 @@ def get_distribution_template() -> DistributionTemplate: config=SentenceTransformersInferenceConfig.sample_run_config(), ) - core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} - default_models = [ - ModelInput( - model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, - provider_model_id=m.provider_model_id, - provider_id="fireworks", - metadata=m.metadata, - model_type=m.model_type, - ) - for m in MODEL_ENTRIES - ] default_tool_groups = [ ToolGroupInput( toolgroup_id="builtin::websearch", @@ -82,6 +71,16 @@ def get_distribution_template() -> DistributionTemplate: provider_id="code-interpreter", ), ] + core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} + default_models = [ + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, + provider_id="fireworks", + model_type=m.model_type, + metadata=m.metadata, + ) + for m in MODEL_ENTRIES + ] embedding_model = ModelInput( model_id="all-MiniLM-L6-v2", provider_id="sentence-transformers", @@ -98,7 +97,7 @@ def get_distribution_template() -> DistributionTemplate: container_image=None, template_path=None, providers=providers, - default_models=default_models, + default_models=default_models + [embedding_model], run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 6696c8041..295d72e71 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -93,59 +93,48 @@ models: - metadata: {} model_id: meta-llama/Llama-3.1-8B-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-70B-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-1B-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-11B-Vision-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.2-90B-Vision-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-3.3-70B-Instruct provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct model_type: llm - metadata: {} model_id: meta-llama/Llama-Guard-3-8B provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-8b model_type: llm - metadata: {} model_id: meta-llama/Llama-Guard-3-11B-Vision provider_id: fireworks - provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision model_type: llm - metadata: embedding_dimension: 768 context_length: 8192 model_id: nomic-ai/nomic-embed-text-v1.5 provider_id: fireworks - provider_model_id: nomic-ai/nomic-embed-text-v1.5 model_type: embedding - metadata: embedding_dimension: 384 diff --git a/llama_stack/templates/dev/__init__.py b/llama_stack/templates/dev/__init__.py new file mode 100644 index 000000000..cf966c2a6 --- /dev/null +++ b/llama_stack/templates/dev/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .dev import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/dev/build.yaml b/llama_stack/templates/dev/build.yaml new file mode 100644 index 000000000..96f588e8d --- /dev/null +++ b/llama_stack/templates/dev/build.yaml @@ -0,0 +1,36 @@ +version: '2' +distribution_spec: + description: Distribution for running e2e tests in CI + providers: + inference: + - remote::openai + - remote::fireworks + - remote::anthropic + - remote::gemini + - inline::sentence-transformers + vector_io: + - inline::sqlite-vec + - remote::chromadb + - remote::pgvector + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda diff --git a/llama_stack/templates/dev/dev.py b/llama_stack/templates/dev/dev.py new file mode 100644 index 000000000..7b449a0b4 --- /dev/null +++ b/llama_stack/templates/dev/dev.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List, Tuple + +from llama_stack.apis.models.models import ModelType +from llama_stack.distribution.datatypes import ( + ModelInput, + Provider, + ShieldInput, + ToolGroupInput, +) +from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.providers.inline.inference.sentence_transformers import ( + SentenceTransformersInferenceConfig, +) +from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig +from llama_stack.providers.remote.inference.anthropic.models import MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES +from llama_stack.providers.remote.inference.fireworks.config import FireworksImplConfig +from llama_stack.providers.remote.inference.fireworks.models import MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES +from llama_stack.providers.remote.inference.gemini.config import GeminiConfig +from llama_stack.providers.remote.inference.gemini.models import MODEL_ENTRIES as GEMINI_MODEL_ENTRIES +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.models import MODEL_ENTRIES as OPENAI_MODEL_ENTRIES +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings + + +def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: + # in this template, we allow each API key to be optional + providers = [ + ( + "openai", + OPENAI_MODEL_ENTRIES, + OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"), + ), + ( + "fireworks", + FIREWORKS_MODEL_ENTRIES, + FireworksImplConfig.sample_run_config(api_key="${env.FIREWORKS_API_KEY:}"), + ), + ( + "anthropic", + ANTHROPIC_MODEL_ENTRIES, + AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"), + ), + ( + "gemini", + GEMINI_MODEL_ENTRIES, + GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"), + ), + ] + inference_providers = [] + default_models = [] + core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} + for provider_id, model_entries, config in providers: + inference_providers.append( + Provider( + provider_id=provider_id, + provider_type=f"remote::{provider_id}", + config=config, + ) + ) + default_models.extend( + ModelInput( + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, + provider_model_id=m.provider_model_id, + provider_id=provider_id, + model_type=m.model_type, + metadata=m.metadata, + ) + for m in model_entries + ) + return inference_providers, default_models + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": [ + "remote::openai", + "remote::fireworks", + "remote::anthropic", + "remote::gemini", + "inline::sentence-transformers", + ], + "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + name = "dev" + + vector_io_provider = Provider( + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", + config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), + ) + embedding_provider = Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + config=SentenceTransformersInferenceConfig.sample_run_config(), + ) + + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + embedding_model = ModelInput( + model_id="all-MiniLM-L6-v2", + provider_id=embedding_provider.provider_id, + model_type=ModelType.embedding, + metadata={ + "embedding_dimension": 384, + }, + ) + inference_providers, default_models = get_inference_providers() + + return DistributionTemplate( + name=name, + distro_type="self_hosted", + description="Distribution for running e2e tests in CI", + container_image=None, + template_path=None, + providers=providers, + default_models=[], + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": inference_providers + [embedding_provider], + "vector_io": [vector_io_provider], + }, + default_models=default_models + [embedding_model], + default_tool_groups=default_tool_groups, + default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], + ), + }, + run_config_env_vars={ + "LLAMA_STACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "FIREWORKS_API_KEY": ( + "", + "Fireworks API Key", + ), + "OPENAI_API_KEY": ( + "", + "OpenAI API Key", + ), + }, + ) diff --git a/llama_stack/templates/dev/run.yaml b/llama_stack/templates/dev/run.yaml new file mode 100644 index 000000000..ab54f1a57 --- /dev/null +++ b/llama_stack/templates/dev/run.yaml @@ -0,0 +1,261 @@ +version: '2' +image_name: dev +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY:} + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:} + - provider_id: anthropic + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY:} + - provider_id: gemini + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY:} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: {} + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: {} + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: {} + - provider_id: localfs + provider_type: inline::localfs + config: {} + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db +models: +- metadata: {} + model_id: openai/gpt-4o + provider_id: openai + provider_model_id: openai/gpt-4o + model_type: llm +- metadata: {} + model_id: openai/gpt-4o-mini + provider_id: openai + provider_model_id: openai/gpt-4o-mini + model_type: llm +- metadata: {} + model_id: openai/chatgpt-4o-latest + provider_id: openai + provider_model_id: openai/chatgpt-4o-latest + model_type: llm +- metadata: + embedding_dimension: 1536 + model_id: openai/text-embedding-3-small + provider_id: openai + provider_model_id: openai/text-embedding-3-small + model_type: embedding +- metadata: + embedding_dimension: 3072 + model_id: openai/text-embedding-3-large + provider_id: openai + provider_model_id: openai/text-embedding-3-large + model_type: embedding +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-8B + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: fireworks + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: nomic-ai/nomic-embed-text-v1.5 + provider_id: fireworks + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding +- metadata: {} + model_id: anthropic/claude-3-5-sonnet-latest + provider_id: anthropic + provider_model_id: anthropic/claude-3-5-sonnet-latest + model_type: llm +- metadata: {} + model_id: anthropic/claude-3-7-sonnet-latest + provider_id: anthropic + provider_model_id: anthropic/claude-3-7-sonnet-latest + model_type: llm +- metadata: {} + model_id: anthropic/claude-3-5-haiku-latest + provider_id: anthropic + provider_model_id: anthropic/claude-3-5-haiku-latest + model_type: llm +- metadata: + embedding_dimension: 1024 + context_length: 32000 + model_id: anthropic/voyage-3 + provider_id: anthropic + provider_model_id: anthropic/voyage-3 + model_type: embedding +- metadata: + embedding_dimension: 512 + context_length: 32000 + model_id: anthropic/voyage-3-lite + provider_id: anthropic + provider_model_id: anthropic/voyage-3-lite + model_type: embedding +- metadata: + embedding_dimension: 1024 + context_length: 32000 + model_id: anthropic/voyage-code-3 + provider_id: anthropic + provider_model_id: anthropic/voyage-code-3 + model_type: embedding +- metadata: {} + model_id: gemini/gemini-1.5-flash + provider_id: gemini + provider_model_id: gemini/gemini-1.5-flash + model_type: llm +- metadata: {} + model_id: gemini/gemini-1.5-pro + provider_id: gemini + provider_model_id: gemini/gemini-1.5-pro + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: gemini/text-embedding-004 + provider_id: gemini + provider_model_id: gemini/text-embedding-004 + model_type: embedding +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: sentence-transformers + model_type: embedding +shields: +- shield_id: meta-llama/Llama-Guard-3-8B +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321 diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 13dee0ba3..961194a73 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -116,12 +116,14 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed providers = [p for p in client.providers.list() if p.api == "inference"] assert len(providers) > 0, "No inference providers found" inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] - if text_model_id: + + model_ids = [m.identifier for m in client.models.list()] + if text_model_id and text_model_id not in model_ids: client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) - if vision_model_id: + if vision_model_id and vision_model_id not in model_ids: client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) - if embedding_model_id and embedding_dimension: + if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids: # try to find a provider that supports embeddings, if sentence-transformers is not available selected_provider = None for p in providers: diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index acda62f57..53afcaa4a 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -19,6 +19,16 @@ PROVIDER_TOOL_PROMPT_FORMAT = { PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} +def skip_if_model_doesnt_support_completion(client_with_models, model_id): + models = {m.identifier: m for m in client_with_models.models.list()} + provider_id = models[model_id].provider_id + providers = {p.provider_id: p for p in client_with_models.providers.list()} + provider = providers[provider_id] + print(f"Provider: {provider.provider_type} for model {model_id}") + if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion") + + @pytest.fixture(scope="session") def provider_tool_format(inference_provider_type): return ( @@ -35,6 +45,7 @@ def provider_tool_format(inference_provider_type): ], ) def test_text_completion_non_streaming(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.completion( @@ -56,6 +67,7 @@ def test_text_completion_non_streaming(client_with_models, text_model_id, test_c ], ) def test_text_completion_streaming(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) tc = TestCase(test_case) response = client_with_models.inference.completion( @@ -79,6 +91,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case) ], ) def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -107,6 +120,7 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_ ], ) def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") @@ -139,6 +153,8 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id, ], ) def test_text_completion_structured_output(client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_completion(client_with_models, text_model_id) + class AnswerFormat(BaseModel): name: str year_born: str @@ -237,9 +253,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming( tool_prompt_format=tool_prompt_format, stream=False, ) - # No content is returned for the system message since we expect the - # response to be a tool call - assert response.completion_message.content == "" + # some models can return content for the response in addition to the tool call assert response.completion_message.role == "assistant" assert len(response.completion_message.tool_calls) == 1