mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: add static embedding metadata to dynamic model listings for providers using OpenAIMixin (#3547)
# What does this PR do? - remove auto-download of ollama embedding models - add embedding model metadata to dynamic listing w/ unit test - add support and tests for allowed_models - removed inference provider models.py files where dynamic listing is enabled - store embedding metadata in embedding_model_metadata field on inference providers - make model_entries optional on ModelRegistryHelper and LiteLLMOpenAIMixin - make OpenAIMixin a ModelRegistryHelper - skip base64 embedding test for remote::ollama, always returns floats - only use OpenAI client for ollama model listing - remove unused build_model_entry function - remove unused get_huggingface_repo function ## Test Plan ci w/ new tests
This commit is contained in:
parent
a50b63906c
commit
b67aef2fc4
43 changed files with 368 additions and 1015 deletions
|
@ -45,8 +45,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -55,6 +56,7 @@ from llama_stack.providers.datatypes import (
|
|||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -77,8 +79,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
request_has_media,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
|
@ -90,8 +90,44 @@ class OllamaInferenceAdapter(
|
|||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
|
@ -120,59 +156,6 @@ class OllamaInferenceAdapter(
|
|||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
response = await self.ollama_client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text:latest",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
@ -301,7 +284,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
|
@ -409,37 +392,16 @@ class OllamaInferenceAdapter(
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.ollama_client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.ollama_client.pull(model.provider_resource_id)
|
||||
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.ollama_client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue