mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
implement embedding generation in supported inference providers
This commit is contained in:
parent
b896be2311
commit
e167e9eb93
16 changed files with 383 additions and 29 deletions
|
@ -202,10 +202,15 @@ API responses, specify the adapter here.
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
|
if api_dependencies is None:
|
||||||
|
api_dependencies = []
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
|
api_dependencies=api_dependencies,
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
ModelRegistryHelper.__init__(
|
if model is None:
|
||||||
self,
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.descriptor(),
|
model.descriptor(),
|
||||||
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.model_registry_helper.register_model(model)
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
self._get_embedding_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def embeddings(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
contents: List[InterleavedTextMedia],
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
async def request_with_localized_media(
|
async def request_with_localized_media(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: SentenceTransformersInferenceConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||||
|
|
||||||
|
impl = SentenceTransformersInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceConfig(BaseModel): ...
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_model(self, request) -> None:
|
||||||
|
if request.model != self.config.model:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Model mismatch: {request.model} != {self.config.model}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
_ = self._get_embedding_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||||
|
raise NotImplementedError("Sentence transformers don't support completion")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
raise NotImplementedError("Sentence transformers don't support chat completion")
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.inference.vllm",
|
module="llama_stack.providers.inline.inference.vllm",
|
||||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
pip_packages=["sentence-transformers"],
|
||||||
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
import json
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -448,4 +449,18 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embeddings = []
|
||||||
|
for content in contents:
|
||||||
|
input_text = str(content) if not isinstance(content, str) else content
|
||||||
|
input_body = {"inputText": input_text}
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=model.provider_resource_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
embeddings.append(response_body.get("embedding"))
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.fireworks.ai/inference",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[str] = Field(
|
||||||
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -12,6 +12,7 @@ from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
from openai import OpenAI
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -89,19 +90,24 @@ class FireworksInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_api_key(self) -> str:
|
||||||
fireworks_api_key = None
|
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
fireworks_api_key = self.config.api_key
|
return self.config.api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
fireworks_api_key = provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
def _get_client(self) -> Fireworks:
|
||||||
|
fireworks_api_key = self._get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
|
def _get_openai_client(self) -> OpenAI:
|
||||||
|
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -264,4 +270,15 @@ class FireworksInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
client = self._get_openai_client()
|
||||||
|
kwargs = {}
|
||||||
|
if model.metadata.get("embedding_dimensions"):
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=model.provider_resource_id, input=contents, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -321,9 +321,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
response = await self.client.embed(
|
||||||
|
model=model.provider_resource_id, input=contents
|
||||||
|
)
|
||||||
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||||
|
if model.model_type == ModelType.embedding_model:
|
||||||
|
response = await self.client.list()
|
||||||
|
available_models = [m["model"] for m in response["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
|
|
@ -253,4 +253,9 @@ class TogetherInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
r = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id, input=contents
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in r.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -203,4 +203,14 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if model.metadata.get("embedding_dimensions"):
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=model.provider_resource_id, input=contents, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
@ -78,3 +84,24 @@ def pytest_generate_tests(metafunc):
|
||||||
):
|
):
|
||||||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
metafunc.parametrize("inference_stack", fixtures, indirect=True)
|
||||||
|
|
||||||
|
if "embedding_model" in metafunc.fixturenames:
|
||||||
|
model = metafunc.config.getoption("--embedding-model")
|
||||||
|
if not model:
|
||||||
|
raise ValueError(
|
||||||
|
"No embedding model specified. Please provide a valid embedding model."
|
||||||
|
)
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
|
metafunc.parametrize("embedding_model", params, indirect=True)
|
||||||
|
|
||||||
|
if "embedding_stack" in metafunc.fixturenames:
|
||||||
|
fixtures = INFERENCE_FIXTURES
|
||||||
|
if filtered_stacks := get_provider_fixture_overrides(
|
||||||
|
metafunc.config,
|
||||||
|
{
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
},
|
||||||
|
):
|
||||||
|
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||||
|
metafunc.parametrize("embedding_stack", fixtures, indirect=True)
|
||||||
|
|
|
@ -9,9 +9,9 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,13 @@ def inference_model(request):
|
||||||
return request.config.getoption("--inference-model", None)
|
return request.config.getoption("--inference-model", None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def embedding_model(request):
|
||||||
|
if hasattr(request, "param"):
|
||||||
|
return request.param
|
||||||
|
return request.config.getoption("--embedding-model", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def inference_remote() -> ProviderFixture:
|
def inference_remote() -> ProviderFixture:
|
||||||
return remote_stack_fixture()
|
return remote_stack_fixture()
|
||||||
|
@ -85,7 +92,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
if "Llama3.1-8B-Instruct" in inference_model:
|
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -240,3 +247,25 @@ async def inference_stack(request, inference_model):
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def embedding_stack(request, embedding_model):
|
||||||
|
fixture_name = request.param
|
||||||
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
|
test_stack = await construct_stack_for_test(
|
||||||
|
[Api.inference],
|
||||||
|
{"inference": inference_fixture.providers},
|
||||||
|
inference_fixture.provider_data,
|
||||||
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=embedding_model,
|
||||||
|
model_type=ModelType.embedding_model,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddings:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embeddings(self, embedding_model, embedding_stack):
|
||||||
|
inference_impl, models_impl = embedding_stack
|
||||||
|
model = await models_impl.get_model(embedding_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=embedding_model,
|
||||||
|
contents=["Hello, world!"],
|
||||||
|
)
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) > 0
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_embeddings(self, embedding_model, embedding_stack):
|
||||||
|
inference_impl, models_impl = embedding_stack
|
||||||
|
model = await models_impl.get_model(embedding_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding_model:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=embedding_model,
|
||||||
|
contents=texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) == len(texts)
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = len(response.embeddings[0])
|
||||||
|
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
45
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
45
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||||
|
|
||||||
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embedding_model = self._get_embedding_model(model.provider_resource_id)
|
||||||
|
embeddings = embedding_model.encode(contents)
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
def _get_embedding_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
|
return loaded_model
|
Loading…
Add table
Add a link
Reference in a new issue