diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 40c167685..638f7bb7b 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4929,11 +4929,21 @@
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
},
"contents": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/InterleavedContent"
- },
- "description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
+ "oneOf": [
+ {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ },
+ {
+ "type": "array",
+ "items": {
+ "$ref": "#/components/schemas/InterleavedContentItem"
+ }
+ }
+ ],
+ "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index c5043665b..08effe7cf 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -3224,13 +3224,17 @@ components:
The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint.
contents:
- type: array
- items:
- $ref: '#/components/schemas/InterleavedContent'
+ oneOf:
+ - type: array
+ items:
+ type: string
+ - type: array
+ items:
+ $ref: '#/components/schemas/InterleavedContentItem'
description: >-
- List of contents to generate embeddings for. Note that content can be
- multimodal. The behavior depends on the model and provider. Some models
- may only support text.
+ List of contents to generate embeddings for. Each content can be a string
+ or an InterleavedContentItem (and hence can be multimodal). The behavior
+ depends on the model and provider. Some models may only support text.
additionalProperties: false
required:
- model_id
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index a3fb69477..2dfe55977 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
-from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
+from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
@@ -481,12 +481,12 @@ class Inference(Protocol):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
- :param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
+ :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 016ca4984..d885ebc09 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
-from llama_stack.apis.common.content_types import URL, InterleavedContent
+from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
@@ -214,7 +214,7 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 5b0df91e7..a6f7a78af 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
+ InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
@@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
- async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
+ async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index 9c5a291db..69fb5dea2 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index 0a27d81d7..71b9155d5 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
@@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index de13638f5..e3acd4314 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
- model: str,
- contents: List[InterleavedContent],
+ model_id: str,
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 3f455da3c..95fe65c39 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py
index c75e92dfe..4b21ae81d 100644
--- a/llama_stack/providers/remote/inference/groq/groq.py
+++ b/llama_stack/providers/remote/inference/groq/groq.py
@@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
InterleavedContent,
+ InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
@@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 1dbcbc294..0071aaa5d 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -13,6 +13,7 @@ from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
+ InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
@@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index 09122a8e6..a5acc47f8 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
@@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index c05284d7d..b60954abc 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -5,16 +5,36 @@
# the root directory of this source tree.
import json
-from typing import AsyncGenerator
+from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
+ InterleavedContentItem,
TextContentItem,
)
-from llama_stack.apis.inference import * # noqa: F403
+from llama_stack.apis.inference import (
+ ChatCompletionRequest,
+ ChatCompletionResponse,
+ CompletionMessage,
+ EmbeddingsResponse,
+ Inference,
+ LogProbConfig,
+ Message,
+ ResponseFormat,
+ SamplingParams,
+ StopReason,
+ SystemMessage,
+ ToolCall,
+ ToolChoice,
+ ToolConfig,
+ ToolDefinition,
+ ToolPromptFormat,
+ ToolResponseMessage,
+ UserMessage,
+)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
@@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 1a50e3b61..a52abd20d 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 8afd3e85b..a2c4f1542 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
-from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index d1793c524..bff5da8a7 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI
-from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+ TextDelta,
+ ToolCallDelta,
+ ToolCallParseStatus,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py
index a84c2eecb..947a62f09 100644
--- a/llama_stack/providers/utils/inference/embedding_mixin.py
+++ b/llama_stack/providers/utils/inference/embedding_mixin.py
@@ -9,7 +9,7 @@ from typing import List
from llama_stack.apis.inference import (
EmbeddingsResponse,
- InterleavedContent,
+ InterleavedContentItem,
ModelStore,
)
@@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
- contents: List[InterleavedContent],
+ contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)