feat(registry): make the Stack query providers for model listing (#2862)

This flips #2823 and #2805 by making the Stack periodically query the
providers for models rather than the providers going behind the back and
calling "register" on to the registry themselves. This also adds support
for model listing for all other providers via `ModelRegistryHelper`.
Once this is done, we do not need to manually list or register models
via `run.yaml` and it will remove both noise and annoyance (setting
`INFERENCE_MODEL` environment variables, for example) from the new user
experience.

In addition, it adds a configuration variable `allowed_models` which can
be used to optionally restrict the set of models exposed from a
provider.
This commit is contained in:
Ashwin Bharambe 2025-07-24 10:39:53 -07:00 committed by GitHub
parent 537dc693ee
commit 1463b79218
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 429 additions and 147 deletions

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
pass
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
pass
async def unregister_model(self, model_id: str) -> None:
pass