diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py
index ad8f2952e..f4238f6f8 100644
--- a/docs/openapi_generator/pyopenapi/operations.py
+++ b/docs/openapi_generator/pyopenapi/operations.py
@@ -315,7 +315,20 @@ def get_endpoint_operations(
)
else:
event_type = None
- response_type = return_type
+
+ def process_type(t):
+ if typing.get_origin(t) is collections.abc.AsyncIterator:
+ # NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
+ # or the item type. I am choosing it to be the latter
+ args = typing.get_args(t)
+ return args[0]
+ elif typing.get_origin(t) is typing.Union:
+ types = [process_type(a) for a in typing.get_args(t)]
+ return typing._UnionGenericAlias(typing.Union, tuple(types))
+ else:
+ return t
+
+ response_type = process_type(return_type)
# set HTTP request method based on type of request and presence of payload
if not request_params:
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 886634fba..363d968f9 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
},
"servers": [
{
@@ -320,11 +320,18 @@
"post": {
"responses": {
"200": {
- "description": "OK",
+ "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.",
"content": {
"text/event-stream": {
"schema": {
- "$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/Turn"
+ },
+ {
+ "$ref": "#/components/schemas/AgentTurnResponseStreamChunk"
+ }
+ ]
}
}
}
@@ -934,7 +941,7 @@
"schema": {
"oneOf": [
{
- "$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
+ "$ref": "#/components/schemas/ScoringFnDefWithProvider"
},
{
"type": "null"
@@ -1555,7 +1562,7 @@
"content": {
"application/jsonl": {
"schema": {
- "$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
+ "$ref": "#/components/schemas/ScoringFnDefWithProvider"
}
}
}
@@ -2762,7 +2769,7 @@
"const": "json_schema",
"default": "json_schema"
},
- "schema": {
+ "json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
@@ -2791,7 +2798,7 @@
"additionalProperties": false,
"required": [
"type",
- "schema"
+ "json_schema"
]
},
{
@@ -3018,7 +3025,7 @@
"const": "json_schema",
"default": "json_schema"
},
- "schema": {
+ "json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
@@ -3047,7 +3054,7 @@
"additionalProperties": false,
"required": [
"type",
- "schema"
+ "json_schema"
]
},
{
@@ -4002,7 +4009,8 @@
"additionalProperties": false,
"required": [
"event"
- ]
+ ],
+ "title": "streamed agent turn completion response."
},
"AgentTurnResponseTurnCompletePayload": {
"type": "object",
@@ -5004,24 +5012,6 @@
"type"
]
},
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "custom",
- "default": "custom"
- },
- "validator_class": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "validator_class"
- ]
- },
{
"type": "object",
"properties": {
@@ -5304,24 +5294,6 @@
"type"
]
},
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "custom",
- "default": "custom"
- },
- "validator_class": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "validator_class"
- ]
- },
{
"type": "object",
"properties": {
@@ -5376,7 +5348,7 @@
"type"
]
},
- "ScoringFunctionDefWithProvider": {
+ "ScoringFnDefWithProvider": {
"type": "object",
"properties": {
"identifier": {
@@ -5516,24 +5488,6 @@
"type"
]
},
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "custom",
- "default": "custom"
- },
- "validator_class": {
- "type": "string"
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "validator_class"
- ]
- },
{
"type": "object",
"properties": {
@@ -5586,6 +5540,12 @@
},
"prompt_template": {
"type": "string"
+ },
+ "judge_score_regex": {
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
}
},
"additionalProperties": false,
@@ -6339,10 +6299,10 @@
"finetuned_model": {
"$ref": "#/components/schemas/URL"
},
- "dataset": {
+ "dataset_id": {
"type": "string"
},
- "validation_dataset": {
+ "validation_dataset_id": {
"type": "string"
},
"algorithm": {
@@ -6412,8 +6372,8 @@
"required": [
"job_uuid",
"finetuned_model",
- "dataset",
- "validation_dataset",
+ "dataset_id",
+ "validation_dataset_id",
"algorithm",
"algorithm_config",
"optimizer_config",
@@ -6595,7 +6555,7 @@
"type": "object",
"properties": {
"function_def": {
- "$ref": "#/components/schemas/ScoringFunctionDefWithProvider"
+ "$ref": "#/components/schemas/ScoringFnDefWithProvider"
}
},
"additionalProperties": false,
@@ -6893,10 +6853,10 @@
"model": {
"type": "string"
},
- "dataset": {
+ "dataset_id": {
"type": "string"
},
- "validation_dataset": {
+ "validation_dataset_id": {
"type": "string"
},
"algorithm": {
@@ -6976,8 +6936,8 @@
"required": [
"job_uuid",
"model",
- "dataset",
- "validation_dataset",
+ "dataset_id",
+ "validation_dataset_id",
"algorithm",
"algorithm_config",
"optimizer_config",
@@ -7102,57 +7062,57 @@
}
],
"tags": [
- {
- "name": "Eval"
- },
- {
- "name": "ScoringFunctions"
- },
- {
- "name": "SyntheticDataGeneration"
- },
- {
- "name": "Inspect"
- },
- {
- "name": "PostTraining"
- },
- {
- "name": "Models"
- },
- {
- "name": "Safety"
- },
- {
- "name": "MemoryBanks"
- },
- {
- "name": "DatasetIO"
- },
{
"name": "Memory"
},
- {
- "name": "Scoring"
- },
- {
- "name": "Shields"
- },
- {
- "name": "Datasets"
- },
{
"name": "Inference"
},
{
- "name": "Telemetry"
+ "name": "Eval"
+ },
+ {
+ "name": "MemoryBanks"
+ },
+ {
+ "name": "Models"
},
{
"name": "BatchInference"
},
+ {
+ "name": "PostTraining"
+ },
{
"name": "Agents"
},
+ {
+ "name": "Shields"
+ },
+ {
+ "name": "Telemetry"
+ },
+ {
+ "name": "Inspect"
+ },
+ {
+ "name": "DatasetIO"
+ },
+ {
+ "name": "SyntheticDataGeneration"
+ },
+ {
+ "name": "Datasets"
+ },
+ {
+ "name": "Scoring"
+ },
+ {
+ "name": "ScoringFunctions"
+ },
+ {
+ "name": "Safety"
+ },
{
"name": "BuiltinTool",
"description": ""
@@ -7355,7 +7315,7 @@
},
{
"name": "AgentTurnResponseStreamChunk",
- "description": ""
+ "description": "streamed agent turn completion response.\n\n"
},
{
"name": "AgentTurnResponseTurnCompletePayload",
@@ -7486,8 +7446,8 @@
"description": ""
},
{
- "name": "ScoringFunctionDefWithProvider",
- "description": ""
+ "name": "ScoringFnDefWithProvider",
+ "description": ""
},
{
"name": "ShieldDefWithProvider",
@@ -7805,7 +7765,7 @@
"ScoreBatchResponse",
"ScoreRequest",
"ScoreResponse",
- "ScoringFunctionDefWithProvider",
+ "ScoringFnDefWithProvider",
"ScoringResult",
"SearchToolDefinition",
"Session",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 9dcdbb028..7dd231965 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -190,6 +190,7 @@ components:
$ref: '#/components/schemas/AgentTurnResponseEvent'
required:
- event
+ title: streamed agent turn completion response.
type: object
AgentTurnResponseTurnCompletePayload:
additionalProperties: false
@@ -360,7 +361,7 @@ components:
oneOf:
- additionalProperties: false
properties:
- schema:
+ json_schema:
additionalProperties:
oneOf:
- type: 'null'
@@ -376,7 +377,7 @@ components:
type: string
required:
- type
- - schema
+ - json_schema
type: object
- additionalProperties: false
properties:
@@ -541,7 +542,7 @@ components:
oneOf:
- additionalProperties: false
properties:
- schema:
+ json_schema:
additionalProperties:
oneOf:
- type: 'null'
@@ -557,7 +558,7 @@ components:
type: string
required:
- type
- - schema
+ - json_schema
type: object
- additionalProperties: false
properties:
@@ -747,18 +748,6 @@ components:
required:
- type
type: object
- - additionalProperties: false
- properties:
- type:
- const: custom
- default: custom
- type: string
- validator_class:
- type: string
- required:
- - type
- - validator_class
- type: object
- additionalProperties: false
properties:
type:
@@ -1575,18 +1564,6 @@ components:
required:
- type
type: object
- - additionalProperties: false
- properties:
- type:
- const: custom
- default: custom
- type: string
- validator_class:
- type: string
- required:
- - type
- - validator_class
- type: object
- additionalProperties: false
properties:
type:
@@ -1724,7 +1701,7 @@ components:
$ref: '#/components/schemas/RLHFAlgorithm'
algorithm_config:
$ref: '#/components/schemas/DPOAlignmentConfig'
- dataset:
+ dataset_id:
type: string
finetuned_model:
$ref: '#/components/schemas/URL'
@@ -1754,13 +1731,13 @@ components:
$ref: '#/components/schemas/OptimizerConfig'
training_config:
$ref: '#/components/schemas/TrainingConfig'
- validation_dataset:
+ validation_dataset_id:
type: string
required:
- job_uuid
- finetuned_model
- - dataset
- - validation_dataset
+ - dataset_id
+ - validation_dataset_id
- algorithm
- algorithm_config
- optimizer_config
@@ -1899,7 +1876,7 @@ components:
additionalProperties: false
properties:
function_def:
- $ref: '#/components/schemas/ScoringFunctionDefWithProvider'
+ $ref: '#/components/schemas/ScoringFnDefWithProvider'
required:
- function_def
type: object
@@ -2121,7 +2098,7 @@ components:
required:
- results
type: object
- ScoringFunctionDefWithProvider:
+ ScoringFnDefWithProvider:
additionalProperties: false
properties:
context:
@@ -2129,6 +2106,10 @@ components:
properties:
judge_model:
type: string
+ judge_score_regex:
+ items:
+ type: string
+ type: array
prompt_template:
type: string
required:
@@ -2219,18 +2200,6 @@ components:
required:
- type
type: object
- - additionalProperties: false
- properties:
- type:
- const: custom
- default: custom
- type: string
- validator_class:
- type: string
- required:
- - type
- - validator_class
- type: object
- additionalProperties: false
properties:
type:
@@ -2484,7 +2453,7 @@ components:
- $ref: '#/components/schemas/LoraFinetuningConfig'
- $ref: '#/components/schemas/QLoraFinetuningConfig'
- $ref: '#/components/schemas/DoraFinetuningConfig'
- dataset:
+ dataset_id:
type: string
hyperparam_search_config:
additionalProperties:
@@ -2514,13 +2483,13 @@ components:
$ref: '#/components/schemas/OptimizerConfig'
training_config:
$ref: '#/components/schemas/TrainingConfig'
- validation_dataset:
+ validation_dataset_id:
type: string
required:
- job_uuid
- model
- - dataset
- - validation_dataset
+ - dataset_id
+ - validation_dataset_id
- algorithm
- algorithm_config
- optimizer_config
@@ -3029,7 +2998,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117"
+ \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -3222,8 +3191,11 @@ paths:
content:
text/event-stream:
schema:
- $ref: '#/components/schemas/AgentTurnResponseStreamChunk'
- description: OK
+ oneOf:
+ - $ref: '#/components/schemas/Turn'
+ - $ref: '#/components/schemas/AgentTurnResponseStreamChunk'
+ description: A single turn in an interaction with an Agentic System. **OR**
+ streamed agent turn completion response.
tags:
- Agents
/agents/turn/get:
@@ -4122,7 +4094,7 @@ paths:
application/json:
schema:
oneOf:
- - $ref: '#/components/schemas/ScoringFunctionDefWithProvider'
+ - $ref: '#/components/schemas/ScoringFnDefWithProvider'
- type: 'null'
description: OK
tags:
@@ -4142,7 +4114,7 @@ paths:
content:
application/jsonl:
schema:
- $ref: '#/components/schemas/ScoringFunctionDefWithProvider'
+ $ref: '#/components/schemas/ScoringFnDefWithProvider'
description: OK
tags:
- ScoringFunctions
@@ -4308,23 +4280,23 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Eval
-- name: ScoringFunctions
-- name: SyntheticDataGeneration
-- name: Inspect
-- name: PostTraining
-- name: Models
-- name: Safety
-- name: MemoryBanks
-- name: DatasetIO
- name: Memory
-- name: Scoring
-- name: Shields
-- name: Datasets
- name: Inference
-- name: Telemetry
+- name: Eval
+- name: MemoryBanks
+- name: Models
- name: BatchInference
+- name: PostTraining
- name: Agents
+- name: Shields
+- name: Telemetry
+- name: Inspect
+- name: DatasetIO
+- name: SyntheticDataGeneration
+- name: Datasets
+- name: Scoring
+- name: ScoringFunctions
+- name: Safety
- description:
name: BuiltinTool
- description:
name: AgentTurnResponseStepStartPayload
-- description:
+- description: 'streamed agent turn completion response.
+
+
+ '
name: AgentTurnResponseStreamChunk
- description:
@@ -4577,9 +4552,9 @@ tags:
name: PaginatedRowsResult
- description:
name: Parameter
-- description:
- name: ScoringFunctionDefWithProvider
+ name: ScoringFnDefWithProvider
- description:
name: ShieldDefWithProvider
@@ -4844,7 +4819,7 @@ x-tagGroups:
- ScoreBatchResponse
- ScoreRequest
- ScoreResponse
- - ScoringFunctionDefWithProvider
+ - ScoringFnDefWithProvider
- ScoringResult
- SearchToolDefinition
- Session
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index e0eaacf51..613844f5e 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -8,6 +8,7 @@ from datetime import datetime
from enum import Enum
from typing import (
Any,
+ AsyncIterator,
Dict,
List,
Literal,
@@ -405,6 +406,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
+ """streamed agent turn completion response."""
+
event: AgentTurnResponseEvent
@@ -434,7 +437,7 @@ class Agents(Protocol):
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
- ) -> AgentTurnResponseStreamChunk: ...
+ ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/turn/get")
async def get_agents_turn(
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index eb2c41d32..4b6530f63 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -6,7 +6,15 @@
from enum import Enum
-from typing import List, Literal, Optional, Protocol, runtime_checkable, Union
+from typing import (
+ AsyncIterator,
+ List,
+ Literal,
+ Optional,
+ Protocol,
+ runtime_checkable,
+ Union,
+)
from llama_models.schema_utils import json_schema_type, webmethod
@@ -224,7 +232,7 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
- ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
+ ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat_completion")
async def chat_completion(
@@ -239,7 +247,9 @@ class Inference(Protocol):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
- ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ...
+ ) -> Union[
+ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
+ ]: ...
@webmethod(route="/inference/embeddings")
async def embeddings(
diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh
index 8044dda28..ae2b17d9e 100755
--- a/llama_stack/distribution/build_container.sh
+++ b/llama_stack/distribution/build_container.sh
@@ -77,9 +77,9 @@ if [ -n "$LLAMA_STACK_DIR" ]; then
# Install in editable format. We will mount the source code into the container
# so that changes will be reflected in the container without having to do a
# rebuild. This is just for development convenience.
- add_to_docker "RUN pip install -e $stack_mount"
+ add_to_docker "RUN pip install --no-cache -e $stack_mount"
else
- add_to_docker "RUN pip install llama-stack"
+ add_to_docker "RUN pip install --no-cache llama-stack"
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
@@ -90,19 +90,19 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
add_to_docker < Type:
+ if protocol in _CLIENT_CLASSES:
+ return _CLIENT_CLASSES[protocol]
+
+ protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
+
+ class APIClient:
+ def __init__(self, base_url: str):
+ print(f"({protocol.__name__}) Connecting to {base_url}")
+ self.base_url = base_url.rstrip("/")
+ self.routes = {}
+
+ # Store routes for this protocol
+ for p in protocols:
+ for name, method in inspect.getmembers(p):
+ if hasattr(method, "__webmethod__"):
+ sig = inspect.signature(method)
+ self.routes[name] = (method.__webmethod__, sig)
+
+ async def initialize(self):
+ pass
+
+ async def shutdown(self):
+ pass
+
+ async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
+ assert method_name in self.routes, f"Unknown endpoint: {method_name}"
+
+ # TODO: make this more precise, same thing needs to happen in server.py
+ is_streaming = kwargs.get("stream", False)
+ if is_streaming:
+ return self._call_streaming(method_name, *args, **kwargs)
+ else:
+ return await self._call_non_streaming(method_name, *args, **kwargs)
+
+ async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
+ _, sig = self.routes[method_name]
+
+ if sig.return_annotation is None:
+ return_type = None
+ else:
+ return_type = extract_non_async_iterator_type(sig.return_annotation)
+ assert (
+ return_type
+ ), f"Could not extract return type for {sig.return_annotation}"
+
+ async with httpx.AsyncClient() as client:
+ params = self.httpx_request_params(method_name, *args, **kwargs)
+ response = await client.request(**params)
+ response.raise_for_status()
+
+ j = response.json()
+ if j is None:
+ return None
+ return parse_obj_as(return_type, j)
+
+ async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
+ webmethod, sig = self.routes[method_name]
+
+ return_type = extract_async_iterator_type(sig.return_annotation)
+ assert (
+ return_type
+ ), f"Could not extract return type for {sig.return_annotation}"
+
+ async with httpx.AsyncClient() as client:
+ params = self.httpx_request_params(method_name, *args, **kwargs)
+ async with client.stream(**params) as response:
+ response.raise_for_status()
+
+ async for line in response.aiter_lines():
+ if line.startswith("data:"):
+ data = line[len("data: ") :]
+ try:
+ if "error" in data:
+ cprint(data, "red")
+ continue
+
+ yield parse_obj_as(return_type, json.loads(data))
+ except Exception as e:
+ print(data)
+ print(f"Error with parsing or validation: {e}")
+
+ def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
+ webmethod, sig = self.routes[method_name]
+
+ parameters = list(sig.parameters.values())[1:] # skip `self`
+ for i, param in enumerate(parameters):
+ if i >= len(args):
+ break
+ kwargs[param.name] = args[i]
+
+ url = f"{self.base_url}{webmethod.route}"
+
+ def convert(value):
+ if isinstance(value, list):
+ return [convert(v) for v in value]
+ elif isinstance(value, dict):
+ return {k: convert(v) for k, v in value.items()}
+ elif isinstance(value, BaseModel):
+ return json.loads(value.model_dump_json())
+ elif isinstance(value, Enum):
+ return value.value
+ else:
+ return value
+
+ params = {}
+ data = {}
+ if webmethod.method == "GET":
+ params.update(kwargs)
+ else:
+ data.update(convert(kwargs))
+
+ return dict(
+ method=webmethod.method or "POST",
+ url=url,
+ headers={"Content-Type": "application/json"},
+ params=params,
+ json=data,
+ timeout=30,
+ )
+
+ # Add protocol methods to the wrapper
+ for p in protocols:
+ for name, method in inspect.getmembers(p):
+ if hasattr(method, "__webmethod__"):
+
+ async def method_impl(self, *args, method_name=name, **kwargs):
+ return await self.__acall__(method_name, *args, **kwargs)
+
+ method_impl.__name__ = name
+ method_impl.__qualname__ = f"APIClient.{name}"
+ method_impl.__signature__ = inspect.signature(method)
+ setattr(APIClient, name, method_impl)
+
+ # Name the class after the protocol
+ APIClient.__name__ = f"{protocol.__name__}Client"
+ _CLIENT_CLASSES[protocol] = APIClient
+ return APIClient
+
+
+# not quite general these methods are
+def extract_non_async_iterator_type(type_hint):
+ if get_origin(type_hint) is Union:
+ args = get_args(type_hint)
+ for arg in args:
+ if not issubclass(get_origin(arg) or arg, AsyncIterator):
+ return arg
+ return type_hint
+
+
+def extract_async_iterator_type(type_hint):
+ if get_origin(type_hint) is Union:
+ args = get_args(type_hint)
+ for arg in args:
+ if issubclass(get_origin(arg) or arg, AsyncIterator):
+ inner_args = get_args(arg)
+ return inner_args[0]
+ return None
+
+
+async def example(model: str = None):
+ from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
+ from llama_stack.apis.inference.event_logger import EventLogger
+
+ client_class = create_api_client_class(Inference)
+ client = client_class("http://localhost:5003")
+
+ if not model:
+ model = "Llama3.2-3B-Instruct"
+
+ message = UserMessage(
+ content="hello world, write me a 2 sentence poem about the moon"
+ )
+ cprint(f"User>{message.content}", "green")
+
+ stream = True
+ iterator = await client.chat_completion(
+ model=model,
+ messages=[message],
+ stream=stream,
+ )
+
+ async for log in EventLogger().log(iterator):
+ log.print()
+
+
+if __name__ == "__main__":
+ import asyncio
+
+ asyncio.run(example())
diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py
index bab807da9..a93cc1183 100644
--- a/llama_stack/distribution/resolver.py
+++ b/llama_stack/distribution/resolver.py
@@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.safety: Safety,
Api.shields: Shields,
Api.telemetry: Telemetry,
- Api.datasets: Datasets,
Api.datasetio: DatasetIO,
- Api.scoring_functions: ScoringFunctions,
+ Api.datasets: Datasets,
Api.scoring: Scoring,
+ Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
}
def additional_protocols_map() -> Dict[Api, Any]:
return {
- Api.inference: ModelsProtocolPrivate,
- Api.memory: MemoryBanksProtocolPrivate,
- Api.safety: ShieldsProtocolPrivate,
+ Api.inference: (ModelsProtocolPrivate, Models),
+ Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
+ Api.safety: (ShieldsProtocolPrivate, Shields),
+ Api.datasetio: (DatasetsProtocolPrivate, Datasets),
+ Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
}
@@ -112,8 +114,6 @@ async def resolve_impls(
if info.router_api.value not in apis_to_serve:
continue
- available_providers = providers_with_specs[f"inner-{info.router_api.value}"]
-
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__routing_table__",
@@ -246,14 +246,21 @@ async def instantiate_provider(
args = []
if isinstance(provider_spec, RemoteProviderSpec):
- if provider_spec.adapter:
- method = "get_adapter_impl"
- else:
- method = "get_client_impl"
-
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
- args = [config, deps]
+
+ if provider_spec.adapter:
+ method = "get_adapter_impl"
+ args = [config, deps]
+ else:
+ method = "get_client_impl"
+ protocol = protocols[provider_spec.api]
+ if provider_spec.api in additional_protocols:
+ _, additional_protocol = additional_protocols[provider_spec.api]
+ else:
+ additional_protocol = None
+ args = [protocol, additional_protocol, config, deps]
+
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@@ -282,7 +289,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
- additional_api = additional_protocols[provider_spec.api]
+ additional_api, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 3e07b9162..4e462c54b 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api:
async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
+
+ if obj.provider_id == "remote":
+ # if this is just a passthrough, we want to let the remote
+ # end actually do the registration with the correct provider
+ obj = obj.model_copy(deep=True)
+ obj.provider_id = ""
+
if api == Api.inference:
await p.register_model(obj)
elif api == Api.safety:
@@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable):
async def initialize(self) -> None:
self.registry: Registry = {}
- def add_objects(objs: List[RoutableObjectWithProvider]) -> None:
+ def add_objects(
+ objs: List[RoutableObjectWithProvider], provider_id: str, cls
+ ) -> None:
for obj in objs:
if obj.identifier not in self.registry:
self.registry[obj.identifier] = []
+ if cls is None:
+ obj.provider_id = provider_id
+ else:
+ if provider_id == "remote":
+ # if this is just a passthrough, we got the *WithProvider object
+ # so we should just override the provider in-place
+ obj.provider_id = provider_id
+ else:
+ obj = cls(**obj.model_dump(), provider_id=provider_id)
self.registry[obj.identifier].append(obj)
for pid, p in self.impls_by_provider_id.items():
@@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable):
if api == Api.inference:
p.model_store = self
models = await p.list_models()
- add_objects(
- [ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models]
- )
+ add_objects(models, pid, ModelDefWithProvider)
elif api == Api.safety:
p.shield_store = self
shields = await p.list_shields()
- add_objects(
- [
- ShieldDefWithProvider(**s.dict(), provider_id=pid)
- for s in shields
- ]
- )
+ add_objects(shields, pid, ShieldDefWithProvider)
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
-
- # do in-memory updates due to pesky Annotated unions
- for m in memory_banks:
- m.provider_id = pid
-
- add_objects(memory_banks)
+ add_objects(memory_banks, pid, None)
elif api == Api.datasetio:
p.dataset_store = self
datasets = await p.list_datasets()
-
- # do in-memory updates due to pesky Annotated unions
- for d in datasets:
- d.provider_id = pid
+ add_objects(datasets, pid, DatasetDefWithProvider)
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
- add_objects(
- [
- ScoringFnDefWithProvider(**s.dict(), provider_id=pid)
- for s in scoring_functions
- ]
- )
+ add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py
index 3800c0496..caf886c0b 100644
--- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py
@@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
- ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
+ ) -> AsyncGenerator:
raise NotImplementedError()
@staticmethod
@@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
- # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
- ) -> (
- AsyncGenerator
- ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]:
- bedrock_model = self.map_to_provider_model(model)
- inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
- sampling_params
+ ) -> Union[
+ ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
+ ]:
+ request = ChatCompletionRequest(
+ model=model,
+ messages=messages,
+ sampling_params=sampling_params,
+ tools=tools or [],
+ tool_choice=tool_choice,
+ tool_prompt_format=tool_prompt_format,
+ response_format=response_format,
+ stream=stream,
+ logprobs=logprobs,
)
- tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice)
+ if stream:
+ return self._stream_chat_completion(request)
+ else:
+ return await self._nonstream_chat_completion(request)
+
+ async def _nonstream_chat_completion(
+ self, request: ChatCompletionRequest
+ ) -> ChatCompletionResponse:
+ params = self._get_params_for_chat_completion(request)
+ converse_api_res = self.client.converse(**params)
+
+ output_message = BedrockInferenceAdapter._bedrock_message_to_message(
+ converse_api_res
+ )
+
+ return ChatCompletionResponse(
+ completion_message=output_message,
+ logprobs=None,
+ )
+
+ async def _stream_chat_completion(
+ self, request: ChatCompletionRequest
+ ) -> AsyncGenerator:
+ params = self._get_params_for_chat_completion(request)
+ converse_stream_api_res = self.client.converse_stream(**params)
+ event_stream = converse_stream_api_res["stream"]
+
+ for chunk in event_stream:
+ if "messageStart" in chunk:
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.start,
+ delta="",
+ )
+ )
+ elif "contentBlockStart" in chunk:
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.progress,
+ delta=ToolCallDelta(
+ content=ToolCall(
+ tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
+ call_id=chunk["contentBlockStart"]["toolUse"][
+ "toolUseId"
+ ],
+ ),
+ parse_status=ToolCallParseStatus.started,
+ ),
+ )
+ )
+ elif "contentBlockDelta" in chunk:
+ if "text" in chunk["contentBlockDelta"]["delta"]:
+ delta = chunk["contentBlockDelta"]["delta"]["text"]
+ else:
+ delta = ToolCallDelta(
+ content=ToolCall(
+ arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
+ "input"
+ ]
+ ),
+ parse_status=ToolCallParseStatus.success,
+ )
+
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.progress,
+ delta=delta,
+ )
+ )
+ elif "contentBlockStop" in chunk:
+ # Ignored
+ pass
+ elif "messageStop" in chunk:
+ stop_reason = (
+ BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
+ chunk["messageStop"]["stopReason"]
+ )
+ )
+
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.complete,
+ delta="",
+ stop_reason=stop_reason,
+ )
+ )
+ elif "metadata" in chunk:
+ # Ignored
+ pass
+ else:
+ # Ignored
+ pass
+
+ def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
+ bedrock_model = self.map_to_provider_model(request.model)
+ inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
+ request.sampling_params
+ )
+
+ tool_config = BedrockInferenceAdapter._tools_to_tool_config(
+ request.tools, request.tool_choice
+ )
bedrock_messages, system_bedrock_messages = (
- BedrockInferenceAdapter._messages_to_bedrock_messages(messages)
+ BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
)
converse_api_params = {
@@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
- if tool_config and not stream:
+ if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
- if not stream:
- converse_api_res = self.client.converse(**converse_api_params)
-
- output_message = BedrockInferenceAdapter._bedrock_message_to_message(
- converse_api_res
- )
-
- yield ChatCompletionResponse(
- completion_message=output_message,
- logprobs=None,
- )
- else:
- converse_stream_api_res = self.client.converse_stream(**converse_api_params)
- event_stream = converse_stream_api_res["stream"]
-
- for chunk in event_stream:
- if "messageStart" in chunk:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.start,
- delta="",
- )
- )
- elif "contentBlockStart" in chunk:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.progress,
- delta=ToolCallDelta(
- content=ToolCall(
- tool_name=chunk["contentBlockStart"]["toolUse"][
- "name"
- ],
- call_id=chunk["contentBlockStart"]["toolUse"][
- "toolUseId"
- ],
- ),
- parse_status=ToolCallParseStatus.started,
- ),
- )
- )
- elif "contentBlockDelta" in chunk:
- if "text" in chunk["contentBlockDelta"]["delta"]:
- delta = chunk["contentBlockDelta"]["delta"]["text"]
- else:
- delta = ToolCallDelta(
- content=ToolCall(
- arguments=chunk["contentBlockDelta"]["delta"][
- "toolUse"
- ]["input"]
- ),
- parse_status=ToolCallParseStatus.success,
- )
-
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.progress,
- delta=delta,
- )
- )
- elif "contentBlockStop" in chunk:
- # Ignored
- pass
- elif "messageStop" in chunk:
- stop_reason = (
- BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
- chunk["messageStop"]["stopReason"]
- )
- )
-
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.complete,
- delta="",
- stop_reason=stop_reason,
- )
- )
- elif "metadata" in chunk:
- # Ignored
- pass
- else:
- # Ignored
- pass
+ return converse_api_params
async def embeddings(
self,
diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py
index 4687618fa..4cf55035c 100644
--- a/llama_stack/providers/adapters/inference/vllm/vllm.py
+++ b/llama_stack/providers/adapters/inference/vllm/vllm.py
@@ -75,7 +75,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
for model in self.client.models.list()
]
- def completion(
+ async def completion(
self,
model: str,
content: InterleavedTextMedia,
@@ -86,7 +86,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
raise NotImplementedError()
- def chat_completion(
+ async def chat_completion(
self,
model: str,
messages: List[Message],
@@ -111,7 +111,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if stream:
return self._stream_chat_completion(request, self.client)
else:
- return self._nonstream_chat_completion(request, self.client)
+ return await self._nonstream_chat_completion(request, self.client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py
index eace0ea1a..9a37a28a9 100644
--- a/llama_stack/providers/datatypes.py
+++ b/llama_stack/providers/datatypes.py
@@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def list_datasets(self) -> List[DatasetDef]: ...
- async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
+ async def register_dataset(self, dataset_def: DatasetDef) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
@@ -171,7 +171,7 @@ as being "Llama Stack compatible"
def module(self) -> str:
if self.adapter:
return self.adapter.module
- return f"llama_stack.apis.{self.api.value}.client"
+ return "llama_stack.distribution.client"
@property
def pip_packages(self) -> List[str]:
diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift b/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift
index 89f24a561..84da42d1b 100644
--- a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift
+++ b/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift
@@ -81,7 +81,9 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay
switch (m.content) {
case .case1(let c):
prompt += _processContent(c)
- case .case2(let c):
+ case .ImageMedia(let c):
+ prompt += _processContent(c)
+ case .case3(let c):
prompt += _processContent(c)
}
case .CompletionMessage(let m):
diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py
index 9c34c3a28..c09db3d20 100644
--- a/llama_stack/providers/tests/agents/test_agents.py
+++ b/llama_stack/providers/tests/agents/test_agents.py
@@ -26,6 +26,7 @@ from dotenv import load_dotenv
#
# ```bash
# PROVIDER_ID= \
+# MODEL_ID= \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agents.py \
# --tb=short --disable-warnings
@@ -44,7 +45,7 @@ async def agents_settings():
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"common_params": {
- "model": "Llama3.1-8B-Instruct",
+ "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}
diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py
index b26bf75a7..d83601de1 100644
--- a/llama_stack/providers/tests/memory/test_memory.py
+++ b/llama_stack/providers/tests/memory/test_memory.py
@@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import os
import pytest
import pytest_asyncio
@@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks):
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
- provider_id=os.environ["PROVIDER_ID"],
)
await banks_impl.register_memory_bank(bank)