implement embedding generation in supported inference providers

This commit is contained in:
Dinesh Yeduguru 2024-12-09 12:48:56 -08:00
parent b896be2311
commit e167e9eb93
16 changed files with 383 additions and 29 deletions

View file

@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config
model = resolve_model(config.model)
ModelRegistryHelper.__init__(
self,
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model_registry_helper = ModelRegistryHelper(
[
build_model_alias(
model.descriptor(),
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
)
],
)
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model
# verify that the checkpoint actually is for this model lol
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
async def unregister_model(self, model_id: str) -> None:
pass
async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
self._get_embedding_model(model.provider_resource_id)
return model
async def completion(
self,
model_id: str,
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
for x in impl():
yield x
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
):
from .sentence_transformers import SentenceTransformersInferenceImpl
impl = SentenceTransformersInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel): ...

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def check_model(self, request) -> None:
if request.model != self.config.model:
raise RuntimeError(
f"Model mismatch: {request.model} != {self.config.model}"
)
async def register_model(self, model: Model) -> None:
_ = self._get_embedding_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
raise NotImplementedError("Sentence transformers don't support completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError("Sentence transformers don't support chat completion")