diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index a6f7a78af..d03ea933a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -22,12 +22,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 69fb5dea2..b82a4c752 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embeddings = [] diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 71b9155d5..4deeea630 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index e3acd4314..75751c8b1 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolDefinition, ToolPromptFormat, @@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 95fe65c39..b9b23584b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b21ae81d..45c15a467 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -17,17 +17,23 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContent, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( @@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 6f38230b2..ecd53e91c 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -23,14 +23,20 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: if any(content_has_media(content) for content in contents): raise NotImplementedError("Media is not supported") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 0071aaa5d..62c8381a8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -21,11 +21,13 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index a34c34f69..11da6bb9e 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, contents: List[InterleavedContent], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: client = self._get_client() model = await self.model_store.get_model(model_id) @@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference): return client.inference.embeddings( model_id=model.provider_resource_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, ) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a5acc47f8..bd620aa64 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b60954abc..57a296258 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionMessage, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, @@ -27,6 +28,7 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, + TextTruncation, ToolCall, ToolChoice, ToolConfig, @@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a52abd20d..d09ca241f 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a2c4f1542..1fca54bb3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bff5da8a7..b9422d85d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -28,12 +28,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 947a62f09..32aa5da3f 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,12 +5,14 @@ # the root directory of this source tree. import logging -from typing import List +from typing import List, Optional from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, InterleavedContentItem, ModelStore, + TextTruncation, ) EMBEDDING_MODELS = {} @@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin: self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)