Update OpenAPI generator to add param and field documentation (#896)

We desperately need to document our APIs. This is the basic requirement
of having a Spec :)

This PR updates the OpenAPI generator so documentation for request
parameters and object fields can be properly added to the OpenAPI specs.
From there, this should get picked by Stainless, etc.

## Test Plan:

Updated client-sdk (See
https://github.com/meta-llama/llama-stack-client-python/pull/104) and
then ran:

```bash
cd tests/client-sdk
LLAMA_STACK_CONFIG=../../llama_stack/templates/fireworks/run.yaml pytest -s -v inference/test_inference.py agents/test_agents.py
```
This commit is contained in:
Ashwin Bharambe 2025-01-29 10:04:30 -08:00 committed by GitHub
parent 53721e91ad
commit 0d96070af9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1059 additions and 2122 deletions

View file

@ -36,6 +36,16 @@ from .pyopenapi.specification import Info, Server # noqa: E402
from .pyopenapi.utility import Specification # noqa: E402 from .pyopenapi.utility import Specification # noqa: E402
def str_presenter(dumper, data):
if data.startswith(f"/{LLAMA_STACK_API_VERSION}") or data.startswith(
"#/components/schemas/"
):
style = None
else:
style = ">" if "\n" in data or len(data) > 40 else None
return dumper.represent_scalar("tag:yaml.org,2002:str", data, style=style)
def main(output_dir: str): def main(output_dir: str):
output_dir = Path(output_dir) output_dir = Path(output_dir)
if not output_dir.exists(): if not output_dir.exists():
@ -69,7 +79,8 @@ def main(output_dir: str):
y.sequence_dash_offset = 2 y.sequence_dash_offset = 2
y.width = 80 y.width = 80
y.allow_unicode = True y.allow_unicode = True
y.explicit_start = True y.representer.add_representer(str, str_presenter)
y.dump( y.dump(
spec.get_json(), spec.get_json(),
fp, fp,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import collections
import hashlib import hashlib
import ipaddress import ipaddress
import typing import typing
from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from ..strong_typing.core import JsonType from ..strong_typing.core import JsonType
@ -276,6 +276,20 @@ class StatusResponse:
examples: List[Any] = dataclasses.field(default_factory=list) examples: List[Any] = dataclasses.field(default_factory=list)
def create_docstring_for_request(
request_name: str, fields: List[Tuple[str, type, Any]], doc_params: Dict[str, str]
) -> str:
"""Creates a ReST-style docstring for a dynamically generated request dataclass."""
lines = ["\n"] # Short description
# Add parameter documentation in ReST format
for name, type_ in fields:
desc = doc_params.get(name, "")
lines.append(f":param {name}: {desc}")
return "\n".join(lines)
class ResponseBuilder: class ResponseBuilder:
content_builder: ContentBuilder content_builder: ContentBuilder
@ -493,11 +507,24 @@ class Generator:
first = next(iter(op.request_params)) first = next(iter(op.request_params))
request_name, request_type = first request_name, request_type = first
from dataclasses import make_dataclass
op_name = "".join(word.capitalize() for word in op.name.split("_")) op_name = "".join(word.capitalize() for word in op.name.split("_"))
request_name = f"{op_name}Request" request_name = f"{op_name}Request"
request_type = make_dataclass(request_name, op.request_params) fields = [
(
name,
type_,
)
for name, type_ in op.request_params
]
request_type = make_dataclass(
request_name,
fields,
namespace={
"__doc__": create_docstring_for_request(
request_name, fields, doc_params
)
},
)
requestBody = RequestBody( requestBody = RequestBody(
content={ content={
@ -650,12 +677,6 @@ class Generator:
) )
) )
# types that are produced/consumed by operations
type_tags = [
self._build_type_tag(ref, schema)
for ref, schema in self.schema_builder.schemas.items()
]
# types that are emitted by events # types that are emitted by events
event_tags: List[Tag] = [] event_tags: List[Tag] = []
events = get_endpoint_events(self.endpoint) events = get_endpoint_events(self.endpoint)
@ -682,7 +703,6 @@ class Generator:
# list all operations and types # list all operations and types
tags: List[Tag] = [] tags: List[Tag] = []
tags.extend(operation_tags) tags.extend(operation_tags)
tags.extend(type_tags)
tags.extend(event_tags) tags.extend(event_tags)
for extra_tag_group in extra_tag_groups.values(): for extra_tag_group in extra_tag_groups.values():
tags.extend(extra_tag_group) tags.extend(extra_tag_group)
@ -697,13 +717,6 @@ class Generator:
tags=sorted(tag.name for tag in operation_tags), tags=sorted(tag.name for tag in operation_tags),
) )
) )
if type_tags:
tag_groups.append(
TagGroup(
name=self.options.map("Types"),
tags=sorted(tag.name for tag in type_tags),
)
)
if event_tags: if event_tags:
tag_groups.append( tag_groups.append(
TagGroup( TagGroup(

View file

@ -531,6 +531,7 @@ class JsonSchemaGenerator:
# add property docstring if available # add property docstring if available
property_doc = property_docstrings.get(property_name) property_doc = property_docstrings.get(property_name)
if property_doc: if property_doc:
# print(output_name, property_doc)
property_def.pop("title", None) property_def.pop("title", None)
property_def["description"] = property_doc property_def["description"] = property_doc

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -297,6 +297,16 @@ class AgentStepResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Agents(Protocol): class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems.
Main functionalities provided by this API:
- Create agents with specific instructions and ability to use tools.
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
- Agents can be provided with various shields (see the Safety API for more details).
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST") @webmethod(route="/agents", method="POST")
async def create_agent( async def create_agent(
self, self,

View file

@ -7,13 +7,15 @@
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, ChatCompletionResponse,
CompletionResponse,
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolDefinition, ToolDefinition,
@ -21,35 +23,14 @@ from llama_stack.apis.inference import (
) )
@json_schema_type
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedContent]
sampling_params: Optional[SamplingParams] = SamplingParams()
logprobs: Optional[LogProbConfig] = None
@json_schema_type @json_schema_type
class BatchCompletionResponse(BaseModel): class BatchCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@json_schema_type @json_schema_type
class BatchChatCompletionResponse(BaseModel): class BatchChatCompletionResponse(BaseModel):
completion_message_batch: List[CompletionMessage] batch: List[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
@ -60,6 +41,7 @@ class BatchInference(Protocol):
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@ -73,5 +55,6 @@ class BatchInference(Protocol):
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse: ...

View file

@ -77,7 +77,6 @@ class ImageDelta(BaseModel):
image: bytes image: bytes
@json_schema_type
class ToolCallParseStatus(Enum): class ToolCallParseStatus(Enum):
started = "started" started = "started"
in_progress = "in_progress" in_progress = "in_progress"

View file

@ -35,11 +35,22 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
class LogProbConfig(BaseModel): class LogProbConfig(BaseModel):
"""
:param top_k: How many tokens (for each position) to return log probabilities for.
"""
top_k: Optional[int] = 0 top_k: Optional[int] = 0
@json_schema_type
class QuantizationType(Enum): class QuantizationType(Enum):
"""Type of model quantization to run inference with.
:cvar bf16: BFloat16 typically this means _no_ quantization
:cvar fp8: 8-bit floating point quantization
:cvar int4: 4-bit integer quantization
"""
bf16 = "bf16" bf16 = "bf16"
fp8 = "fp8" fp8 = "fp8"
int4 = "int4" int4 = "int4"
@ -57,6 +68,12 @@ class Bf16QuantizationConfig(BaseModel):
@json_schema_type @json_schema_type
class Int4QuantizationConfig(BaseModel): class Int4QuantizationConfig(BaseModel):
"""Configuration for 4-bit integer quantization.
:param type: Must be "int4" to identify this quantization type
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
"""
type: Literal["int4"] = "int4" type: Literal["int4"] = "int4"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation" scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
@ -69,6 +86,13 @@ QuantizationConfig = Annotated[
@json_schema_type @json_schema_type
class UserMessage(BaseModel): class UserMessage(BaseModel):
"""A message from the user in a chat conversation.
:param role: Must be "user" to identify this as a user message
:param content: The content of the message, which can include text and other media
:param context: (Optional) This field is used internally by Llama Stack to pass RAG context. This field may be removed in the API in the future.
"""
role: Literal["user"] = "user" role: Literal["user"] = "user"
content: InterleavedContent content: InterleavedContent
context: Optional[InterleavedContent] = None context: Optional[InterleavedContent] = None
@ -76,15 +100,27 @@ class UserMessage(BaseModel):
@json_schema_type @json_schema_type
class SystemMessage(BaseModel): class SystemMessage(BaseModel):
"""A system message providing instructions or context to the model.
:param role: Must be "system" to identify this as a system message
:param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions).
"""
role: Literal["system"] = "system" role: Literal["system"] = "system"
content: InterleavedContent content: InterleavedContent
@json_schema_type @json_schema_type
class ToolResponseMessage(BaseModel): class ToolResponseMessage(BaseModel):
"""A message representing the result of a tool invocation.
:param role: Must be "tool" to identify this as a tool response
:param call_id: Unique identifier for the tool call this response is for
:param tool_name: Name of the tool that was called
:param content: The response content from the tool
"""
role: Literal["tool"] = "tool" role: Literal["tool"] = "tool"
# it was nice to re-use the ToolResponse type, but having all messages
# have a `content` type makes things nicer too
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
content: InterleavedContent content: InterleavedContent
@ -92,6 +128,17 @@ class ToolResponseMessage(BaseModel):
@json_schema_type @json_schema_type
class CompletionMessage(BaseModel): class CompletionMessage(BaseModel):
"""A message containing the model's (assistant) response in a chat conversation.
:param role: Must be "assistant" to identify this as the model's response
:param content: The content of the model's response
:param stop_reason: Reason why the model stopped generating. Options are:
- `StopReason.end_of_turn`: The model finished generating the entire response.
- `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response.
- `StopReason.out_of_tokens`: The model ran out of token budget.
:param tool_calls: List of tool calls. Each tool call is a ToolCall object.
"""
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: InterleavedContent content: InterleavedContent
stop_reason: StopReason stop_reason: StopReason
@ -129,19 +176,35 @@ class ToolResponse(BaseModel):
return v return v
@json_schema_type
class ToolChoice(Enum): class ToolChoice(Enum):
"""Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.
:cvar auto: The model may use tools if it determines that is appropriate.
:cvar required: The model must use tools.
"""
auto = "auto" auto = "auto"
required = "required" required = "required"
@json_schema_type @json_schema_type
class TokenLogProbs(BaseModel): class TokenLogProbs(BaseModel):
"""Log probabilities for generated tokens.
:param logprobs_by_token: Dictionary mapping tokens to their log probabilities
"""
logprobs_by_token: Dict[str, float] logprobs_by_token: Dict[str, float]
@json_schema_type
class ChatCompletionResponseEventType(Enum): class ChatCompletionResponseEventType(Enum):
"""Types of events that can occur during chat completion.
:cvar start: Inference has started
:cvar complete: Inference is complete and a full response is available
:cvar progress: Inference is in progress and a partial response is available
"""
start = "start" start = "start"
complete = "complete" complete = "complete"
progress = "progress" progress = "progress"
@ -149,7 +212,13 @@ class ChatCompletionResponseEventType(Enum):
@json_schema_type @json_schema_type
class ChatCompletionResponseEvent(BaseModel): class ChatCompletionResponseEvent(BaseModel):
"""Chat completion response event.""" """An event during chat completion generation.
:param event_type: Type of the event
:param delta: Content generated since last event. This can be one or more tokens, or a tool call.
:param logprobs: Optional log probabilities for generated tokens
:param stop_reason: Optional reason why generation stopped, if complete
"""
event_type: ChatCompletionResponseEventType event_type: ChatCompletionResponseEventType
delta: ContentDelta delta: ContentDelta
@ -157,14 +226,25 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
@json_schema_type
class ResponseFormatType(Enum): class ResponseFormatType(Enum):
"""Types of formats for structured (guided) decoding.
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
:cvar grammar: Response should conform to a BNF grammar
"""
json_schema = "json_schema" json_schema = "json_schema"
grammar = "grammar" grammar = "grammar"
@json_schema_type @json_schema_type
class JsonSchemaResponseFormat(BaseModel): class JsonSchemaResponseFormat(BaseModel):
"""Configuration for JSON schema-guided response generation.
:param type: Must be "json_schema" to identify this format type
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
"""
type: Literal[ResponseFormatType.json_schema.value] = ( type: Literal[ResponseFormatType.json_schema.value] = (
ResponseFormatType.json_schema.value ResponseFormatType.json_schema.value
) )
@ -173,6 +253,12 @@ class JsonSchemaResponseFormat(BaseModel):
@json_schema_type @json_schema_type
class GrammarResponseFormat(BaseModel): class GrammarResponseFormat(BaseModel):
"""Configuration for grammar-guided response generation.
:param type: Must be "grammar" to identify this format type
:param bnf: The BNF grammar specification the response should conform to
"""
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
bnf: Dict[str, Any] bnf: Dict[str, Any]
@ -186,20 +272,24 @@ ResponseFormat = register_schema(
) )
@json_schema_type # This is an internally used class
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
content: InterleavedContent content: InterleavedContent
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@json_schema_type @json_schema_type
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Completion response.""" """Response from a completion request.
:param content: The generated completion text
:param stop_reason: Reason why generation stopped
:param logprobs: Optional log probabilities for generated tokens
"""
content: str content: str
stop_reason: StopReason stop_reason: StopReason
@ -208,80 +298,60 @@ class CompletionResponse(BaseModel):
@json_schema_type @json_schema_type
class CompletionResponseStreamChunk(BaseModel): class CompletionResponseStreamChunk(BaseModel):
"""streamed completion response.""" """A chunk of a streamed completion response.
:param delta: New content generated since last chunk. This can be one or more tokens.
:param stop_reason: Optional reason why generation stopped, if complete
:param logprobs: Optional log probabilities for generated tokens
"""
delta: str delta: str
stop_reason: Optional[StopReason] = None stop_reason: Optional[StopReason] = None
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type # This is an internally used class
class BatchCompletionRequest(BaseModel):
model: str
content_batch: List[InterleavedContent]
sampling_params: Optional[SamplingParams] = SamplingParams()
response_format: Optional[ResponseFormat] = None
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Batch completion response."""
batch: List[CompletionResponse]
@json_schema_type
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[Message] messages: List[Message]
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@json_schema_type @json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel): class ChatCompletionResponseStreamChunk(BaseModel):
"""SSE-stream of these events.""" """A chunk of a streamed chat completion response.
:param event: The event containing the new content
"""
event: ChatCompletionResponseEvent event: ChatCompletionResponseEvent
@json_schema_type @json_schema_type
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
"""Chat completion response.""" """Response from a chat completion request.
:param completion_message: The complete response message
:param logprobs: Optional log probabilities for generated tokens
"""
completion_message: CompletionMessage completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]] = None logprobs: Optional[List[TokenLogProbs]] = None
@json_schema_type
class BatchChatCompletionRequest(BaseModel):
model: str
messages_batch: List[List[Message]]
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@json_schema_type @json_schema_type
class EmbeddingsResponse(BaseModel): class EmbeddingsResponse(BaseModel):
"""Response containing generated embeddings.
:param embeddings: List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
embeddings: List[List[float]] embeddings: List[List[float]]
@ -292,6 +362,13 @@ class ModelStore(Protocol):
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Inference(Protocol): class Inference(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
model_store: ModelStore model_store: ModelStore
@webmethod(route="/inference/completion", method="POST") @webmethod(route="/inference/completion", method="POST")
@ -303,7 +380,19 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
"""Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for
:param sampling_params: (Optional) Parameters to control the sampling strategy
:param response_format: (Optional) Grammar specification for guided (structured) decoding
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk
"""
...
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -311,7 +400,6 @@ class Inference(Protocol):
model_id: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
@ -320,11 +408,38 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]: ... ]:
"""Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation
:param sampling_params: Parameters to control the sampling strategy
:param tools: (Optional) List of tool definitions available to the model
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
:param tool_prompt_format: (Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag.
- `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls.
:param response_format: (Optional) Grammar specification for guided (structured) decoding. There are two options:
- `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format.
- `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk
"""
...
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[InterleavedContent],
) -> EmbeddingsResponse: ... ) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...

View file

@ -6,11 +6,9 @@
from enum import Enum from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@json_schema_type
class ResourceType(Enum): class ResourceType(Enum):
model = "model" model = "model"
shield = "shield" shield = "shield"

View file

@ -339,7 +339,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
method=options.method, method=options.method,
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers, headers=options.headers or {},
json=options.json_data, json=options.json_data,
), ),
) )
@ -388,7 +388,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
method=options.method, method=options.method,
url=options.url, url=options.url,
params=options.params, params=options.params,
headers=options.headers, headers=options.headers or {},
json=options.json_data, json=options.json_data,
), ),
) )