add embedding model by default to distribution templates (#617)

# What does this PR do?
Adds the sentence transformer provider and the `all-MiniLM-L6-v2`
embedding model to the default models to register in the run.yaml for
all providers.

## Test Plan
llama stack build --template together --image-type conda
llama stack run
~/.llama/distributions/llamastack-together/together-run.yaml
This commit is contained in:
Dinesh Yeduguru 2024-12-13 12:48:00 -08:00 committed by GitHub
parent e893b22868
commit 516e1a3e59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 473 additions and 64 deletions

View file

@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
return model

View file

@ -4,7 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel): ...
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}

View file

@ -337,7 +337,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:

View file

@ -207,7 +207,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model = await self.model_store.get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding_model
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(

View file

@ -238,7 +238,7 @@ async def inference_stack(request, inference_model):
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding_model
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test(

View file

@ -18,7 +18,7 @@ class TestEmbeddings:
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
@ -39,7 +39,7 @@ class TestEmbeddings:
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding_model:
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
texts = ["Hello, world!", "This is a test", "Testing embeddings"]

View file

@ -125,7 +125,7 @@ async def memory_stack(inference_model, request):
models=[
ModelInput(
model_id=inference_model,
model_type=ModelType.embedding_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},

View file

@ -78,7 +78,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None
async def register_model(self, model: Model) -> Model:
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else: