mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 06:07:43 +00:00
feat(registry): make the Stack query providers for model listing (#2862)
This flips #2823 and #2805 by making the Stack periodically query the providers for models rather than the providers going behind the back and calling "register" on to the registry themselves. This also adds support for model listing for all other providers via `ModelRegistryHelper`. Once this is done, we do not need to manually list or register models via `run.yaml` and it will remove both noise and annoyance (setting `INFERENCE_MODEL` environment variables, for example) from the new user experience. In addition, it adds a configuration variable `allowed_models` which can be used to optionally restrict the set of models exposed from a provider.
This commit is contained in:
parent
537dc693ee
commit
1463b79218
23 changed files with 429 additions and 147 deletions
|
@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
|
|||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
# the Stack router will query each provider for their list of models
|
||||
# if a `refresh_interval_seconds` is provided, this method will be called
|
||||
# periodically to refresh the list of models
|
||||
#
|
||||
# NOTE: each model returned will be registered with the model registry. this means
|
||||
# a callback to the `register_model()` method will be made. this is duplicative and
|
||||
# may be removed in the future.
|
||||
async def list_models(self) -> list[Model] | None: ...
|
||||
|
||||
async def should_refresh_models(self) -> bool: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
|
|
|
@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
||||
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -98,14 +98,16 @@ class OllamaInferenceAdapter(
|
|||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = AsyncClient(host=self.config.url)
|
||||
return self._client
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
|
@ -121,59 +123,61 @@ class OllamaInferenceAdapter(
|
|||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
if self.config.refresh_models:
|
||||
logger.debug("ollama starting background model refresh task")
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
if task.cancelled():
|
||||
import traceback
|
||||
|
||||
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
logger.error(f"ollama background refresh task died: {task.exception()}")
|
||||
else:
|
||||
logger.error("ollama background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
# Wait for model store to be available (with timeout)
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
while True:
|
||||
try:
|
||||
response = await self.client.list()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list models: {str(e)}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if m.details.family in ["bert"]:
|
||||
continue
|
||||
|
||||
models = []
|
||||
for m in response.models:
|
||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||
if model_type == ModelType.embedding:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
||||
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
@ -190,12 +194,7 @@ class OllamaInferenceAdapter(
|
|||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
||||
logger.debug("ollama cancelling background refresh task")
|
||||
self._refresh_task.cancel()
|
||||
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(BaseModel):
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
|
|
|
@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
refresh_models_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval in seconds to refresh models",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
_refresh_task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
|
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
|
||||
# or available in distributions like "starter" without causing a ruckus
|
||||
return
|
||||
pass
|
||||
|
||||
if self.config.refresh_models:
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
# print the stack trace for the exception
|
||||
exc = task.exception()
|
||||
log.error(f"vLLM background refresh task died: {exc}")
|
||||
traceback.print_exception(exc)
|
||||
else:
|
||||
log.error("vLLM background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
provider_id = self.__provider_id__
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
while True:
|
||||
try:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
log.debug(f"vLLM refreshed model list ({len(models)} models)")
|
||||
except Exception as e:
|
||||
log.error(f"vLLM background refresh task failed: {e}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
# more closer to the Model class.
|
||||
class ProviderModelEntry(BaseModel):
|
||||
|
@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
|||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_entries: list[ProviderModelEntry]):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.allowed_models = allowed_models
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
for entry in self.model_entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for id in ids:
|
||||
if self.allowed_models and id not in self.allowed_models:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
model_id=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue