fix: make backslash work in GET /models/{model_id:path} (#1068)

This commit is contained in:
Xi Yan 2025-02-13 08:46:43 -08:00 committed by GitHub
parent 47fccf0d03
commit 2fa9e3c941
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 32 additions and 24 deletions

View file

@ -644,7 +644,9 @@ class Generator:
else:
callbacks = None
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
description = "\n".join(
filter(None, [doc_string.short_description, doc_string.long_description])
)
return Operation(
tags=[op.defining_class.__name__],
summary=None,
@ -681,6 +683,7 @@ class Generator:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route()
route = route.replace(":path", "")
print(f"route: {route}")
if route in paths:
paths[route].update(pathItem)

View file

@ -130,6 +130,8 @@ class _FormatParameterExtractor:
def _get_route_parameters(route: str) -> List[str]:
extractor = _FormatParameterExtractor()
# Replace all occurrences of ":path" with empty string
route = route.replace(":path", "")
route.format_map(extractor)
return extractor.keys

View file

@ -29,11 +29,11 @@ from llama_stack.apis.inference import (
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -318,7 +318,7 @@ class Agents(Protocol):
agent_config: AgentConfig,
) -> AgentCreateResponse: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}/turn", method="POST")
async def create_agent_turn(
self,
agent_id: str,
@ -335,7 +335,10 @@ class Agents(Protocol):
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
@webmethod(
route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}",
method="GET",
)
async def get_agents_turn(
self,
agent_id: str,
@ -344,7 +347,7 @@ class Agents(Protocol):
) -> Turn: ...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}/step/{step_id:path}",
method="GET",
)
async def get_agents_step(
@ -355,14 +358,14 @@ class Agents(Protocol):
step_id: str,
) -> AgentStepResponse: ...
@webmethod(route="/agents/{agent_id}/session", method="POST")
@webmethod(route="/agents/{agent_id:path}/session", method="POST")
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="GET")
async def get_agents_session(
self,
session_id: str,
@ -370,14 +373,14 @@ class Agents(Protocol):
turn_ids: Optional[List[str]] = None,
) -> Session: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="DELETE")
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None: ...
@webmethod(route="/agents/{agent_id}", method="DELETE")
@webmethod(route="/agents/{agent_id:path}", method="DELETE")
async def delete_agent(
self,
agent_id: str,

View file

@ -58,7 +58,7 @@ class Datasets(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/{dataset_id}", method="GET")
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(
self,
dataset_id: str,
@ -67,7 +67,7 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -62,7 +62,7 @@ class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/models/{model_id}", method="GET")
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
@ -78,7 +78,7 @@ class Models(Protocol):
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/{model_id}", method="DELETE")
@webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model(
self,
model_id: str,

View file

@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")

View file

@ -48,7 +48,7 @@ class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier}", method="GET")
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields", method="POST")

View file

@ -13,8 +13,8 @@ from typing import (
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
Union,
)
from llama_models.llama3.api.datatypes import Primitive
@ -224,13 +224,13 @@ class Telemetry(Protocol):
order_by: Optional[List[str]] = None,
) -> QueryTracesResponse: ...
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
async def get_span_tree(
self,
span_id: str,

View file

@ -101,7 +101,7 @@ class ToolGroups(Protocol):
"""Register a tool group"""
...
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
"""List tools with optional tool group"""
...
@webmethod(route="/tools/{tool_name}", method="GET")
@webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,

View file

@ -46,7 +46,7 @@ class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
provider_vector_db_id: Optional[str] = None,
) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...