mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
more fixes
This commit is contained in:
parent
145da06fdf
commit
5ae5a4da8a
4 changed files with 26 additions and 16 deletions
|
@ -318,8 +318,10 @@ async def construct_stack(
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
|
|
||||||
|
await refresh_registry_once(impls)
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||||
|
|
||||||
def cb(task):
|
def cb(task):
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -355,11 +357,17 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
||||||
REGISTRY_REFRESH_TASK.cancel()
|
REGISTRY_REFRESH_TASK.cancel()
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
|
logger.info("refreshing registry")
|
||||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||||
|
for routing_table in routing_tables:
|
||||||
|
await routing_table.refresh()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||||
|
logger.info("starting registry refresh task")
|
||||||
while True:
|
while True:
|
||||||
for routing_table in routing_tables:
|
await refresh_registry_once(impls)
|
||||||
await routing_table.refresh()
|
|
||||||
|
|
||||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
|
||||||
-> Provider uses provider-model-id for inference
|
-> Provider uses provider-model-id for inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# this should be called `on_model_register` or something like that.
|
||||||
|
# the provider should _not_ be able to change the object in this
|
||||||
|
# callback
|
||||||
async def register_model(self, model: Model) -> Model: ...
|
async def register_model(self, model: Model) -> Model: ...
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
|
@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
|
|
||||||
if model.provider_resource_id is None:
|
|
||||||
raise ValueError("Model provider_resource_id cannot be None")
|
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||||
|
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
|
||||||
# - models not currently running are run by the ollama server as needed
|
# - models not currently running are run by the ollama server as needed
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
available_models = [m.model for m in response.models]
|
available_models = [m.model for m in response.models]
|
||||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
|
||||||
if provider_resource_id is None:
|
provider_resource_id = model.provider_resource_id
|
||||||
provider_resource_id = model.provider_resource_id
|
assert provider_resource_id is not None # mypy
|
||||||
if provider_resource_id not in available_models:
|
if provider_resource_id not in available_models:
|
||||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||||
if provider_resource_id in available_models_latest:
|
if provider_resource_id in available_models_latest:
|
||||||
|
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
|
||||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||||
|
|
||||||
|
# mutating this should be considered an anti-pattern
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -188,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
# TODO: should we block unregistering base supported provider model IDs?
|
# model_id is the identifier, not the provider_resource_id
|
||||||
if model_id not in self.alias_to_provider_id_map:
|
# unfortunately, this ID can be of the form provider_id/model_id which
|
||||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
# we never registered. TODO: fix this by significantly rewriting
|
||||||
|
# registration and registry helper
|
||||||
del self.alias_to_provider_id_map[model_id]
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue