mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 15:19:46 +00:00
implement embedding generation in supported inference providers
This commit is contained in:
parent
b896be2311
commit
e167e9eb93
16 changed files with 383 additions and 29 deletions
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
|
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
|
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
self._get_embedding_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue