Update embeddings signatures for all providers

This commit is contained in:
Ashwin Bharambe 2025-02-20 22:21:20 -08:00
parent e011491c6b
commit 2c1e8b5956
15 changed files with 116 additions and 10 deletions

View file

@ -22,12 +22,14 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
InterleavedContentItem, InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embeddings = [] embeddings = []

View file

@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -17,17 +17,23 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -23,14 +23,20 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents): if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported") raise NotImplementedError("Media is not supported")

View file

@ -21,11 +21,13 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[InterleavedContent],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
return client.inference.embeddings( return client.inference.embeddings(
model_id=model.provider_resource_id, model_id=model.provider_resource_id,
contents=contents, contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
) )

View file

@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
CompletionMessage, CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -27,6 +28,7 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
TextTruncation,
ToolCall, ToolCall,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (

View file

@ -28,12 +28,14 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import List from typing import List, Optional
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
InterleavedContentItem, InterleavedContentItem,
ModelStore, ModelStore,
TextTruncation,
) )
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin:
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)