chore: more API validators (#2165)

# What does this PR do?

We added:

* make sure docstrings are present with 'params' and 'returns'
* fail if someone sets 'returns: None'
* fix the failing APIs

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-15 20:22:51 +02:00 committed by GitHub
parent e46de23be6
commit bb5fca9521
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1304 additions and 574 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -179,6 +179,35 @@ def _validate_has_ellipsis(method) -> str | None:
if "..." not in source and not "NotImplementedError" in source: if "..." not in source and not "NotImplementedError" in source:
return "does not contain ellipsis (...) in its implementation" return "does not contain ellipsis (...) in its implementation"
def _validate_has_return_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is not None and return_type != type(None) and ":returns:" not in source:
return "does not have a ':returns:' in its docstring"
def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method)
sig = inspect.signature(method)
# Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring"
def _validate_has_no_return_none_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is None and ":returns: None" in source:
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
def _validate_docstring_lines_end_with_dot(method) -> str | None:
docstring = inspect.getdoc(method)
if docstring is None:
return None
lines = docstring.split('\n')
for line in lines:
line = line.strip()
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
_VALIDATORS = { _VALIDATORS = {
"GET": [ "GET": [
@ -186,13 +215,23 @@ _VALIDATORS = {
_validate_list_parameters_contain_data, _validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list, _validate_api_method_doesnt_return_list,
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
"DELETE": [ "DELETE": [
_validate_api_delete_method_returns_none, _validate_api_delete_method_returns_none,
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring
], ],
"POST": [ "POST": [
_validate_has_ellipsis, _validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
} }

View file

@ -413,7 +413,7 @@ class Agents(Protocol):
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request. :param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config. :param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object. :returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
""" """
... ...
@ -509,6 +509,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to get. :param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for. :param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by. :param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
""" """
... ...
@ -606,5 +607,6 @@ class Agents(Protocol):
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:returns: An OpenAIResponseObject.
""" """
... ...

View file

@ -38,7 +38,17 @@ class BatchInference(Protocol):
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -52,4 +62,17 @@ class BatchInference(Protocol):
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Benchmarks(Protocol): class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET") @webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ... async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Benchmark: ... ) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST") @webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark( async def register_benchmark(
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
provider_benchmark_id: str | None = None, provider_benchmark_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> None: ... ) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...

View file

@ -34,14 +34,21 @@ class DatasetIO(Protocol):
- limit: Number of items to return. If None or -1, returns all items. - limit: Number of items to return. If None or -1, returns all items.
The response includes: The response includes:
- data: List of items for the current page - data: List of items for the current page.
- has_more: Whether there are more items available after this set - has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from. :param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None. :param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get. :param limit: The number of rows to get.
:returns: A PaginatedResponse.
""" """
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -137,7 +137,8 @@ class Datasets(Protocol):
""" """
Register a new dataset. Register a new dataset.
:param purpose: The purpose of the dataset. One of :param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training. - "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{ {
"messages": [ "messages": [
@ -188,8 +189,9 @@ class Datasets(Protocol):
] ]
} }
:param metadata: The metadata for the dataset. :param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"} - E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated. :param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
""" """
... ...
@ -197,13 +199,29 @@ class Datasets(Protocol):
async def get_dataset( async def get_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> Dataset: ... ) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET") @webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ... async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE") @webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset( async def unregister_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> None: ... ) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -93,7 +93,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation. :returns: The job that was created to run the evaluation.
""" """
... ...
@ -111,7 +111,7 @@ class Eval(Protocol):
:param input_rows: The rows to evaluate. :param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation. :param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores :returns: EvaluateResponse object containing generations and scores.
""" """
... ...
@ -121,7 +121,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of. :param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob. :returns: The status of the evaluation job.
""" """
... ...
@ -140,6 +140,6 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of. :param job_id: The ID of the job to get the result of.
:return: The result of the job. :returns: The result of the job.
""" """
... ...

View file

@ -91,10 +91,11 @@ class Files(Protocol):
""" """
Create a new upload session for a file identified by a bucket and key. Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-) :param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:param mime_type: MIME type of the file :param mime_type: MIME type of the file.
:param size: File size in bytes :param size: File size in bytes.
:returns: A FileUploadResponse.
""" """
... ...
@ -107,7 +108,8 @@ class Files(Protocol):
Upload file content to an existing upload session. Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded. On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileResponse or None if the upload is not complete.
""" """
... ...
@ -117,9 +119,10 @@ class Files(Protocol):
upload_id: str, upload_id: str,
) -> FileUploadResponse: ) -> FileUploadResponse:
""" """
Returns information about an existsing upload session Returns information about an existsing upload session.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileUploadResponse.
""" """
... ...
@ -130,6 +133,9 @@ class Files(Protocol):
) -> ListBucketResponse: ) -> ListBucketResponse:
""" """
List all buckets. List all buckets.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListBucketResponse.
""" """
... ...
@ -141,7 +147,8 @@ class Files(Protocol):
""" """
List all files in a bucket. List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListFileResponse.
""" """
... ...
@ -154,8 +161,9 @@ class Files(Protocol):
""" """
Get a file info identified by a bucket and key. Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:returns: A FileResponse.
""" """
... ...
@ -168,7 +176,7 @@ class Files(Protocol):
""" """
Delete a file identified by a bucket and key. Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
""" """
... ...

View file

@ -845,13 +845,13 @@ class Inference(Protocol):
"""Generate a completion for the given content using the specified model. """Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for :param content: The content to generate a completion for.
:param sampling_params: (Optional) Parameters to control the sampling strategy :param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding :param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion. :returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
""" """
... ...
@ -864,6 +864,15 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
@ -883,9 +892,9 @@ class Inference(Protocol):
"""Generate a chat completion for the given messages using the specified model. """Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy :param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model :param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated:: .. deprecated::
Use tool_config instead. Use tool_config instead.
@ -902,7 +911,7 @@ class Inference(Protocol):
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use. :param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion. :returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
""" """
... ...
@ -917,6 +926,17 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
@ -935,7 +955,7 @@ class Inference(Protocol):
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models. :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length. :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models. :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
""" """
... ...
@ -967,22 +987,23 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for :param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate :param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt :param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAICompletion.
""" """
... ...
@ -1016,27 +1037,28 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use :param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use :param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate :param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls :param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use :param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use :param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use :param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use :param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
""" """
... ...

View file

@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET") @webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ... async def list_routes(self) -> ListRoutesResponse:
"""List all routes.
:returns: A ListRoutesResponse.
"""
...
@webmethod(route="/health", method="GET") @webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ... async def health(self) -> HealthInfo:
"""Get the health of the service.
:returns: A HealthInfo.
"""
...
@webmethod(route="/version", method="GET") @webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ... async def version(self) -> VersionInfo:
"""Get the version of the service.
:returns: A VersionInfo.
"""
...

View file

@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
@trace_protocol @trace_protocol
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models", method="GET") @webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ... async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET") @webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ... async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET") @webmethod(route="/models/{model_id:path}", method="GET")
async def get_model( async def get_model(
self, self,
model_id: str, model_id: str,
) -> Model: ... ) -> Model:
"""Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST") @webmethod(route="/models", method="POST")
async def register_model( async def register_model(
@ -99,10 +115,25 @@ class Models(Protocol):
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> Model: ... ) -> Model:
"""Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE") @webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model( async def unregister_model(
self, self,
model_id: str, model_id: str,
) -> None: ... ) -> None:
"""Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -182,7 +182,19 @@ class PostTraining(Protocol):
), ),
checkpoint_dir: str | None = None, checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None, algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
@ -193,16 +205,49 @@ class PostTraining(Protocol):
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -32,7 +32,18 @@ class Providers(Protocol):
""" """
@webmethod(route="/providers", method="GET") @webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ... async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET") @webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ... async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -54,4 +54,12 @@ class Safety(Protocol):
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any], params: dict[str, Any],
) -> RunShieldResponse: ... ) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...

View file

@ -61,7 +61,15 @@ class Scoring(Protocol):
dataset_id: str, dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None], scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
@ -73,6 +81,6 @@ class Scoring(Protocol):
:param input_rows: The rows to score. :param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring. :param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results :returns: A ScoreResponse object containing rows and aggregated results.
""" """
... ...

View file

@ -134,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET") @webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(
@ -148,4 +159,14 @@ class ScoringFunctions(Protocol):
provider_scoring_fn_id: str | None = None, provider_scoring_fn_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: ScoringFnParams | None = None, params: ScoringFnParams | None = None,
) -> None: ... ) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...

View file

@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
@trace_protocol @trace_protocol
class Shields(Protocol): class Shields(Protocol):
@webmethod(route="/shields", method="GET") @webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ... async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET") @webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ... async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST") @webmethod(route="/shields", method="POST")
async def register_shield( async def register_shield(
@ -58,4 +69,13 @@ class Shields(Protocol):
provider_shield_id: str | None = None, provider_shield_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> Shield: ... ) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...

View file

@ -247,7 +247,17 @@ class QueryMetricsResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST") @webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ... async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(route="/telemetry/traces", method="POST") @webmethod(route="/telemetry/traces", method="POST")
async def query_traces( async def query_traces(
@ -256,13 +266,35 @@ class Telemetry(Protocol):
limit: int | None = 100, limit: int | None = 100,
offset: int | None = 0, offset: int | None = 0,
order_by: list[str] | None = None, order_by: list[str] | None = None,
) -> QueryTracesResponse: ... ) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ... async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree( async def get_span_tree(
@ -270,7 +302,15 @@ class Telemetry(Protocol):
span_id: str, span_id: str,
attributes_to_return: list[str] | None = None, attributes_to_return: list[str] | None = None,
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpanTreeResponse: ... ) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(route="/telemetry/spans", method="POST") @webmethod(route="/telemetry/spans", method="POST")
async def query_spans( async def query_spans(
@ -278,7 +318,15 @@ class Telemetry(Protocol):
attribute_filters: list[QueryCondition], attribute_filters: list[QueryCondition],
attributes_to_return: list[str], attributes_to_return: list[str],
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpansResponse: ... ) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST") @webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset( async def save_spans_to_dataset(
@ -287,7 +335,15 @@ class Telemetry(Protocol):
attributes_to_save: list[str], attributes_to_save: list[str],
dataset_id: str, dataset_id: str,
max_depth: int | None = None, max_depth: int | None = None,
) -> None: ... ) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST") @webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
async def query_metrics( async def query_metrics(
@ -298,4 +354,15 @@ class Telemetry(Protocol):
granularity: str | None = "1d", granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE, query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None, label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse: ... ) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -103,37 +103,65 @@ class ToolGroups(Protocol):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
... ...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group( async def get_tool_group(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> ToolGroup: ... ) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET") @webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider""" """List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
... ...
@webmethod(route="/tools", method="GET") @webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group""" """List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
"""
... ...
@webmethod(route="/tools/{tool_name:path}", method="GET") @webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool( async def get_tool(
self, self,
tool_name: str, tool_name: str,
) -> Tool: ... ) -> Tool:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup( async def unregister_toolgroup(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> None: ) -> None:
"""Unregister a tool group""" """Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
... ...
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ... ) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
... ...

View file

@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
@trace_protocol @trace_protocol
class VectorDBs(Protocol): class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET") @webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ... async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db( async def get_vector_db(
self, self,
vector_db_id: str, vector_db_id: str,
) -> VectorDB: ... ) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST") @webmethod(route="/vector-dbs", method="POST")
async def register_vector_db( async def register_vector_db(
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ... ) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ... async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -46,7 +46,14 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ... ) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param ttl_seconds: The time to live of the chunks.
"""
...
@webmethod(route="/vector-io/query", method="POST") @webmethod(route="/vector-io/query", method="POST")
async def query_chunks( async def query_chunks(
@ -54,4 +61,12 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ... ) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...