mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat(registry): make the Stack query providers for model listing
This commit is contained in:
parent
cd8715d327
commit
2e5ffab4e3
12 changed files with 127 additions and 124 deletions
|
@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Required | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
| `api_token` | `str \| None` | No | fake | The API token |
|
| `api_token` | `str \| None` | No | fake | The API token |
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
|
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
async def update_registered_llm_models(
|
|
||||||
self,
|
|
||||||
provider_id: str,
|
|
||||||
models: list[Model],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class TextTruncation(Enum):
|
class TextTruncation(Enum):
|
||||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -19,6 +20,47 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
listed_providers: set[str] = set()
|
||||||
|
model_refresh_interval_seconds: int = 300
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
await super().initialize()
|
||||||
|
task = asyncio.create_task(self._refresh_models())
|
||||||
|
|
||||||
|
def cb(task):
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
if task.cancelled():
|
||||||
|
logger.error("Model refresh task cancelled")
|
||||||
|
elif task.exception():
|
||||||
|
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||||
|
traceback.print_exception(task.exception())
|
||||||
|
else:
|
||||||
|
logger.debug("Model refresh task completed")
|
||||||
|
|
||||||
|
task.add_done_callback(cb)
|
||||||
|
|
||||||
|
async def _refresh_models(self) -> None:
|
||||||
|
while True:
|
||||||
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
|
refresh = await provider.should_refresh_models()
|
||||||
|
if not (refresh or provider_id in self.listed_providers):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
models = await provider.list_models()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.listed_providers.add(provider_id)
|
||||||
|
if models is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self.update_registered_llm_models(provider_id, models)
|
||||||
|
|
||||||
|
await asyncio.sleep(self.model_refresh_interval_seconds)
|
||||||
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
||||||
|
# the Stack router will query each provider for their list of models
|
||||||
|
# if a `refresh_interval_seconds` is provided, this method will be called
|
||||||
|
# periodically to refresh the list of models
|
||||||
|
#
|
||||||
|
# NOTE: each model returned will be registered with the model registry. this means
|
||||||
|
# a callback to the `register_model()` method will be made. this is duplicative and
|
||||||
|
# may be removed in the future.
|
||||||
|
async def list_models(self) -> list[Model] | None: ...
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
class ShieldsProtocolPrivate(Protocol):
|
class ShieldsProtocolPrivate(Protocol):
|
||||||
async def register_shield(self, shield: Shield) -> None: ...
|
async def register_shield(self, shield: Shield) -> None: ...
|
||||||
|
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator.stop()
|
self.generator.stop()
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,13 @@ class SentenceTransformersInferenceImpl(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
# TODO: add all-mini-lm models
|
||||||
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
||||||
|
|
||||||
class OllamaImplConfig(BaseModel):
|
class OllamaImplConfig(BaseModel):
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
url: str = DEFAULT_OLLAMA_URL
|
||||||
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
refresh_models: bool = Field(
|
||||||
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
default=False,
|
||||||
|
description="Whether to refresh models periodically",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
@ -121,59 +120,27 @@ class OllamaInferenceAdapter(
|
||||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.refresh_models:
|
async def should_refresh_models(self) -> bool:
|
||||||
logger.debug("ollama starting background model refresh task")
|
return self.config.refresh_models
|
||||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
|
||||||
|
|
||||||
def cb(task):
|
|
||||||
if task.cancelled():
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
|
||||||
elif task.exception():
|
|
||||||
logger.error(f"ollama background refresh task died: {task.exception()}")
|
|
||||||
else:
|
|
||||||
logger.error("ollama background refresh task completed unexpectedly")
|
|
||||||
|
|
||||||
self._refresh_task.add_done_callback(cb)
|
|
||||||
|
|
||||||
async def _refresh_models(self) -> None:
|
|
||||||
# Wait for model store to be available (with timeout)
|
|
||||||
waited_time = 0
|
|
||||||
while not self.model_store and waited_time < 60:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
waited_time += 1
|
|
||||||
|
|
||||||
if not self.model_store:
|
|
||||||
raise ValueError("Model store not set after waiting 60 seconds")
|
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
provider_id = self.__provider_id__
|
provider_id = self.__provider_id__
|
||||||
while True:
|
response = await self.client.list()
|
||||||
try:
|
models = []
|
||||||
response = await self.client.list()
|
for m in response.models:
|
||||||
except Exception as e:
|
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||||
logger.warning(f"Failed to list models: {str(e)}")
|
if model_type == ModelType.embedding:
|
||||||
await asyncio.sleep(self.config.refresh_models_interval)
|
|
||||||
continue
|
continue
|
||||||
|
models.append(
|
||||||
models = []
|
Model(
|
||||||
for m in response.models:
|
identifier=m.model,
|
||||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
provider_resource_id=m.model,
|
||||||
if model_type == ModelType.embedding:
|
provider_id=provider_id,
|
||||||
continue
|
metadata={},
|
||||||
models.append(
|
model_type=model_type,
|
||||||
Model(
|
|
||||||
identifier=m.model,
|
|
||||||
provider_resource_id=m.model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
metadata={},
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
)
|
||||||
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
return models
|
||||||
|
|
||||||
await asyncio.sleep(self.config.refresh_models_interval)
|
|
||||||
|
|
||||||
async def health(self) -> HealthResponse:
|
async def health(self) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
|
@ -190,10 +157,6 @@ class OllamaInferenceAdapter(
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
|
||||||
logger.debug("ollama cancelling background refresh task")
|
|
||||||
self._refresh_task.cancel()
|
|
||||||
|
|
||||||
self._client = None
|
self._client = None
|
||||||
self._openai_client = None
|
self._openai_client = None
|
||||||
|
|
||||||
|
|
|
@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to refresh models periodically",
|
description="Whether to refresh models periodically",
|
||||||
)
|
)
|
||||||
refresh_models_interval: int = Field(
|
|
||||||
default=300,
|
|
||||||
description="Interval in seconds to refresh models",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("tls_verify")
|
@field_validator("tls_verify")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
# automatically set by the resolver when instantiating the provider
|
# automatically set by the resolver when instantiating the provider
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
model_store: ModelStore | None = None
|
model_store: ModelStore | None = None
|
||||||
_refresh_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||||
|
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.url:
|
pass
|
||||||
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
|
|
||||||
# or available in distributions like "starter" without causing a ruckus
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.config.refresh_models:
|
async def should_refresh_models(self) -> bool:
|
||||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
return self.config.refresh_models
|
||||||
|
|
||||||
def cb(task):
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
if task.cancelled():
|
|
||||||
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
|
||||||
elif task.exception():
|
|
||||||
# print the stack trace for the exception
|
|
||||||
exc = task.exception()
|
|
||||||
log.error(f"vLLM background refresh task died: {exc}")
|
|
||||||
traceback.print_exception(exc)
|
|
||||||
else:
|
|
||||||
log.error("vLLM background refresh task completed unexpectedly")
|
|
||||||
|
|
||||||
self._refresh_task.add_done_callback(cb)
|
|
||||||
|
|
||||||
async def _refresh_models(self) -> None:
|
|
||||||
provider_id = self.__provider_id__
|
|
||||||
waited_time = 0
|
|
||||||
while not self.model_store and waited_time < 60:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
waited_time += 1
|
|
||||||
|
|
||||||
if not self.model_store:
|
|
||||||
raise ValueError("Model store not set after waiting 60 seconds")
|
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
self._lazy_initialize_client()
|
self._lazy_initialize_client()
|
||||||
assert self.client is not None # mypy
|
assert self.client is not None # mypy
|
||||||
while True:
|
models = []
|
||||||
try:
|
async for m in self.client.models.list():
|
||||||
models = []
|
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||||
async for m in self.client.models.list():
|
models.append(
|
||||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
Model(
|
||||||
models.append(
|
identifier=m.id,
|
||||||
Model(
|
provider_resource_id=m.id,
|
||||||
identifier=m.id,
|
provider_id=self.__provider_id__,
|
||||||
provider_resource_id=m.id,
|
metadata={},
|
||||||
provider_id=provider_id,
|
model_type=model_type,
|
||||||
metadata={},
|
)
|
||||||
model_type=model_type,
|
)
|
||||||
)
|
return models
|
||||||
)
|
|
||||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
|
||||||
log.debug(f"vLLM refreshed model list ({len(models)} models)")
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"vLLM background refresh task failed: {e}")
|
|
||||||
await asyncio.sleep(self.config.refresh_models_interval)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self._refresh_task:
|
pass
|
||||||
self._refresh_task.cancel()
|
|
||||||
self._refresh_task = None
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -65,6 +65,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, model_entries: list[ProviderModelEntry]):
|
def __init__(self, model_entries: list[ProviderModelEntry]):
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
|
@ -79,6 +81,25 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||||
|
|
||||||
|
async def list_models(self) -> list[Model] | None:
|
||||||
|
models = []
|
||||||
|
for entry in self.model_entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for id in ids:
|
||||||
|
models.append(
|
||||||
|
Model(
|
||||||
|
model_id=id,
|
||||||
|
provider_resource_id=entry.provider_model_id,
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
metadata=entry.metadata,
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
|
||||||
|
async def should_refresh_models(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||||
return self.alias_to_provider_id_map.get(identifier, None)
|
return self.alias_to_provider_id_map.get(identifier, None)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue