mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 21:19:48 +00:00
add embedding model by default
This commit is contained in:
parent
2a9b13dd52
commit
2f88006bd0
43 changed files with 446 additions and 85 deletions
|
|
@ -83,7 +83,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -7,4 +7,8 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SentenceTransformersInferenceConfig(BaseModel): ...
|
||||
class SentenceTransformersInferenceConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls) -> "SentenceTransformersInferenceConfig":
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding_model
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimensions")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
assert all(
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ async def inference_stack(request, inference_model):
|
|||
model_type = ModelType.llm
|
||||
metadata = {}
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
model_type = ModelType.embedding_model
|
||||
model_type = ModelType.embedding
|
||||
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||
|
||||
test_stack = await construct_stack_for_test(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class TestEmbeddings:
|
|||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
|
|
@ -39,7 +39,7 @@ class TestEmbeddings:
|
|||
inference_impl, models_impl = inference_stack
|
||||
model = await models_impl.get_model(inference_model)
|
||||
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
if model.model_type != ModelType.embedding:
|
||||
pytest.skip("This test is only applicable for embedding models")
|
||||
|
||||
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ async def memory_stack(inference_model, request):
|
|||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
model_type=ModelType.embedding_model,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return None
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue