mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: make backslash work in GET /models/{model_id:path} (#1068)
This commit is contained in:
parent
47fccf0d03
commit
2fa9e3c941
10 changed files with 32 additions and 24 deletions
|
@ -644,7 +644,9 @@ class Generator:
|
||||||
else:
|
else:
|
||||||
callbacks = None
|
callbacks = None
|
||||||
|
|
||||||
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
|
description = "\n".join(
|
||||||
|
filter(None, [doc_string.short_description, doc_string.long_description])
|
||||||
|
)
|
||||||
return Operation(
|
return Operation(
|
||||||
tags=[op.defining_class.__name__],
|
tags=[op.defining_class.__name__],
|
||||||
summary=None,
|
summary=None,
|
||||||
|
@ -681,6 +683,7 @@ class Generator:
|
||||||
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
|
||||||
|
|
||||||
route = op.get_route()
|
route = op.get_route()
|
||||||
|
route = route.replace(":path", "")
|
||||||
print(f"route: {route}")
|
print(f"route: {route}")
|
||||||
if route in paths:
|
if route in paths:
|
||||||
paths[route].update(pathItem)
|
paths[route].update(pathItem)
|
||||||
|
|
|
@ -130,6 +130,8 @@ class _FormatParameterExtractor:
|
||||||
|
|
||||||
def _get_route_parameters(route: str) -> List[str]:
|
def _get_route_parameters(route: str) -> List[str]:
|
||||||
extractor = _FormatParameterExtractor()
|
extractor = _FormatParameterExtractor()
|
||||||
|
# Replace all occurrences of ":path" with empty string
|
||||||
|
route = route.replace(":path", "")
|
||||||
route.format_map(extractor)
|
route.format_map(extractor)
|
||||||
return extractor.keys
|
return extractor.keys
|
||||||
|
|
||||||
|
|
|
@ -29,11 +29,11 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
ToolConfig,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import SafetyViolation
|
from llama_stack.apis.safety import SafetyViolation
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
|
@ -318,7 +318,7 @@ class Agents(Protocol):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}/turn", method="POST")
|
||||||
async def create_agent_turn(
|
async def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
@ -335,7 +335,10 @@ class Agents(Protocol):
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
|
@webmethod(
|
||||||
|
route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}",
|
||||||
|
method="GET",
|
||||||
|
)
|
||||||
async def get_agents_turn(
|
async def get_agents_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
@ -344,7 +347,7 @@ class Agents(Protocol):
|
||||||
) -> Turn: ...
|
) -> Turn: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}/step/{step_id:path}",
|
||||||
method="GET",
|
method="GET",
|
||||||
)
|
)
|
||||||
async def get_agents_step(
|
async def get_agents_step(
|
||||||
|
@ -355,14 +358,14 @@ class Agents(Protocol):
|
||||||
step_id: str,
|
step_id: str,
|
||||||
) -> AgentStepResponse: ...
|
) -> AgentStepResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
@webmethod(route="/agents/{agent_id:path}/session", method="POST")
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgentSessionCreateResponse: ...
|
) -> AgentSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="GET")
|
||||||
async def get_agents_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -370,14 +373,14 @@ class Agents(Protocol):
|
||||||
turn_ids: Optional[List[str]] = None,
|
turn_ids: Optional[List[str]] = None,
|
||||||
) -> Session: ...
|
) -> Session: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
@webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="DELETE")
|
||||||
async def delete_agents_session(
|
async def delete_agents_session(
|
||||||
self,
|
self,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
@webmethod(route="/agents/{agent_id:path}", method="DELETE")
|
||||||
async def delete_agent(
|
async def delete_agent(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
|
|
|
@ -58,7 +58,7 @@ class Datasets(Protocol):
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}", method="GET")
|
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
@ -67,7 +67,7 @@ class Datasets(Protocol):
|
||||||
@webmethod(route="/datasets", method="GET")
|
@webmethod(route="/datasets", method="GET")
|
||||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
|
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||||
async def unregister_dataset(
|
async def unregister_dataset(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
|
|
|
@ -62,7 +62,7 @@ class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> ListModelsResponse: ...
|
async def list_models(self) -> ListModelsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id}", method="GET")
|
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||||
async def get_model(
|
async def get_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -78,7 +78,7 @@ class Models(Protocol):
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id}", method="DELETE")
|
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||||
async def unregister_model(
|
async def unregister_model(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
|
||||||
@webmethod(route="/scoring-functions", method="GET")
|
@webmethod(route="/scoring-functions", method="GET")
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
|
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
|
||||||
|
|
||||||
@webmethod(route="/scoring-functions", method="POST")
|
@webmethod(route="/scoring-functions", method="POST")
|
||||||
|
|
|
@ -48,7 +48,7 @@ class Shields(Protocol):
|
||||||
@webmethod(route="/shields", method="GET")
|
@webmethod(route="/shields", method="GET")
|
||||||
async def list_shields(self) -> ListShieldsResponse: ...
|
async def list_shields(self) -> ListShieldsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/{identifier}", method="GET")
|
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||||
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields", method="POST")
|
@webmethod(route="/shields", method="POST")
|
||||||
|
|
|
@ -13,8 +13,8 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Union,
|
|
||||||
runtime_checkable,
|
runtime_checkable,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Primitive
|
from llama_models.llama3.api.datatypes import Primitive
|
||||||
|
@ -224,13 +224,13 @@ class Telemetry(Protocol):
|
||||||
order_by: Optional[List[str]] = None,
|
order_by: Optional[List[str]] = None,
|
||||||
) -> QueryTracesResponse: ...
|
) -> QueryTracesResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
|
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||||
|
|
||||||
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
|
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
|
||||||
async def get_span_tree(
|
async def get_span_tree(
|
||||||
self,
|
self,
|
||||||
span_id: str,
|
span_id: str,
|
||||||
|
|
|
@ -101,7 +101,7 @@ class ToolGroups(Protocol):
|
||||||
"""Register a tool group"""
|
"""Register a tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||||
async def get_tool_group(
|
async def get_tool_group(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
|
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
|
||||||
"""List tools with optional tool group"""
|
"""List tools with optional tool group"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/tools/{tool_name}", method="GET")
|
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||||
async def get_tool(
|
async def get_tool(
|
||||||
self,
|
self,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
) -> Tool: ...
|
) -> Tool: ...
|
||||||
|
|
||||||
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
|
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||||
async def unregister_toolgroup(
|
async def unregister_toolgroup(
|
||||||
self,
|
self,
|
||||||
toolgroup_id: str,
|
toolgroup_id: str,
|
||||||
|
|
|
@ -46,7 +46,7 @@ class VectorDBs(Protocol):
|
||||||
@webmethod(route="/vector-dbs", method="GET")
|
@webmethod(route="/vector-dbs", method="GET")
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||||
async def get_vector_db(
|
async def get_vector_db(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
|
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
|
||||||
provider_vector_db_id: Optional[str] = None,
|
provider_vector_db_id: Optional[str] = None,
|
||||||
) -> VectorDB: ...
|
) -> VectorDB: ...
|
||||||
|
|
||||||
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
|
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue