From 37b330b4ef24a9253abba07cf906bf4ac3af6e55 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 31 Oct 2024 14:46:25 -0700 Subject: [PATCH] add dynamic clients for all APIs (#348) * add dynamic clients for all APIs * fix openapi generator * inference + memory + agents tests now pass with "remote" providers * Add docstring which fixes openapi generator :/ --- .../openapi_generator/pyopenapi/operations.py | 15 +- docs/resources/llama-stack-spec.html | 48 ++-- docs/resources/llama-stack-spec.yaml | 35 +-- llama_stack/apis/agents/agents.py | 5 +- llama_stack/apis/inference/inference.py | 16 +- llama_stack/distribution/client.py | 221 ++++++++++++++++++ llama_stack/distribution/resolver.py | 35 +-- .../distribution/routers/routing_tables.py | 50 ++-- llama_stack/providers/datatypes.py | 4 +- .../providers/tests/agents/test_agents.py | 3 +- .../providers/tests/memory/test_memory.py | 2 - 11 files changed, 350 insertions(+), 84 deletions(-) create mode 100644 llama_stack/distribution/client.py diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index ad8f2952e..f4238f6f8 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -315,7 +315,20 @@ def get_endpoint_operations( ) else: event_type = None - response_type = return_type + + def process_type(t): + if typing.get_origin(t) is collections.abc.AsyncIterator: + # NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List + # or the item type. I am choosing it to be the latter + args = typing.get_args(t) + return args[0] + elif typing.get_origin(t) is typing.Union: + types = [process_type(a) for a in typing.get_args(t)] + return typing._UnionGenericAlias(typing.Union, tuple(types)) + else: + return t + + response_type = process_type(return_type) # set HTTP request method based on type of request and presence of payload if not request_params: diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index e790dcff1..363d968f9 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-30 16:17:03.919702" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" }, "servers": [ { @@ -320,11 +320,18 @@ "post": { "responses": { "200": { - "description": "OK", + "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.", "content": { "text/event-stream": { "schema": { - "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + "oneOf": [ + { + "$ref": "#/components/schemas/Turn" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + } + ] } } } @@ -4002,7 +4009,8 @@ "additionalProperties": false, "required": [ "event" - ] + ], + "title": "streamed agent turn completion response." }, "AgentTurnResponseTurnCompletePayload": { "type": "object", @@ -7054,30 +7062,27 @@ } ], "tags": [ - { - "name": "Inference" - }, { "name": "Memory" }, { - "name": "Inspect" + "name": "Inference" }, { - "name": "PostTraining" + "name": "Eval" + }, + { + "name": "MemoryBanks" }, { "name": "Models" }, - { - "name": "Scoring" - }, - { - "name": "DatasetIO" - }, { "name": "BatchInference" }, + { + "name": "PostTraining" + }, { "name": "Agents" }, @@ -7085,19 +7090,22 @@ "name": "Shields" }, { - "name": "MemoryBanks" + "name": "Telemetry" }, { - "name": "Datasets" + "name": "Inspect" + }, + { + "name": "DatasetIO" }, { "name": "SyntheticDataGeneration" }, { - "name": "Eval" + "name": "Datasets" }, { - "name": "Telemetry" + "name": "Scoring" }, { "name": "ScoringFunctions" @@ -7307,7 +7315,7 @@ }, { "name": "AgentTurnResponseStreamChunk", - "description": "" + "description": "streamed agent turn completion response.\n\n" }, { "name": "AgentTurnResponseTurnCompletePayload", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 67181ab42..7dd231965 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -190,6 +190,7 @@ components: $ref: '#/components/schemas/AgentTurnResponseEvent' required: - event + title: streamed agent turn completion response. type: object AgentTurnResponseTurnCompletePayload: additionalProperties: false @@ -2997,7 +2998,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-10-30 16:17:03.919702" + \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3190,8 +3191,11 @@ paths: content: text/event-stream: schema: - $ref: '#/components/schemas/AgentTurnResponseStreamChunk' - description: OK + oneOf: + - $ref: '#/components/schemas/Turn' + - $ref: '#/components/schemas/AgentTurnResponseStreamChunk' + description: A single turn in an interaction with an Agentic System. **OR** + streamed agent turn completion response. tags: - Agents /agents/turn/get: @@ -4276,21 +4280,21 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Inference - name: Memory -- name: Inspect -- name: PostTraining +- name: Inference +- name: Eval +- name: MemoryBanks - name: Models -- name: Scoring -- name: DatasetIO - name: BatchInference +- name: PostTraining - name: Agents - name: Shields -- name: MemoryBanks -- name: Datasets -- name: SyntheticDataGeneration -- name: Eval - name: Telemetry +- name: Inspect +- name: DatasetIO +- name: SyntheticDataGeneration +- name: Datasets +- name: Scoring - name: ScoringFunctions - name: Safety - description: @@ -4451,8 +4455,11 @@ tags: - description: name: AgentTurnResponseStepStartPayload -- description: +- description: 'streamed agent turn completion response. + + + ' name: AgentTurnResponseStreamChunk - description: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e0eaacf51..613844f5e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -8,6 +8,7 @@ from datetime import datetime from enum import Enum from typing import ( Any, + AsyncIterator, Dict, List, Literal, @@ -405,6 +406,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): + """streamed agent turn completion response.""" + event: AgentTurnResponseEvent @@ -434,7 +437,7 @@ class Agents(Protocol): ], attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, - ) -> AgentTurnResponseStreamChunk: ... + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") async def get_agents_turn( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index eb2c41d32..4b6530f63 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,7 +6,15 @@ from enum import Enum -from typing import List, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + AsyncIterator, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod @@ -224,7 +232,7 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... @webmethod(route="/inference/chat_completion") async def chat_completion( @@ -239,7 +247,9 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: ... @webmethod(route="/inference/embeddings") async def embeddings( diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py new file mode 100644 index 000000000..acc871f01 --- /dev/null +++ b/llama_stack/distribution/client.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect + +import json +from collections.abc import AsyncIterator +from enum import Enum +from typing import Any, get_args, get_origin, Type, Union + +import httpx +from pydantic import BaseModel, parse_obj_as +from termcolor import cprint + +from llama_stack.providers.datatypes import RemoteProviderConfig + +_CLIENT_CLASSES = {} + + +async def get_client_impl( + protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any +): + client_class = create_api_client_class(protocol, additional_protocol) + impl = client_class(config.url) + await impl.initialize() + return impl + + +def create_api_client_class(protocol, additional_protocol) -> Type: + if protocol in _CLIENT_CLASSES: + return _CLIENT_CLASSES[protocol] + + protocols = [protocol, additional_protocol] if additional_protocol else [protocol] + + class APIClient: + def __init__(self, base_url: str): + print(f"({protocol.__name__}) Connecting to {base_url}") + self.base_url = base_url.rstrip("/") + self.routes = {} + + # Store routes for this protocol + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): + sig = inspect.signature(method) + self.routes[name] = (method.__webmethod__, sig) + + async def initialize(self): + pass + + async def shutdown(self): + pass + + async def __acall__(self, method_name: str, *args, **kwargs) -> Any: + assert method_name in self.routes, f"Unknown endpoint: {method_name}" + + # TODO: make this more precise, same thing needs to happen in server.py + is_streaming = kwargs.get("stream", False) + if is_streaming: + return self._call_streaming(method_name, *args, **kwargs) + else: + return await self._call_non_streaming(method_name, *args, **kwargs) + + async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any: + _, sig = self.routes[method_name] + + if sig.return_annotation is None: + return_type = None + else: + return_type = extract_non_async_iterator_type(sig.return_annotation) + assert ( + return_type + ), f"Could not extract return type for {sig.return_annotation}" + + async with httpx.AsyncClient() as client: + params = self.httpx_request_params(method_name, *args, **kwargs) + response = await client.request(**params) + response.raise_for_status() + + j = response.json() + if j is None: + return None + return parse_obj_as(return_type, j) + + async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any: + webmethod, sig = self.routes[method_name] + + return_type = extract_async_iterator_type(sig.return_annotation) + assert ( + return_type + ), f"Could not extract return type for {sig.return_annotation}" + + async with httpx.AsyncClient() as client: + params = self.httpx_request_params(method_name, *args, **kwargs) + async with client.stream(**params) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: + if "error" in data: + cprint(data, "red") + continue + + yield parse_obj_as(return_type, json.loads(data)) + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") + + def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: + webmethod, sig = self.routes[method_name] + + parameters = list(sig.parameters.values())[1:] # skip `self` + for i, param in enumerate(parameters): + if i >= len(args): + break + kwargs[param.name] = args[i] + + url = f"{self.base_url}{webmethod.route}" + + def convert(value): + if isinstance(value, list): + return [convert(v) for v in value] + elif isinstance(value, dict): + return {k: convert(v) for k, v in value.items()} + elif isinstance(value, BaseModel): + return json.loads(value.model_dump_json()) + elif isinstance(value, Enum): + return value.value + else: + return value + + params = {} + data = {} + if webmethod.method == "GET": + params.update(kwargs) + else: + data.update(convert(kwargs)) + + return dict( + method=webmethod.method or "POST", + url=url, + headers={"Content-Type": "application/json"}, + params=params, + json=data, + timeout=30, + ) + + # Add protocol methods to the wrapper + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): + + async def method_impl(self, *args, method_name=name, **kwargs): + return await self.__acall__(method_name, *args, **kwargs) + + method_impl.__name__ = name + method_impl.__qualname__ = f"APIClient.{name}" + method_impl.__signature__ = inspect.signature(method) + setattr(APIClient, name, method_impl) + + # Name the class after the protocol + APIClient.__name__ = f"{protocol.__name__}Client" + _CLIENT_CLASSES[protocol] = APIClient + return APIClient + + +# not quite general these methods are +def extract_non_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if not issubclass(get_origin(arg) or arg, AsyncIterator): + return arg + return type_hint + + +def extract_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if issubclass(get_origin(arg) or arg, AsyncIterator): + inner_args = get_args(arg) + return inner_args[0] + return None + + +async def example(model: str = None): + from llama_stack.apis.inference import Inference, UserMessage # noqa: F403 + from llama_stack.apis.inference.event_logger import EventLogger + + client_class = create_api_client_class(Inference) + client = client_class("http://localhost:5003") + + if not model: + model = "Llama3.2-3B-Instruct" + + message = UserMessage( + content="hello world, write me a 2 sentence poem about the moon" + ) + cprint(f"User>{message.content}", "green") + + stream = True + iterator = await client.chat_completion( + model=model, + messages=[message], + stream=stream, + ) + + async for log in EventLogger().log(iterator): + log.print() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(example()) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bab807da9..a93cc1183 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]: Api.safety: Safety, Api.shields: Shields, Api.telemetry: Telemetry, - Api.datasets: Datasets, Api.datasetio: DatasetIO, - Api.scoring_functions: ScoringFunctions, + Api.datasets: Datasets, Api.scoring: Scoring, + Api.scoring_functions: ScoringFunctions, Api.eval: Eval, } def additional_protocols_map() -> Dict[Api, Any]: return { - Api.inference: ModelsProtocolPrivate, - Api.memory: MemoryBanksProtocolPrivate, - Api.safety: ShieldsProtocolPrivate, + Api.inference: (ModelsProtocolPrivate, Models), + Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), + Api.safety: (ShieldsProtocolPrivate, Shields), + Api.datasetio: (DatasetsProtocolPrivate, Datasets), + Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), } @@ -112,8 +114,6 @@ async def resolve_impls( if info.router_api.value not in apis_to_serve: continue - available_providers = providers_with_specs[f"inner-{info.router_api.value}"] - providers_with_specs[info.routing_table_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__routing_table__", @@ -246,14 +246,21 @@ async def instantiate_provider( args = [] if isinstance(provider_spec, RemoteProviderSpec): - if provider_spec.adapter: - method = "get_adapter_impl" - else: - method = "get_client_impl" - config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - args = [config, deps] + + if provider_spec.adapter: + method = "get_adapter_impl" + args = [config, deps] + else: + method = "get_client_impl" + protocol = protocols[provider_spec.api] + if provider_spec.api in additional_protocols: + _, additional_protocol = additional_protocols[provider_spec.api] + else: + additional_protocol = None + args = [protocol, additional_protocol, config, deps] + elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -282,7 +289,7 @@ async def instantiate_provider( not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols ): - additional_api = additional_protocols[provider_spec.api] + additional_api, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3e07b9162..4e462c54b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api: async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) + + if obj.provider_id == "remote": + # if this is just a passthrough, we want to let the remote + # end actually do the registration with the correct provider + obj = obj.model_copy(deep=True) + obj.provider_id = "" + if api == Api.inference: await p.register_model(obj) elif api == Api.safety: @@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable): async def initialize(self) -> None: self.registry: Registry = {} - def add_objects(objs: List[RoutableObjectWithProvider]) -> None: + def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if obj.identifier not in self.registry: self.registry[obj.identifier] = [] + if cls is None: + obj.provider_id = provider_id + else: + if provider_id == "remote": + # if this is just a passthrough, we got the *WithProvider object + # so we should just override the provider in-place + obj.provider_id = provider_id + else: + obj = cls(**obj.model_dump(), provider_id=provider_id) self.registry[obj.identifier].append(obj) for pid, p in self.impls_by_provider_id.items(): @@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable): if api == Api.inference: p.model_store = self models = await p.list_models() - add_objects( - [ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models] - ) + add_objects(models, pid, ModelDefWithProvider) elif api == Api.safety: p.shield_store = self shields = await p.list_shields() - add_objects( - [ - ShieldDefWithProvider(**s.dict(), provider_id=pid) - for s in shields - ] - ) + add_objects(shields, pid, ShieldDefWithProvider) elif api == Api.memory: p.memory_bank_store = self memory_banks = await p.list_memory_banks() - - # do in-memory updates due to pesky Annotated unions - for m in memory_banks: - m.provider_id = pid - - add_objects(memory_banks) + add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self datasets = await p.list_datasets() - - # do in-memory updates due to pesky Annotated unions - for d in datasets: - d.provider_id = pid + add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - add_objects( - [ - ScoringFnDefWithProvider(**s.dict(), provider_id=pid) - for s in scoring_functions - ] - ) + add_objects(scoring_functions, pid, ScoringFnDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index eace0ea1a..9a37a28a9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): async def list_datasets(self) -> List[DatasetDef]: ... - async def register_datasets(self, dataset_def: DatasetDef) -> None: ... + async def register_dataset(self, dataset_def: DatasetDef) -> None: ... class ScoringFunctionsProtocolPrivate(Protocol): @@ -171,7 +171,7 @@ as being "Llama Stack compatible" def module(self) -> str: if self.adapter: return self.adapter.module - return f"llama_stack.apis.{self.api.value}.client" + return "llama_stack.distribution.client" @property def pip_packages(self) -> List[str]: diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 9c34c3a28..c09db3d20 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -26,6 +26,7 @@ from dotenv import load_dotenv # # ```bash # PROVIDER_ID= \ +# MODEL_ID= \ # PROVIDER_CONFIG=provider_config.yaml \ # pytest -s llama_stack/providers/tests/agents/test_agents.py \ # --tb=short --disable-warnings @@ -44,7 +45,7 @@ async def agents_settings(): "impl": impls[Api.agents], "memory_impl": impls[Api.memory], "common_params": { - "model": "Llama3.1-8B-Instruct", + "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct", "instructions": "You are a helpful assistant.", }, } diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b26bf75a7..d83601de1 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import pytest import pytest_asyncio @@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - provider_id=os.environ["PROVIDER_ID"], ) await banks_impl.register_memory_bank(bank)