mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat(api): Add options for supporting various embedding models (#1192)
We need to support: - asymmetric embedding models (#934) - truncation policies (#933) - varying dimensional output (#932) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ```
This commit is contained in:
parent
6f9d622340
commit
81ce39a607
19 changed files with 202 additions and 11 deletions
21
docs/_static/llama-stack-spec.html
vendored
21
docs/_static/llama-stack-spec.html
vendored
|
@ -4944,6 +4944,27 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||||
|
},
|
||||||
|
"text_truncation": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"none",
|
||||||
|
"start",
|
||||||
|
"end"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
|
||||||
|
},
|
||||||
|
"output_dimension": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
|
||||||
|
},
|
||||||
|
"task_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"query",
|
||||||
|
"document"
|
||||||
|
],
|
||||||
|
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
22
docs/_static/llama-stack-spec.yaml
vendored
22
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3235,6 +3235,28 @@ components:
|
||||||
List of contents to generate embeddings for. Each content can be a string
|
List of contents to generate embeddings for. Each content can be a string
|
||||||
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||||
depends on the model and provider. Some models may only support text.
|
depends on the model and provider. Some models may only support text.
|
||||||
|
text_truncation:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- none
|
||||||
|
- start
|
||||||
|
- end
|
||||||
|
description: >-
|
||||||
|
(Optional) Config for how to truncate text for embedding when text is
|
||||||
|
longer than the model's max sequence length.
|
||||||
|
output_dimension:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Output dimensionality for the embeddings. Only supported by
|
||||||
|
Matryoshka models.
|
||||||
|
task_type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- query
|
||||||
|
- document
|
||||||
|
description: >-
|
||||||
|
(Optional) How is the embedding being used? This is only supported by
|
||||||
|
asymmetric embedding models.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model_id
|
- model_id
|
||||||
|
|
|
@ -402,6 +402,30 @@ class ModelStore(Protocol):
|
||||||
def get_model(self, identifier: str) -> Model: ...
|
def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
|
||||||
|
class TextTruncation(Enum):
|
||||||
|
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||||
|
|
||||||
|
:cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
|
||||||
|
:cvar start: Truncate from the start
|
||||||
|
:cvar end: Truncate from the end
|
||||||
|
"""
|
||||||
|
|
||||||
|
none = "none"
|
||||||
|
start = "start"
|
||||||
|
end = "end"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingTaskType(Enum):
|
||||||
|
"""How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
|
|
||||||
|
:cvar query: Used for a query for semantic search.
|
||||||
|
:cvar document: Used at indexing time when ingesting documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = "query"
|
||||||
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
@ -482,11 +506,17 @@ class Inference(Protocol):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
"""Generate embeddings for content pieces using the specified model.
|
"""Generate embeddings for content pieces using the specified model.
|
||||||
|
|
||||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||||
|
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||||
|
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||||
|
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -6,7 +6,11 @@
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||||
from llama_stack.apis.eval import (
|
from llama_stack.apis.eval import (
|
||||||
BenchmarkConfig,
|
BenchmarkConfig,
|
||||||
|
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -215,6 +221,9 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
output_dimension=output_dimension,
|
||||||
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,12 +22,14 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
|
@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from cerebras.cloud.sdk import AsyncCerebras
|
from cerebras.cloud.sdk import AsyncCerebras
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -17,17 +17,23 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import CoreModelId
|
from llama_stack.models.llama.sku_list import CoreModelId
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -23,14 +23,20 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
if any(content_has_media(content) for content in contents):
|
if any(content_has_media(content) for content in contents):
|
||||||
raise NotImplementedError("Media is not supported")
|
raise NotImplementedError("Media is not supported")
|
||||||
|
|
|
@ -21,11 +21,13 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedContent],
|
contents: List[InterleavedContent],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
|
||||||
return client.inference.embeddings(
|
return client.inference.embeddings(
|
||||||
model_id=model.provider_resource_id,
|
model_id=model.provider_resource_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
text_truncation=text_truncation,
|
||||||
|
output_dimension=output_dimension,
|
||||||
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
|
@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -27,6 +28,7 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
TextTruncation,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
|
@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
InterleavedContentItem,
|
||||||
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
|
|
|
@ -28,12 +28,14 @@ from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseStreamChunk,
|
CompletionResponseStreamChunk,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
ResponseFormatType,
|
ResponseFormatType,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
|
TextTruncation,
|
||||||
)
|
)
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[str] | List[InterleavedContentItem],
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue