forked from phoenix-oss/llama-stack-mirror
chore: more API validators (#2165)
# What does this PR do? We added: * make sure docstrings are present with 'params' and 'returns' * fail if someone sets 'returns: None' * fix the failing APIs Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
e46de23be6
commit
bb5fca9521
23 changed files with 1304 additions and 574 deletions
645
docs/_static/llama-stack-spec.html
vendored
645
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
534
docs/_static/llama-stack-spec.yaml
vendored
534
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
@ -179,6 +179,35 @@ def _validate_has_ellipsis(method) -> str | None:
|
|||
if "..." not in source and not "NotImplementedError" in source:
|
||||
return "does not contain ellipsis (...) in its implementation"
|
||||
|
||||
def _validate_has_return_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is not None and return_type != type(None) and ":returns:" not in source:
|
||||
return "does not have a ':returns:' in its docstring"
|
||||
|
||||
def _validate_has_params_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
sig = inspect.signature(method)
|
||||
# Only check if the method has more than one parameter
|
||||
if len(sig.parameters) > 1 and ":param" not in source:
|
||||
return "does not have a ':param' in its docstring"
|
||||
|
||||
def _validate_has_no_return_none_in_docstring(method) -> str | None:
|
||||
source = inspect.getsource(method)
|
||||
return_type = method.__annotations__.get('return')
|
||||
if return_type is None and ":returns: None" in source:
|
||||
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
|
||||
|
||||
def _validate_docstring_lines_end_with_dot(method) -> str | None:
|
||||
docstring = inspect.getdoc(method)
|
||||
if docstring is None:
|
||||
return None
|
||||
|
||||
lines = docstring.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
|
||||
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
|
@ -186,13 +215,23 @@ _VALIDATORS = {
|
|||
_validate_list_parameters_contain_data,
|
||||
_validate_api_method_doesnt_return_list,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring
|
||||
],
|
||||
"POST": [
|
||||
_validate_has_ellipsis,
|
||||
_validate_has_return_in_docstring,
|
||||
_validate_has_params_in_docstring,
|
||||
_validate_has_no_return_none_in_docstring,
|
||||
_validate_docstring_lines_end_with_dot,
|
||||
],
|
||||
}
|
||||
|
||||
|
|
|
@ -413,7 +413,7 @@ class Agents(Protocol):
|
|||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -509,6 +509,7 @@ class Agents(Protocol):
|
|||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -606,5 +607,6 @@ class Agents(Protocol):
|
|||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -38,7 +38,17 @@ class BatchInference(Protocol):
|
|||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate completions for a batch of content.
|
||||
|
||||
:param model: The model to use for the completion.
|
||||
:param content_batch: The content to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param response_format: The response format to use for the completion.
|
||||
:param logprobs: The logprobs to use for the completion.
|
||||
:returns: A job for the completion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
|
@ -52,4 +62,17 @@ class BatchInference(Protocol):
|
|||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job: ...
|
||||
) -> Job:
|
||||
"""Generate chat completions for a batch of messages.
|
||||
|
||||
:param model: The model to use for the chat completion.
|
||||
:param messages_batch: The messages to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param tools: The tools to use for the chat completion.
|
||||
:param tool_choice: The tool choice to use for the chat completion.
|
||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||
:param response_format: The response format to use for the chat completion.
|
||||
:param logprobs: The logprobs to use for the chat completion.
|
||||
:returns: A job for the chat completion.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
:returns: A ListBenchmarksResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
) -> Benchmark: ...
|
||||
) -> Benchmark:
|
||||
"""Get a benchmark by its ID.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to get.
|
||||
:returns: A Benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
async def register_benchmark(
|
||||
|
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
|
|||
provider_benchmark_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to register.
|
||||
:param dataset_id: The ID of the dataset to use for the benchmark.
|
||||
:param scoring_functions: The scoring functions to use for the benchmark.
|
||||
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
|
||||
:param provider_id: The ID of the provider to use for the benchmark.
|
||||
:param metadata: The metadata to use for the benchmark.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -34,14 +34,21 @@ class DatasetIO(Protocol):
|
|||
- limit: Number of items to return. If None or -1, returns all items.
|
||||
|
||||
The response includes:
|
||||
- data: List of items for the current page
|
||||
- has_more: Whether there are more items available after this set
|
||||
- data: List of items for the current page.
|
||||
- has_more: Whether there are more items available after this set.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get the rows from.
|
||||
:param start_index: Index into dataset for the first row to get. Get all rows if None.
|
||||
:param limit: The number of rows to get.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ...
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
:param dataset_id: The ID of the dataset to append the rows to.
|
||||
:param rows: The rows to append to the dataset.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -137,7 +137,8 @@ class Datasets(Protocol):
|
|||
"""
|
||||
Register a new dataset.
|
||||
|
||||
:param purpose: The purpose of the dataset. One of
|
||||
:param purpose: The purpose of the dataset.
|
||||
One of:
|
||||
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
|
||||
{
|
||||
"messages": [
|
||||
|
@ -188,8 +189,9 @@ class Datasets(Protocol):
|
|||
]
|
||||
}
|
||||
:param metadata: The metadata for the dataset.
|
||||
- E.g. {"description": "My dataset"}
|
||||
- E.g. {"description": "My dataset"}.
|
||||
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -197,13 +199,29 @@ class Datasets(Protocol):
|
|||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> Dataset: ...
|
||||
) -> Dataset:
|
||||
"""Get a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to get.
|
||||
:returns: A Dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
async def list_datasets(self) -> ListDatasetsResponse: ...
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
:returns: A ListDatasetsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a dataset by its ID.
|
||||
|
||||
:param dataset_id: The ID of the dataset to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -93,7 +93,7 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: The job that was created to run the evaluation.
|
||||
:returns: The job that was created to run the evaluation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -111,7 +111,7 @@ class Eval(Protocol):
|
|||
:param input_rows: The rows to evaluate.
|
||||
:param scoring_functions: The scoring functions to use for the evaluation.
|
||||
:param benchmark_config: The configuration for the benchmark.
|
||||
:return: EvaluateResponse object containing generations and scores
|
||||
:returns: EvaluateResponse object containing generations and scores.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -121,7 +121,7 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the status of.
|
||||
:return: The status of the evaluationjob.
|
||||
:returns: The status of the evaluation job.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -140,6 +140,6 @@ class Eval(Protocol):
|
|||
|
||||
:param benchmark_id: The ID of the benchmark to run the evaluation on.
|
||||
:param job_id: The ID of the job to get the result of.
|
||||
:return: The result of the job.
|
||||
:returns: The result of the job.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -91,10 +91,11 @@ class Files(Protocol):
|
|||
"""
|
||||
Create a new upload session for a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param mime_type: MIME type of the file
|
||||
:param size: File size in bytes
|
||||
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:param mime_type: MIME type of the file.
|
||||
:param size: File size in bytes.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -107,7 +108,8 @@ class Files(Protocol):
|
|||
Upload file content to an existing upload session.
|
||||
On the server, request body will have the raw bytes that are uploaded.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileResponse or None if the upload is not complete.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -117,9 +119,10 @@ class Files(Protocol):
|
|||
upload_id: str,
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Returns information about an existsing upload session
|
||||
Returns information about an existsing upload session.
|
||||
|
||||
:param upload_id: ID of the upload session
|
||||
:param upload_id: ID of the upload session.
|
||||
:returns: A FileUploadResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -130,6 +133,9 @@ class Files(Protocol):
|
|||
) -> ListBucketResponse:
|
||||
"""
|
||||
List all buckets.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListBucketResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -141,7 +147,8 @@ class Files(Protocol):
|
|||
"""
|
||||
List all files in a bucket.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:returns: A ListFileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -154,8 +161,9 @@ class Files(Protocol):
|
|||
"""
|
||||
Get a file info identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
:returns: A FileResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -168,7 +176,7 @@ class Files(Protocol):
|
|||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-)
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
|
||||
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -845,13 +845,13 @@ class Inference(Protocol):
|
|||
"""Generate a completion for the given content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content: The content to generate a completion for
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding
|
||||
:param content: The content to generate a completion for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: If stream=False, returns a CompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -864,6 +864,15 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
"""Generate completions for a batch of content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content_batch: The content to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
|
@ -883,9 +892,9 @@ class Inference(Protocol):
|
|||
"""Generate a chat completion for the given messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param sampling_params: Parameters to control the sampling strategy
|
||||
:param tools: (Optional) List of tool definitions available to the model
|
||||
:param messages: List of messages in the conversation.
|
||||
:param sampling_params: Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
|
||||
.. deprecated::
|
||||
Use tool_config instead.
|
||||
|
@ -902,7 +911,7 @@ class Inference(Protocol):
|
|||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
|
||||
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -917,6 +926,17 @@ class Inference(Protocol):
|
|||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
"""Generate chat completions for a batch of messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages_batch: The messages to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
|
@ -935,7 +955,7 @@ class Inference(Protocol):
|
|||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -967,22 +987,23 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for
|
||||
:param best_of: (Optional) The number of completions to generate
|
||||
:param echo: (Optional) Whether to echo the prompt
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param prompt: The prompt to generate a completion for.
|
||||
:param best_of: (Optional) The number of completions to generate.
|
||||
:param echo: (Optional) Whether to echo the prompt.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAICompletion.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -1016,27 +1037,28 @@ class Inference(Protocol):
|
|||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param function_call: (Optional) The function call to use
|
||||
:param functions: (Optional) List of functions to use
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param response_format: (Optional) The response format to use
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param tool_choice: (Optional) The tool choice to use
|
||||
:param tools: (Optional) The tools to use
|
||||
:param top_logprobs: (Optional) The top log probabilities to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
:param messages: List of messages in the conversation.
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param function_call: (Optional) The function call to use.
|
||||
:param functions: (Optional) List of functions to use.
|
||||
:param logit_bias: (Optional) The logit bias to use.
|
||||
:param logprobs: (Optional) The log probabilities to use.
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate.
|
||||
:param n: (Optional) The number of completions to generate.
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens.
|
||||
:param response_format: (Optional) The response format to use.
|
||||
:param seed: (Optional) The seed to use.
|
||||
:param stop: (Optional) The stop tokens to use.
|
||||
:param stream: (Optional) Whether to stream the response.
|
||||
:param stream_options: (Optional) The stream options to use.
|
||||
:param temperature: (Optional) The temperature to use.
|
||||
:param tool_choice: (Optional) The tool choice to use.
|
||||
:param tools: (Optional) The tools to use.
|
||||
:param top_logprobs: (Optional) The top log probabilities to use.
|
||||
:param top_p: (Optional) The top p to use.
|
||||
:param user: (Optional) The user to use.
|
||||
:returns: An OpenAIChatCompletion.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
async def list_routes(self) -> ListRoutesResponse: ...
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all routes.
|
||||
|
||||
:returns: A ListRoutesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
async def health(self) -> HealthInfo: ...
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the health of the service.
|
||||
|
||||
:returns: A HealthInfo.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET")
|
||||
async def version(self) -> VersionInfo: ...
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
|
||||
:returns: A VersionInfo.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
:returns: A ListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Get a model by its identifier.
|
||||
|
||||
:param model_id: The identifier of the model to get.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
async def register_model(
|
||||
|
@ -99,10 +115,25 @@ class Models(Protocol):
|
|||
provider_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
) -> Model: ...
|
||||
) -> Model:
|
||||
"""Register a model.
|
||||
|
||||
:param model_id: The identifier of the model to register.
|
||||
:param provider_model_id: The identifier of the model in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param metadata: Any additional metadata for this model.
|
||||
:param model_type: The type of model to register.
|
||||
:returns: A Model.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Unregister a model.
|
||||
|
||||
:param model_id: The identifier of the model to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -182,7 +182,19 @@ class PostTraining(Protocol):
|
|||
),
|
||||
checkpoint_dir: str | None = None,
|
||||
algorithm_config: AlgorithmConfig | None = None,
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run supervised fine-tuning of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:param model: The model to fine-tune.
|
||||
:param checkpoint_dir: The directory to save checkpoint(s) to.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
async def preference_optimize(
|
||||
|
@ -193,16 +205,49 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: dict[str, Any],
|
||||
logger_config: dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
) -> PostTrainingJob:
|
||||
"""Run preference optimization of a model.
|
||||
|
||||
:param job_uuid: The UUID of the job to create.
|
||||
:param finetuned_model: The model to fine-tune.
|
||||
:param algorithm_config: The algorithm configuration.
|
||||
:param training_config: The training configuration.
|
||||
:param hyperparam_search_config: The hyperparam search configuration.
|
||||
:param logger_config: The logger configuration.
|
||||
:returns: A PostTrainingJob.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
:returns: A ListPostTrainingJobsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ...
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the status of.
|
||||
:returns: A PostTrainingJobStatusResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to cancel.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ...
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
:param job_uuid: The UUID of the job to get the artifacts of.
|
||||
:returns: A PostTrainingJobArtifactsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -32,7 +32,18 @@ class Providers(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
async def list_providers(self) -> ListProvidersResponse: ...
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
|
||||
:returns: A ListProvidersResponse containing information about all providers.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ...
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
|
||||
:param provider_id: The ID of the provider to inspect.
|
||||
:returns: A ProviderInfo object containing the provider's details.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -54,4 +54,12 @@ class Safety(Protocol):
|
|||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse: ...
|
||||
) -> RunShieldResponse:
|
||||
"""Run a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to run.
|
||||
:param messages: The messages to run the shield on.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A RunShieldResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -61,7 +61,15 @@ class Scoring(Protocol):
|
|||
dataset_id: str,
|
||||
scoring_functions: dict[str, ScoringFnParams | None],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
) -> ScoreBatchResponse:
|
||||
"""Score a batch of rows.
|
||||
|
||||
:param dataset_id: The ID of the dataset to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:param save_results_dataset: Whether to save the results to a dataset.
|
||||
:returns: A ScoreBatchResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
async def score(
|
||||
|
@ -73,6 +81,6 @@ class Scoring(Protocol):
|
|||
|
||||
:param input_rows: The rows to score.
|
||||
:param scoring_functions: The scoring functions to use for the scoring.
|
||||
:return: ScoreResponse object containing rows and aggregated results
|
||||
:returns: A ScoreResponse object containing rows and aggregated results.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -134,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
:returns: A ListScoringFunctionsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ...
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to get.
|
||||
:returns: A ScoringFn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
async def register_scoring_function(
|
||||
|
@ -148,4 +159,14 @@ class ScoringFunctions(Protocol):
|
|||
provider_scoring_fn_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: ScoringFnParams | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Register a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to register.
|
||||
:param description: The description of the scoring function.
|
||||
:param return_type: The return type of the scoring function.
|
||||
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
||||
:param provider_id: The ID of the provider to use for the scoring function.
|
||||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
async def list_shields(self) -> ListShieldsResponse: ...
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
:returns: A ListShieldsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
async def get_shield(self, identifier: str) -> Shield: ...
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
:param identifier: The identifier of the shield to get.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
async def register_shield(
|
||||
|
@ -58,4 +69,13 @@ class Shields(Protocol):
|
|||
provider_shield_id: str | None = None,
|
||||
provider_id: str | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> Shield: ...
|
||||
) -> Shield:
|
||||
"""Register a shield.
|
||||
|
||||
:param shield_id: The identifier of the shield to register.
|
||||
:param provider_shield_id: The identifier of the shield in the provider.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param params: The parameters of the shield.
|
||||
:returns: A Shield.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -247,7 +247,17 @@ class QueryMetricsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ...
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
|
||||
) -> None:
|
||||
"""Log an event.
|
||||
|
||||
:param event: The event to log.
|
||||
:param ttl_seconds: The time to live of the event.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
async def query_traces(
|
||||
|
@ -256,13 +266,35 @@ class Telemetry(Protocol):
|
|||
limit: int | None = 100,
|
||||
offset: int | None = 0,
|
||||
order_by: list[str] | None = None,
|
||||
) -> QueryTracesResponse: ...
|
||||
) -> QueryTracesResponse:
|
||||
"""Query traces.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the traces.
|
||||
:param limit: The limit of traces to return.
|
||||
:param offset: The offset of the traces to return.
|
||||
:param order_by: The order by of the traces to return.
|
||||
:returns: A QueryTracesResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get.
|
||||
:returns: A Trace.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
:param trace_id: The ID of the trace to get the span from.
|
||||
:param span_id: The ID of the span to get.
|
||||
:returns: A Span.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
async def get_span_tree(
|
||||
|
@ -270,7 +302,15 @@ class Telemetry(Protocol):
|
|||
span_id: str,
|
||||
attributes_to_return: list[str] | None = None,
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpanTreeResponse: ...
|
||||
) -> QuerySpanTreeResponse:
|
||||
"""Get a span tree by its ID.
|
||||
|
||||
:param span_id: The ID of the span to get the tree from.
|
||||
:param attributes_to_return: The attributes to return in the tree.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpanTreeResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
async def query_spans(
|
||||
|
@ -278,7 +318,15 @@ class Telemetry(Protocol):
|
|||
attribute_filters: list[QueryCondition],
|
||||
attributes_to_return: list[str],
|
||||
max_depth: int | None = None,
|
||||
) -> QuerySpansResponse: ...
|
||||
) -> QuerySpansResponse:
|
||||
"""Query spans.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_return: The attributes to return in the spans.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
:returns: A QuerySpansResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
async def save_spans_to_dataset(
|
||||
|
@ -287,7 +335,15 @@ class Telemetry(Protocol):
|
|||
attributes_to_save: list[str],
|
||||
dataset_id: str,
|
||||
max_depth: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Save spans to a dataset.
|
||||
|
||||
:param attribute_filters: The attribute filters to apply to the spans.
|
||||
:param attributes_to_save: The attributes to save to the dataset.
|
||||
:param dataset_id: The ID of the dataset to save the spans to.
|
||||
:param max_depth: The maximum depth of the tree.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
async def query_metrics(
|
||||
|
@ -298,4 +354,15 @@ class Telemetry(Protocol):
|
|||
granularity: str | None = "1d",
|
||||
query_type: MetricQueryType = MetricQueryType.RANGE,
|
||||
label_matchers: list[MetricLabelMatcher] | None = None,
|
||||
) -> QueryMetricsResponse: ...
|
||||
) -> QueryMetricsResponse:
|
||||
"""Query metrics.
|
||||
|
||||
:param metric_name: The name of the metric to query.
|
||||
:param start_time: The start time of the metric to query.
|
||||
:param end_time: The end time of the metric to query.
|
||||
:param granularity: The granularity of the metric to query.
|
||||
:param query_type: The type of query to perform.
|
||||
:param label_matchers: The label matchers to apply to the metric.
|
||||
:returns: A QueryMetricsResponse.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
|
|||
mcp_endpoint: URL | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Register a tool group"""
|
||||
"""Register a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to register.
|
||||
:param provider_id: The ID of the provider to use for the tool group.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:param args: A dictionary of arguments to pass to the tool group.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> ToolGroup: ...
|
||||
) -> ToolGroup:
|
||||
"""Get a tool group by its ID.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to get.
|
||||
:returns: A ToolGroup.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider"""
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
:returns: A ListToolGroupsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
"""List tools with optional tool group"""
|
||||
"""List tools with optional tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to list tools for.
|
||||
:returns: A ListToolsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
) -> Tool: ...
|
||||
) -> Tool:
|
||||
"""Get a tool by its name.
|
||||
|
||||
:param tool_name: The name of the tool to get.
|
||||
:returns: A Tool.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
) -> None:
|
||||
"""Unregister a tool group"""
|
||||
"""Unregister a tool group.
|
||||
|
||||
:param toolgroup_id: The ID of the tool group to unregister.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
|
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
|
|||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse: ...
|
||||
) -> ListToolDefsResponse:
|
||||
"""List all tools in the runtime.
|
||||
|
||||
:param tool_group_id: The ID of the tool group to list tools for.
|
||||
:param mcp_endpoint: The MCP endpoint to use for the tool group.
|
||||
:returns: A ListToolDefsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments"""
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
:param tool_name: The name of the tool to invoke.
|
||||
:param kwargs: A dictionary of arguments to pass to the tool.
|
||||
:returns: A ToolInvocationResult.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
:returns: A ListVectorDBsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Get a vector database by its identifier.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to get.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
async def register_vector_db(
|
||||
|
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
|
|||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorDB: ...
|
||||
) -> VectorDB:
|
||||
"""Register a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to register.
|
||||
:param embedding_model: The embedding model to use.
|
||||
:param embedding_dimension: The dimension of the embedding model.
|
||||
:param provider_id: The identifier of the provider.
|
||||
:param provider_vector_db_id: The identifier of the vector database in the provider.
|
||||
:returns: A VectorDB.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None: ...
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -46,7 +46,14 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
chunks: list[Chunk],
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
"""Insert chunks into a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to insert the chunks into.
|
||||
:param chunks: The chunks to insert.
|
||||
:param ttl_seconds: The time to live of the chunks.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
async def query_chunks(
|
||||
|
@ -54,4 +61,12 @@ class VectorIO(Protocol):
|
|||
vector_db_id: str,
|
||||
query: InterleavedContent,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse: ...
|
||||
) -> QueryChunksResponse:
|
||||
"""Query chunks from a vector database.
|
||||
|
||||
:param vector_db_id: The identifier of the vector database to query.
|
||||
:param query: The query to search for.
|
||||
:param params: The parameters of the query.
|
||||
:returns: A QueryChunksResponse.
|
||||
"""
|
||||
...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue