more fixes

This commit is contained in:
Ashwin Bharambe 2025-07-25 14:01:21 -07:00
parent 145da06fdf
commit 5ae5a4da8a
4 changed files with 26 additions and 16 deletions

View file

@ -318,8 +318,10 @@ async def construct_stack(
await register_resources(run_config, impls) await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls)) REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task): def cb(task):
import traceback import traceback
@ -355,12 +357,18 @@ async def shutdown_stack(impls: dict[Api, Any]):
REGISTRY_REFRESH_TASK.cancel() REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]):
logger.info("refreshing registry")
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables: for routing_table in routing_tables:
await routing_table.refresh() await routing_table.refresh()
async def refresh_registry_task(impls: dict[Api, Any]):
logger.info("starting registry refresh task")
while True:
await refresh_registry_once(impls)
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)

View file

@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
-> Provider uses provider-model-id for inference -> Provider uses provider-model-id for inference
""" """
# this should be called `on_model_register` or something like that.
# the provider should _not_ be able to change the object in this
# callback
async def register_model(self, model: Model) -> Model: ... async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...

View file

@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
except ValueError: except ValueError:
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
response = await self.client.list() response = await self.client.list()
if model.provider_resource_id not in [m.model for m in response.models]: if model.provider_resource_id not in [m.model for m in response.models]:
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.client.list()
available_models = [m.model for m in response.models] available_models = [m.model for m in response.models]
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
assert provider_resource_id is not None # mypy
if provider_resource_id not in available_models: if provider_resource_id not in available_models:
available_models_latest = [m.model.split(":latest")[0] for m in response.models] available_models_latest = [m.model.split(":latest")[0] for m in response.models]
if provider_resource_id in available_models_latest: if provider_resource_id in available_models_latest:
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
) )
return model return model
raise UnsupportedModelError(model.provider_resource_id, available_models) raise UnsupportedModelError(provider_resource_id, available_models)
# mutating this should be considered an anti-pattern
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
return model return model

View file

@ -188,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return model return model
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
# TODO: should we block unregistering base supported provider model IDs? # model_id is the identifier, not the provider_resource_id
if model_id not in self.alias_to_provider_id_map: # unfortunately, this ID can be of the form provider_id/model_id which
raise ValueError(f"Model id '{model_id}' is not registered.") # we never registered. TODO: fix this by significantly rewriting
# registration and registry helper
del self.alias_to_provider_id_map[model_id] pass