Update embeddings signature so inputs and outputs list align

This commit is contained in:
Ashwin Bharambe 2025-02-19 16:14:22 -08:00
parent dd43494847
commit 25613953d5
15 changed files with 60 additions and 30 deletions

View file

@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
@ -481,12 +481,12 @@ class Inference(Protocol):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model. """Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text. :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
""" """
... ...

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
@ -214,7 +214,7 @@ class InferenceRouter(Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
Inference, Inference,
InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embeddings = [] embeddings = []

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
Inference, Inference,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -13,6 +13,7 @@ from ollama import AsyncClient
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -5,16 +5,36 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy, GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (

View file

@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -9,7 +9,7 @@ from typing import List
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
InterleavedContent, InterleavedContentItem,
ModelStore, ModelStore,
) )
@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)