From 2fa9e3c941d4b1d3183f45d3fa883637b1aa4110 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 13 Feb 2025 08:46:43 -0800 Subject: [PATCH] fix: make backslash work in GET /models/{model_id:path} (#1068) --- docs/openapi_generator/pyopenapi/generator.py | 5 ++++- .../openapi_generator/pyopenapi/operations.py | 2 ++ llama_stack/apis/agents/agents.py | 19 +++++++++++-------- llama_stack/apis/datasets/datasets.py | 4 ++-- llama_stack/apis/models/models.py | 4 ++-- .../scoring_functions/scoring_functions.py | 2 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/apis/telemetry/telemetry.py | 8 ++++---- llama_stack/apis/tools/tools.py | 6 +++--- llama_stack/apis/vector_dbs/vector_dbs.py | 4 ++-- 10 files changed, 32 insertions(+), 24 deletions(-) diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index f0d30a0e6..a0385cae0 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -644,7 +644,9 @@ class Generator: else: callbacks = None - description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description])) + description = "\n".join( + filter(None, [doc_string.short_description, doc_string.long_description]) + ) return Operation( tags=[op.defining_class.__name__], summary=None, @@ -681,6 +683,7 @@ class Generator: raise NotImplementedError(f"unknown HTTP method: {op.http_method}") route = op.get_route() + route = route.replace(":path", "") print(f"route: {route}") if route in paths: paths[route].update(pathItem) diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index abeb16936..bf4d35c87 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -130,6 +130,8 @@ class _FormatParameterExtractor: def _get_route_parameters(route: str) -> List[str]: extractor = _FormatParameterExtractor() + # Replace all occurrences of ":path" with empty string + route = route.replace(":path", "") route.format_map(extractor) return extractor.keys diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 785248633..b20145be9 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -29,11 +29,11 @@ from llama_stack.apis.inference import ( SamplingParams, ToolCall, ToolChoice, + ToolConfig, ToolPromptFormat, ToolResponse, ToolResponseMessage, UserMessage, - ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -318,7 +318,7 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}/turn", method="POST") async def create_agent_turn( self, agent_id: str, @@ -335,7 +335,10 @@ class Agents(Protocol): tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET") + @webmethod( + route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}", + method="GET", + ) async def get_agents_turn( self, agent_id: str, @@ -344,7 +347,7 @@ class Agents(Protocol): ) -> Turn: ... @webmethod( - route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}", + route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}/step/{step_id:path}", method="GET", ) async def get_agents_step( @@ -355,14 +358,14 @@ class Agents(Protocol): step_id: str, ) -> AgentStepResponse: ... - @webmethod(route="/agents/{agent_id}/session", method="POST") + @webmethod(route="/agents/{agent_id:path}/session", method="POST") async def create_agent_session( self, agent_id: str, session_name: str, ) -> AgentSessionCreateResponse: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="GET") async def get_agents_session( self, session_id: str, @@ -370,14 +373,14 @@ class Agents(Protocol): turn_ids: Optional[List[str]] = None, ) -> Session: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="DELETE") async def delete_agents_session( self, session_id: str, agent_id: str, ) -> None: ... - @webmethod(route="/agents/{agent_id}", method="DELETE") + @webmethod(route="/agents/{agent_id:path}", method="DELETE") async def delete_agent( self, agent_id: str, diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 5ad5bdcdb..5e2b38697 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -58,7 +58,7 @@ class Datasets(Protocol): metadata: Optional[Dict[str, Any]] = None, ) -> None: ... - @webmethod(route="/datasets/{dataset_id}", method="GET") + @webmethod(route="/datasets/{dataset_id:path}", method="GET") async def get_dataset( self, dataset_id: str, @@ -67,7 +67,7 @@ class Datasets(Protocol): @webmethod(route="/datasets", method="GET") async def list_datasets(self) -> ListDatasetsResponse: ... - @webmethod(route="/datasets/{dataset_id}", method="DELETE") + @webmethod(route="/datasets/{dataset_id:path}", method="DELETE") async def unregister_dataset( self, dataset_id: str, diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 3361c2836..7e6d9854f 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -62,7 +62,7 @@ class Models(Protocol): @webmethod(route="/models", method="GET") async def list_models(self) -> ListModelsResponse: ... - @webmethod(route="/models/{model_id}", method="GET") + @webmethod(route="/models/{model_id:path}", method="GET") async def get_model( self, model_id: str, @@ -78,7 +78,7 @@ class Models(Protocol): model_type: Optional[ModelType] = None, ) -> Model: ... - @webmethod(route="/models/{model_id}", method="DELETE") + @webmethod(route="/models/{model_id:path}", method="DELETE") async def unregister_model( self, model_id: str, diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 325979583..3fa40ffbf 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -134,7 +134,7 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="GET") async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... - @webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET") + @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... @webmethod(route="/scoring-functions", method="POST") diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 3dd685b14..ae316ee53 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -48,7 +48,7 @@ class Shields(Protocol): @webmethod(route="/shields", method="GET") async def list_shields(self) -> ListShieldsResponse: ... - @webmethod(route="/shields/{identifier}", method="GET") + @webmethod(route="/shields/{identifier:path}", method="GET") async def get_shield(self, identifier: str) -> Optional[Shield]: ... @webmethod(route="/shields", method="POST") diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 6272cc40b..5622aaeac 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -13,8 +13,8 @@ from typing import ( Literal, Optional, Protocol, - Union, runtime_checkable, + Union, ) from llama_models.llama3.api.datatypes import Primitive @@ -224,13 +224,13 @@ class Telemetry(Protocol): order_by: Optional[List[str]] = None, ) -> QueryTracesResponse: ... - @webmethod(route="/telemetry/traces/{trace_id}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") async def get_trace(self, trace_id: str) -> Trace: ... - @webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") async def get_span(self, trace_id: str, span_id: str) -> Span: ... - @webmethod(route="/telemetry/spans/{span_id}/tree", method="GET") + @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET") async def get_span_tree( self, span_id: str, diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index d6d806c53..a8e946b08 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -101,7 +101,7 @@ class ToolGroups(Protocol): """Register a tool group""" ... - @webmethod(route="/toolgroups/{toolgroup_id}", method="GET") + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET") async def get_tool_group( self, toolgroup_id: str, @@ -117,13 +117,13 @@ class ToolGroups(Protocol): """List tools with optional tool group""" ... - @webmethod(route="/tools/{tool_name}", method="GET") + @webmethod(route="/tools/{tool_name:path}", method="GET") async def get_tool( self, tool_name: str, ) -> Tool: ... - @webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE") + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") async def unregister_toolgroup( self, toolgroup_id: str, diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 4b782e2d5..1da2c128c 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -46,7 +46,7 @@ class VectorDBs(Protocol): @webmethod(route="/vector-dbs", method="GET") async def list_vector_dbs(self) -> ListVectorDBsResponse: ... - @webmethod(route="/vector-dbs/{vector_db_id}", method="GET") + @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET") async def get_vector_db( self, vector_db_id: str, @@ -62,5 +62,5 @@ class VectorDBs(Protocol): provider_vector_db_id: Optional[str] = None, ) -> VectorDB: ... - @webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE") + @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") async def unregister_vector_db(self, vector_db_id: str) -> None: ...