add support embedding models and keeping provider models separate

This commit is contained in:
Ashwin Bharambe 2025-07-23 16:13:47 -07:00
parent cf629f81fe
commit 8fb4feeba1
6 changed files with 264 additions and 18 deletions

View file

@ -20,7 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate, ModelType
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
@ -41,6 +41,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
@ -54,8 +56,17 @@ class SentenceTransformersInferenceImpl(
return False
async def list_models(self) -> list[Model] | None:
# TODO: add all-mini-lm models
return None
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model

View file

@ -126,10 +126,44 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
models = []
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
# kill embedding models since we don't know dimensions for them
if m.details.family in ["bert"]:
continue
models.append(
Model(
@ -137,7 +171,7 @@ class OllamaInferenceAdapter(
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
model_type=ModelType.llm,
)
)
return models