feat(registry): make the Stack query providers for model listing

This commit is contained in:
Ashwin Bharambe 2025-07-22 14:13:21 -07:00
parent cd8715d327
commit 2e5ffab4e3
12 changed files with 127 additions and 124 deletions

View file

@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum): class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import time import time
from typing import Any from typing import Any
@ -19,6 +20,47 @@ logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
model_refresh_interval_seconds: int = 300
async def initialize(self) -> None:
await super().initialize()
task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
task.add_done_callback(cb)
async def _refresh_models(self) -> None:
while True:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_llm_models(provider_id, models)
await asyncio.sleep(self.model_refresh_interval_seconds)
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type("model"))

View file

@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ... async def register_shield(self, shield: Shield) -> None: ...

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -50,6 +50,13 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
# TODO: add all-mini-lm models
return None
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") refresh_models: bool = Field(
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") default=False,
description="Whether to refresh models periodically",
)
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
@ -121,59 +120,27 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
) )
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
logger.debug("ollama starting background model refresh task") return self.config.refresh_models
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
while True: response = await self.client.list()
try: models = []
response = await self.client.list() for m in response.models:
except Exception as e: model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
logger.warning(f"Failed to list models: {str(e)}") if model_type == ModelType.embedding:
await asyncio.sleep(self.config.refresh_models_interval)
continue continue
models.append(
models = [] Model(
for m in response.models: identifier=m.model,
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm provider_resource_id=m.model,
if model_type == ModelType.embedding: provider_id=provider_id,
continue metadata={},
models.append( model_type=model_type,
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
) )
await self.model_store.update_registered_llm_models(provider_id, models) )
logger.debug(f"ollama refreshed model list ({len(models)} models)") return models
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """
@ -190,10 +157,6 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None self._client = None
self._openai_client = None self._openai_client = None

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False, default=False,
description="Whether to refresh models periodically", description="Whether to refresh models periodically",
) )
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider # automatically set by the resolver when instantiating the provider
__provider_id__: str __provider_id__: str
model_store: ModelStore | None = None model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: pass
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
self._refresh_task = asyncio.create_task(self._refresh_models()) return self.config.refresh_models
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client() self._lazy_initialize_client()
assert self.client is not None # mypy assert self.client is not None # mypy
while True: models = []
try: async for m in self.client.models.list():
models = [] model_type = ModelType.llm # unclear how to determine embedding vs. llm models
async for m in self.client.models.list(): models.append(
model_type = ModelType.llm # unclear how to determine embedding vs. llm models Model(
models.append( identifier=m.id,
Model( provider_resource_id=m.id,
identifier=m.id, provider_id=self.__provider_id__,
provider_resource_id=m.id, metadata={},
provider_id=provider_id, model_type=model_type,
metadata={}, )
model_type=model_type, )
) return models
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self._refresh_task: pass
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -65,6 +65,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate): class ModelRegistryHelper(ModelsProtocolPrivate):
__provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry]): def __init__(self, model_entries: list[ProviderModelEntry]):
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {} self.provider_id_to_llama_model_map = {}
@ -79,6 +81,25 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
models.append(
Model(
model_id=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None: def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None) return self.alias_to_provider_id_map.get(identifier, None)