mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This commit enhances the signal handling mechanism in the server by improving the `handle_signal` (previously handle_sigint) function. It now properly retrieves the signal name, ensuring clearer logging when a termination signal is received. Additionally, it cancels all running tasks and waits for their completion before stopping the event loop, allowing for a more graceful shutdown. Support for handling SIGTERM has also been added alongside SIGINT. Before the changes, handle_sigint used asyncio.run(run_shutdown()). However, asyncio.run() is meant to start a new event loop, and calling it inside an existing one (like when running Uvicorn) raises an error. The fix replaces asyncio.run(run_shutdown()) with an async function scheduled on the existing loop using loop.create_task(shutdown()). This ensures that the shutdown coroutine runs within the current event loop instead of trying to create a new one. Furthermore, this commit updates the project dependencies. `fastapi` and `uvicorn` have been added to the development dependencies in `pyproject.toml` and `uv.lock`, ensuring that the necessary packages are available for development and execution. Closes: https://github.com/meta-llama/llama-stack/issues/1043 Signed-off-by: Sébastien Han <seb@redhat.com> [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run a server and send SIGINT: ``` INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml Using config file: llama_stack/templates/ollama/run.yaml Run configuration: apis: - agents - datasetio - eval - inference - safety - scoring - telemetry - tool_runtime - vector_io container_image: null datasets: [] eval_tasks: [] image_name: ollama metadata_store: db_path: /Users/leseb/.llama/distributions/ollama/registry.db namespace: null type: sqlite models: - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType - llm provider_id: ollama provider_model_id: null - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType - embedding provider_id: sentence-transformers provider_model_id: null providers: agents: - config: persistence_store: db_path: /Users/leseb/.llama/distributions/ollama/agents_store.db namespace: null type: sqlite provider_id: meta-reference provider_type: inline::meta-reference datasetio: - config: {} provider_id: huggingface provider_type: remote::huggingface - config: {} provider_id: localfs provider_type: inline::localfs eval: - config: {} provider_id: meta-reference provider_type: inline::meta-reference inference: - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers safety: - config: {} provider_id: llama-guard provider_type: inline::llama-guard scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust telemetry: - config: service_name: llama-stack sinks: console,sqlite sqlite_db_path: /Users/leseb/.llama/distributions/ollama/trace_store.db provider_id: meta-reference provider_type: inline::meta-reference tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: code-interpreter provider_type: inline::code-interpreter - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime vector_io: - config: kvstore: db_path: /Users/leseb/.llama/distributions/ollama/faiss_store.db namespace: null type: sqlite provider_id: faiss provider_type: inline::faiss scoring_fns: [] server: port: 8321 tls_certfile: null tls_keyfile: null shields: [] tool_groups: - args: null mcp_endpoint: null provider_id: tavily-search toolgroup_id: builtin::websearch - args: null mcp_endpoint: null provider_id: rag-runtime toolgroup_id: builtin::rag - args: null mcp_endpoint: null provider_id: code-interpreter toolgroup_id: builtin::code_interpreter vector_dbs: [] version: '2' INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:213: Resolved 31 providers INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-inference => ollama INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-inference => sentence-transformers INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: models => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inference => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-vector_io => faiss INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-safety => llama-guard INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: shields => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: safety => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: vector_dbs => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: vector_io => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => brave-search INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => tavily-search INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => code-interpreter INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => rag-runtime INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: tool_groups => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: tool_runtime => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: agents => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-datasetio => huggingface INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-datasetio => localfs INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: datasets => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: datasetio => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: telemetry => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => basic INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => llm-as-judge INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => braintrust INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: scoring_functions => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: scoring => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-eval => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: eval_tasks => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: eval => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inspect => __builtin__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:216: INFO 2025-02-12 10:21:03,723 llama_stack.providers.remote.inference.ollama.ollama:148: checking connectivity to Ollama at `http://localhost:11434`... INFO 2025-02-12 10:21:03,734 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK" INFO 2025-02-12 10:21:03,843 faiss.loader:148: Loading faiss. INFO 2025-02-12 10:21:03,865 faiss.loader:150: Successfully loaded faiss. INFO 2025-02-12 10:21:03,868 faiss:173: Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. Warning: `bwrap` is not available. Code interpreter tool will not work correctly. INFO 2025-02-12 10:21:04,315 datasets:54: PyTorch version 2.6.0 available. INFO 2025-02-12 10:21:04,556 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK" INFO 2025-02-12 10:21:04,557 llama_stack.providers.utils.inference.embedding_mixin:42: Loading sentence transformer for all-MiniLM-L6-v2... INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:210: Use pytorch device_name: mps INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:218: Load pretrained SentenceTransformer: all-MiniLM-L6-v2 INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: all-MiniLM-L6-v2 served by sentence-transformers INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: meta-llama/Llama-3.2-3B-Instruct served by ollama INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::equality served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::subset_of served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-correctness served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-relevancy served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-similarity served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-entity-recall served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-precision served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-recall served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-relevancy served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::factuality served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::faithfulness served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::405b-simpleqa served by llm-as-judge INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::base served by llm-as-judge INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::code_interpreter served by code-interpreter INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::rag served by rag-runtime INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::websearch served by tavily-search INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:106: Serving API eval POST /v1/eval/tasks/{task_id}/evaluations DELETE /v1/eval/tasks/{task_id}/jobs/{job_id} GET /v1/eval/tasks/{task_id}/jobs/{job_id}/result GET /v1/eval/tasks/{task_id}/jobs/{job_id} POST /v1/eval/tasks/{task_id}/jobs Serving API agents POST /v1/agents POST /v1/agents/{agent_id}/session POST /v1/agents/{agent_id}/session/{session_id}/turn DELETE /v1/agents/{agent_id} DELETE /v1/agents/{agent_id}/session/{session_id} GET /v1/agents/{agent_id}/session/{session_id} GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id} GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id} Serving API scoring_functions GET /v1/scoring-functions/{scoring_fn_id} GET /v1/scoring-functions POST /v1/scoring-functions Serving API safety POST /v1/safety/run-shield Serving API inspect GET /v1/health GET /v1/inspect/providers GET /v1/inspect/routes GET /v1/version Serving API tool_runtime POST /v1/tool-runtime/invoke GET /v1/tool-runtime/list-tools POST /v1/tool-runtime/rag-tool/insert POST /v1/tool-runtime/rag-tool/query Serving API datasetio POST /v1/datasetio/rows GET /v1/datasetio/rows Serving API shields GET /v1/shields/{identifier} GET /v1/shields POST /v1/shields Serving API eval_tasks GET /v1/eval-tasks/{eval_task_id} GET /v1/eval-tasks POST /v1/eval-tasks Serving API models GET /v1/models/{model_id} GET /v1/models POST /v1/models DELETE /v1/models/{model_id} Serving API datasets GET /v1/datasets/{dataset_id} GET /v1/datasets POST /v1/datasets DELETE /v1/datasets/{dataset_id} Serving API vector_io POST /v1/vector-io/insert POST /v1/vector-io/query Serving API inference POST /v1/inference/chat-completion POST /v1/inference/completion POST /v1/inference/embeddings Serving API tool_groups GET /v1/tools/{tool_name} GET /v1/toolgroups/{toolgroup_id} GET /v1/toolgroups GET /v1/tools POST /v1/toolgroups DELETE /v1/toolgroups/{toolgroup_id} Serving API vector_dbs GET /v1/vector-dbs/{vector_db_id} GET /v1/vector-dbs POST /v1/vector-dbs DELETE /v1/vector-dbs/{vector_db_id} Serving API scoring POST /v1/scoring/score POST /v1/scoring/score-batch Serving API telemetry GET /v1/telemetry/traces/{trace_id}/spans/{span_id} GET /v1/telemetry/spans/{span_id}/tree GET /v1/telemetry/traces/{trace_id} POST /v1/telemetry/events GET /v1/telemetry/spans GET /v1/telemetry/traces POST /v1/telemetry/spans/export Listening on ['::', '0.0.0.0']:5001 INFO: Started server process [65372] INFO: Waiting for application startup. INFO: ASGI 'lifespan' protocol appears unsupported. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5001 (Press CTRL+C to quit) ^CINFO: Shutting down INFO: Finished server process [65372] Received signal SIGINT (2). Exiting gracefully... INFO 2025-02-12 10:21:11,215 __main__:151: Shutting down ModelsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down InferenceRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ShieldsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down SafetyRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorDBsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorIORouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolGroupsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolRuntimeRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down MetaReferenceAgentsImpl INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetIORouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down TelemetryAdapter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringFunctionsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalTasksRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DistributionInspectImpl ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han <seb@redhat.com>
542 lines
22 KiB
Python
542 lines
22 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import URL
|
|
from llama_stack.apis.common.type_system import ParamType
|
|
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
|
|
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
|
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
|
from llama_stack.apis.resource import ResourceType
|
|
from llama_stack.apis.scoring_functions import (
|
|
ListScoringFunctionsResponse,
|
|
ScoringFn,
|
|
ScoringFnParams,
|
|
ScoringFunctions,
|
|
)
|
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
|
from llama_stack.apis.tools import (
|
|
ListToolGroupsResponse,
|
|
ListToolsResponse,
|
|
Tool,
|
|
ToolGroup,
|
|
ToolGroups,
|
|
ToolHost,
|
|
)
|
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
from llama_stack.distribution.datatypes import (
|
|
RoutableObject,
|
|
RoutableObjectWithProvider,
|
|
RoutedProtocol,
|
|
)
|
|
from llama_stack.distribution.store import DistributionRegistry
|
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
|
|
|
|
|
def get_impl_api(p: Any) -> Api:
|
|
return p.__provider_spec__.api
|
|
|
|
|
|
# TODO: this should return the registered object for all APIs
|
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
|
api = get_impl_api(p)
|
|
|
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
|
|
|
if api == Api.inference:
|
|
return await p.register_model(obj)
|
|
elif api == Api.safety:
|
|
return await p.register_shield(obj)
|
|
elif api == Api.vector_io:
|
|
return await p.register_vector_db(obj)
|
|
elif api == Api.datasetio:
|
|
return await p.register_dataset(obj)
|
|
elif api == Api.scoring:
|
|
return await p.register_scoring_function(obj)
|
|
elif api == Api.eval:
|
|
return await p.register_eval_task(obj)
|
|
elif api == Api.tool_runtime:
|
|
return await p.register_tool(obj)
|
|
else:
|
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
|
|
|
|
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
api = get_impl_api(p)
|
|
if api == Api.vector_io:
|
|
return await p.unregister_vector_db(obj.identifier)
|
|
elif api == Api.inference:
|
|
return await p.unregister_model(obj.identifier)
|
|
elif api == Api.datasetio:
|
|
return await p.unregister_dataset(obj.identifier)
|
|
elif api == Api.tool_runtime:
|
|
return await p.unregister_tool(obj.identifier)
|
|
else:
|
|
raise ValueError(f"Unregister not supported for {api}")
|
|
|
|
|
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
|
|
|
|
|
class CommonRoutingTableImpl(RoutingTable):
|
|
def __init__(
|
|
self,
|
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
|
dist_registry: DistributionRegistry,
|
|
) -> None:
|
|
self.impls_by_provider_id = impls_by_provider_id
|
|
self.dist_registry = dist_registry
|
|
|
|
async def initialize(self) -> None:
|
|
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
|
for obj in objs:
|
|
if cls is None:
|
|
obj.provider_id = provider_id
|
|
else:
|
|
# Create a copy of the model data and explicitly set provider_id
|
|
model_data = obj.model_dump()
|
|
model_data["provider_id"] = provider_id
|
|
obj = cls(**model_data)
|
|
await self.dist_registry.register(obj)
|
|
|
|
# Register all objects from providers
|
|
for pid, p in self.impls_by_provider_id.items():
|
|
api = get_impl_api(p)
|
|
if api == Api.inference:
|
|
p.model_store = self
|
|
elif api == Api.safety:
|
|
p.shield_store = self
|
|
elif api == Api.vector_io:
|
|
p.vector_db_store = self
|
|
elif api == Api.datasetio:
|
|
p.dataset_store = self
|
|
elif api == Api.scoring:
|
|
p.scoring_function_store = self
|
|
scoring_functions = await p.list_scoring_functions()
|
|
await add_objects(scoring_functions, pid, ScoringFn)
|
|
elif api == Api.eval:
|
|
p.eval_task_store = self
|
|
elif api == Api.tool_runtime:
|
|
p.tool_store = self
|
|
|
|
async def shutdown(self) -> None:
|
|
for p in self.impls_by_provider_id.values():
|
|
await p.shutdown()
|
|
|
|
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
|
def apiname_object():
|
|
if isinstance(self, ModelsRoutingTable):
|
|
return ("Inference", "model")
|
|
elif isinstance(self, ShieldsRoutingTable):
|
|
return ("Safety", "shield")
|
|
elif isinstance(self, VectorDBsRoutingTable):
|
|
return ("VectorIO", "vector_db")
|
|
elif isinstance(self, DatasetsRoutingTable):
|
|
return ("DatasetIO", "dataset")
|
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
|
return ("Scoring", "scoring_function")
|
|
elif isinstance(self, EvalTasksRoutingTable):
|
|
return ("Eval", "eval_task")
|
|
elif isinstance(self, ToolGroupsRoutingTable):
|
|
return ("Tools", "tool")
|
|
else:
|
|
raise ValueError("Unknown routing table type")
|
|
|
|
apiname, objtype = apiname_object()
|
|
|
|
# Get objects from disk registry
|
|
obj = self.dist_registry.get_cached(objtype, routing_key)
|
|
if not obj:
|
|
provider_ids = list(self.impls_by_provider_id.keys())
|
|
if len(provider_ids) > 1:
|
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
|
else:
|
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
|
raise ValueError(
|
|
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
|
)
|
|
|
|
if not provider_id or provider_id == obj.provider_id:
|
|
return self.impls_by_provider_id[obj.provider_id]
|
|
|
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
|
|
|
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
|
# Get from disk registry
|
|
obj = await self.dist_registry.get(type, identifier)
|
|
if not obj:
|
|
return None
|
|
|
|
return obj
|
|
|
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
|
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
|
|
if obj.provider_id not in self.impls_by_provider_id:
|
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
|
|
|
p = self.impls_by_provider_id[obj.provider_id]
|
|
|
|
registered_obj = await register_object_with_provider(obj, p)
|
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
if obj.type == ResourceType.model.value:
|
|
await self.dist_registry.register(registered_obj)
|
|
return registered_obj
|
|
|
|
else:
|
|
await self.dist_registry.register(obj)
|
|
return obj
|
|
|
|
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
|
objs = await self.dist_registry.get_all()
|
|
return [obj for obj in objs if obj.type == type]
|
|
|
|
|
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
async def list_models(self) -> ListModelsResponse:
|
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
|
|
async def get_model(self, model_id: str) -> Optional[Model]:
|
|
return await self.get_object_by_identifier("model", model_id)
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
model_type: Optional[ModelType] = None,
|
|
) -> Model:
|
|
if provider_model_id is None:
|
|
provider_model_id = model_id
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this model
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
|
)
|
|
if metadata is None:
|
|
metadata = {}
|
|
if model_type is None:
|
|
model_type = ModelType.llm
|
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
model = Model(
|
|
identifier=model_id,
|
|
provider_resource_id=provider_model_id,
|
|
provider_id=provider_id,
|
|
metadata=metadata,
|
|
model_type=model_type,
|
|
)
|
|
registered_model = await self.register_object(model)
|
|
return registered_model
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
existing_model = await self.get_model(model_id)
|
|
if existing_model is None:
|
|
raise ValueError(f"Model {model_id} not found")
|
|
await self.unregister_object(existing_model)
|
|
|
|
|
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
async def list_shields(self) -> ListShieldsResponse:
|
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
|
|
|
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
|
return await self.get_object_by_identifier("shield", identifier)
|
|
|
|
async def register_shield(
|
|
self,
|
|
shield_id: str,
|
|
provider_shield_id: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> Shield:
|
|
if provider_shield_id is None:
|
|
provider_shield_id = shield_id
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this shield type
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
)
|
|
if params is None:
|
|
params = {}
|
|
shield = Shield(
|
|
identifier=shield_id,
|
|
provider_resource_id=provider_shield_id,
|
|
provider_id=provider_id,
|
|
params=params,
|
|
)
|
|
await self.register_object(shield)
|
|
return shield
|
|
|
|
|
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
|
|
async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]:
|
|
return await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db_id: str,
|
|
embedding_model: str,
|
|
embedding_dimension: Optional[int] = 384,
|
|
provider_id: Optional[str] = None,
|
|
provider_vector_db_id: Optional[str] = None,
|
|
) -> VectorDB:
|
|
if provider_vector_db_id is None:
|
|
provider_vector_db_id = vector_db_id
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this shield type
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
)
|
|
model = await self.get_object_by_identifier("model", embedding_model)
|
|
if model is None:
|
|
if embedding_model == "all-MiniLM-L6-v2":
|
|
raise ValueError(
|
|
"Embeddings are now served via Inference providers. "
|
|
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
|
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
|
)
|
|
else:
|
|
raise ValueError(f"Model {embedding_model} not found")
|
|
if model.model_type != ModelType.embedding:
|
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
if "embedding_dimension" not in model.metadata:
|
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
vector_db_data = {
|
|
"identifier": vector_db_id,
|
|
"type": ResourceType.vector_db.value,
|
|
"provider_id": provider_id,
|
|
"provider_resource_id": provider_vector_db_id,
|
|
"embedding_model": embedding_model,
|
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
}
|
|
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
|
await self.register_object(vector_db)
|
|
return vector_db
|
|
|
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
if existing_vector_db is None:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
await self.unregister_object(existing_vector_db)
|
|
|
|
|
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|
async def list_datasets(self) -> ListDatasetsResponse:
|
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
|
|
|
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
|
return await self.get_object_by_identifier("dataset", dataset_id)
|
|
|
|
async def register_dataset(
|
|
self,
|
|
dataset_id: str,
|
|
dataset_schema: Dict[str, ParamType],
|
|
url: URL,
|
|
provider_dataset_id: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
if provider_dataset_id is None:
|
|
provider_dataset_id = dataset_id
|
|
if provider_id is None:
|
|
# If provider_id not specified, use the only provider if it supports this dataset
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
)
|
|
if metadata is None:
|
|
metadata = {}
|
|
dataset = Dataset(
|
|
identifier=dataset_id,
|
|
provider_resource_id=provider_dataset_id,
|
|
provider_id=provider_id,
|
|
dataset_schema=dataset_schema,
|
|
url=url,
|
|
metadata=metadata,
|
|
)
|
|
await self.register_object(dataset)
|
|
|
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
dataset = await self.get_dataset(dataset_id)
|
|
if dataset is None:
|
|
raise ValueError(f"Dataset {dataset_id} not found")
|
|
await self.unregister_object(dataset)
|
|
|
|
|
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
|
|
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
|
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
|
|
async def register_scoring_function(
|
|
self,
|
|
scoring_fn_id: str,
|
|
description: str,
|
|
return_type: ParamType,
|
|
provider_scoring_fn_id: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
params: Optional[ScoringFnParams] = None,
|
|
) -> None:
|
|
if provider_scoring_fn_id is None:
|
|
provider_scoring_fn_id = scoring_fn_id
|
|
if provider_id is None:
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
)
|
|
scoring_fn = ScoringFn(
|
|
identifier=scoring_fn_id,
|
|
description=description,
|
|
return_type=return_type,
|
|
provider_resource_id=provider_scoring_fn_id,
|
|
provider_id=provider_id,
|
|
params=params,
|
|
)
|
|
scoring_fn.provider_id = provider_id
|
|
await self.register_object(scoring_fn)
|
|
|
|
|
|
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
|
async def list_eval_tasks(self) -> ListEvalTasksResponse:
|
|
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
|
|
|
|
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
|
|
return await self.get_object_by_identifier("eval_task", eval_task_id)
|
|
|
|
async def register_eval_task(
|
|
self,
|
|
eval_task_id: str,
|
|
dataset_id: str,
|
|
scoring_functions: List[str],
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
provider_eval_task_id: Optional[str] = None,
|
|
provider_id: Optional[str] = None,
|
|
) -> None:
|
|
if metadata is None:
|
|
metadata = {}
|
|
if provider_id is None:
|
|
if len(self.impls_by_provider_id) == 1:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
else:
|
|
raise ValueError(
|
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
)
|
|
if provider_eval_task_id is None:
|
|
provider_eval_task_id = eval_task_id
|
|
eval_task = EvalTask(
|
|
identifier=eval_task_id,
|
|
dataset_id=dataset_id,
|
|
scoring_functions=scoring_functions,
|
|
metadata=metadata,
|
|
provider_id=provider_id,
|
|
provider_resource_id=provider_eval_task_id,
|
|
)
|
|
await self.register_object(eval_task)
|
|
|
|
|
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
|
tools = await self.get_all_with_type("tool")
|
|
if toolgroup_id:
|
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
|
return ListToolsResponse(data=tools)
|
|
|
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
|
|
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
|
return await self.get_object_by_identifier("tool_group", toolgroup_id)
|
|
|
|
async def get_tool(self, tool_name: str) -> Tool:
|
|
return await self.get_object_by_identifier("tool", tool_name)
|
|
|
|
async def register_tool_group(
|
|
self,
|
|
toolgroup_id: str,
|
|
provider_id: str,
|
|
mcp_endpoint: Optional[URL] = None,
|
|
args: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
tools = []
|
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
|
|
for tool_def in tool_defs:
|
|
tools.append(
|
|
Tool(
|
|
identifier=tool_def.name,
|
|
toolgroup_id=toolgroup_id,
|
|
description=tool_def.description or "",
|
|
parameters=tool_def.parameters or [],
|
|
provider_id=provider_id,
|
|
provider_resource_id=tool_def.name,
|
|
metadata=tool_def.metadata,
|
|
tool_host=tool_host,
|
|
)
|
|
)
|
|
for tool in tools:
|
|
existing_tool = await self.get_tool(tool.identifier)
|
|
# Compare existing and new object if one exists
|
|
if existing_tool:
|
|
existing_dict = existing_tool.model_dump()
|
|
new_dict = tool.model_dump()
|
|
|
|
if existing_dict != new_dict:
|
|
raise ValueError(
|
|
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
)
|
|
await self.register_object(tool)
|
|
|
|
await self.dist_registry.register(
|
|
ToolGroup(
|
|
identifier=toolgroup_id,
|
|
provider_id=provider_id,
|
|
provider_resource_id=toolgroup_id,
|
|
mcp_endpoint=mcp_endpoint,
|
|
args=args,
|
|
)
|
|
)
|
|
|
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
|
tool_group = await self.get_tool_group(toolgroup_id)
|
|
if tool_group is None:
|
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
|
tools = await self.list_tools(toolgroup_id).data
|
|
for tool in tools:
|
|
await self.unregister_object(tool)
|
|
await self.unregister_object(tool_group)
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|