forked from phoenix-oss/llama-stack-mirror
fix(api): update embeddings signature so inputs and outputs list align (#1161)
See Issue #922 The change is slightly backwards incompatible but no callsite (in our client codebases or stack-apps) every passes a depth-2 `List[List[InterleavedContentItem]]` (which is now disallowed.) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` Also ran `tests/client-sdk/inference/test_embeddings.py`
This commit is contained in:
parent
cfa752fc92
commit
6f9d622340
17 changed files with 85 additions and 41 deletions
20
docs/_static/llama-stack-spec.html
vendored
20
docs/_static/llama-stack-spec.html
vendored
|
@ -4929,11 +4929,21 @@
|
||||||
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||||
},
|
},
|
||||||
"contents": {
|
"contents": {
|
||||||
"type": "array",
|
"oneOf": [
|
||||||
"items": {
|
{
|
||||||
"$ref": "#/components/schemas/InterleavedContent"
|
"type": "array",
|
||||||
},
|
"items": {
|
||||||
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
16
docs/_static/llama-stack-spec.yaml
vendored
16
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3224,13 +3224,17 @@ components:
|
||||||
The identifier of the model to use. The model must be an embedding model
|
The identifier of the model to use. The model must be an embedding model
|
||||||
registered with Llama Stack and available via the /models endpoint.
|
registered with Llama Stack and available via the /models endpoint.
|
||||||
contents:
|
contents:
|
||||||
type: array
|
oneOf:
|
||||||
items:
|
- type: array
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
items:
|
||||||
|
type: string
|
||||||
|
- type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/InterleavedContentItem'
|
||||||
description: >-
|
description: >-
|
||||||
List of contents to generate embeddings for. Note that content can be
|
List of contents to generate embeddings for. Each content can be a string
|
||||||
multimodal. The behavior depends on the model and provider. Some models
|
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||||
may only support text.
|
depends on the model and provider. Some models may only support text.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model_id
|
- model_id
|
||||||
|
|
|
@ -20,7 +20,7 @@ from typing import (
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -481,12 +481,12 @@ class Inference(Protocol):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||||
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
|
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
BenchmarkConfig,
|
BenchmarkConfig,
|
||||||
|
@ -214,7 +214,7 @@ class InferenceRouter(Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
Inference,
|
Inference,
|
||||||
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
Inference,
|
Inference,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ollama import AsyncClient
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -5,16 +5,36 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
ImageContentItem,
|
ImageContentItem,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionMessage,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
SystemMessage,
|
||||||
|
ToolCall,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
|
@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
|
|
|
@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_models.datatypes import StopReason, ToolCall
|
from llama_models.datatypes import StopReason, ToolCall
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
TextDelta,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
InterleavedContent,
|
InterleavedContentItem,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue