diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 65e54b40d..1af019bd4 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -132,7 +132,7 @@ class ToolGroups(Protocol): ... -class SpecialToolGroups(Enum): +class SpecialToolGroup(Enum): rag_tool = "rag_tool" diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 3d71dea0f..745bcddea 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -9,7 +9,7 @@ from typing import Dict, List from pydantic import BaseModel -from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroups +from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION @@ -26,7 +26,7 @@ class ApiEndpoint(BaseModel): def toolgroup_protocol_map(): return { - SpecialToolGroups.rag_tool: RAGToolRuntime, + SpecialToolGroup.rag_tool: RAGToolRuntime, } @@ -39,7 +39,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: endpoints = [] protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) if api == Api.tool_runtime: - for tool_group in SpecialToolGroups: + for tool_group in SpecialToolGroup: sub_protocol = toolgroup_protocols[tool_group] sub_protocol_methods = inspect.getmembers( sub_protocol, predicate=inspect.isfunction