diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index ad8f2952e..f4238f6f8 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -315,7 +315,20 @@ def get_endpoint_operations( ) else: event_type = None - response_type = return_type + + def process_type(t): + if typing.get_origin(t) is collections.abc.AsyncIterator: + # NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List + # or the item type. I am choosing it to be the latter + args = typing.get_args(t) + return args[0] + elif typing.get_origin(t) is typing.Union: + types = [process_type(a) for a in typing.get_args(t)] + return typing._UnionGenericAlias(typing.Union, tuple(types)) + else: + return t + + response_type = process_type(return_type) # set HTTP request method based on type of request and presence of payload if not request_params: diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 886634fba..363d968f9 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" }, "servers": [ { @@ -320,11 +320,18 @@ "post": { "responses": { "200": { - "description": "OK", + "description": "A single turn in an interaction with an Agentic System. **OR** streamed agent turn completion response.", "content": { "text/event-stream": { "schema": { - "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + "oneOf": [ + { + "$ref": "#/components/schemas/Turn" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + } + ] } } } @@ -934,7 +941,7 @@ "schema": { "oneOf": [ { - "$ref": "#/components/schemas/ScoringFunctionDefWithProvider" + "$ref": "#/components/schemas/ScoringFnDefWithProvider" }, { "type": "null" @@ -1555,7 +1562,7 @@ "content": { "application/jsonl": { "schema": { - "$ref": "#/components/schemas/ScoringFunctionDefWithProvider" + "$ref": "#/components/schemas/ScoringFnDefWithProvider" } } } @@ -2762,7 +2769,7 @@ "const": "json_schema", "default": "json_schema" }, - "schema": { + "json_schema": { "type": "object", "additionalProperties": { "oneOf": [ @@ -2791,7 +2798,7 @@ "additionalProperties": false, "required": [ "type", - "schema" + "json_schema" ] }, { @@ -3018,7 +3025,7 @@ "const": "json_schema", "default": "json_schema" }, - "schema": { + "json_schema": { "type": "object", "additionalProperties": { "oneOf": [ @@ -3047,7 +3054,7 @@ "additionalProperties": false, "required": [ "type", - "schema" + "json_schema" ] }, { @@ -4002,7 +4009,8 @@ "additionalProperties": false, "required": [ "event" - ] + ], + "title": "streamed agent turn completion response." }, "AgentTurnResponseTurnCompletePayload": { "type": "object", @@ -5004,24 +5012,6 @@ "type" ] }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "custom", - "default": "custom" - }, - "validator_class": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "validator_class" - ] - }, { "type": "object", "properties": { @@ -5304,24 +5294,6 @@ "type" ] }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "custom", - "default": "custom" - }, - "validator_class": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "validator_class" - ] - }, { "type": "object", "properties": { @@ -5376,7 +5348,7 @@ "type" ] }, - "ScoringFunctionDefWithProvider": { + "ScoringFnDefWithProvider": { "type": "object", "properties": { "identifier": { @@ -5516,24 +5488,6 @@ "type" ] }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "custom", - "default": "custom" - }, - "validator_class": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "validator_class" - ] - }, { "type": "object", "properties": { @@ -5586,6 +5540,12 @@ }, "prompt_template": { "type": "string" + }, + "judge_score_regex": { + "type": "array", + "items": { + "type": "string" + } } }, "additionalProperties": false, @@ -6339,10 +6299,10 @@ "finetuned_model": { "$ref": "#/components/schemas/URL" }, - "dataset": { + "dataset_id": { "type": "string" }, - "validation_dataset": { + "validation_dataset_id": { "type": "string" }, "algorithm": { @@ -6412,8 +6372,8 @@ "required": [ "job_uuid", "finetuned_model", - "dataset", - "validation_dataset", + "dataset_id", + "validation_dataset_id", "algorithm", "algorithm_config", "optimizer_config", @@ -6595,7 +6555,7 @@ "type": "object", "properties": { "function_def": { - "$ref": "#/components/schemas/ScoringFunctionDefWithProvider" + "$ref": "#/components/schemas/ScoringFnDefWithProvider" } }, "additionalProperties": false, @@ -6893,10 +6853,10 @@ "model": { "type": "string" }, - "dataset": { + "dataset_id": { "type": "string" }, - "validation_dataset": { + "validation_dataset_id": { "type": "string" }, "algorithm": { @@ -6976,8 +6936,8 @@ "required": [ "job_uuid", "model", - "dataset", - "validation_dataset", + "dataset_id", + "validation_dataset_id", "algorithm", "algorithm_config", "optimizer_config", @@ -7102,57 +7062,57 @@ } ], "tags": [ - { - "name": "Eval" - }, - { - "name": "ScoringFunctions" - }, - { - "name": "SyntheticDataGeneration" - }, - { - "name": "Inspect" - }, - { - "name": "PostTraining" - }, - { - "name": "Models" - }, - { - "name": "Safety" - }, - { - "name": "MemoryBanks" - }, - { - "name": "DatasetIO" - }, { "name": "Memory" }, - { - "name": "Scoring" - }, - { - "name": "Shields" - }, - { - "name": "Datasets" - }, { "name": "Inference" }, { - "name": "Telemetry" + "name": "Eval" + }, + { + "name": "MemoryBanks" + }, + { + "name": "Models" }, { "name": "BatchInference" }, + { + "name": "PostTraining" + }, { "name": "Agents" }, + { + "name": "Shields" + }, + { + "name": "Telemetry" + }, + { + "name": "Inspect" + }, + { + "name": "DatasetIO" + }, + { + "name": "SyntheticDataGeneration" + }, + { + "name": "Datasets" + }, + { + "name": "Scoring" + }, + { + "name": "ScoringFunctions" + }, + { + "name": "Safety" + }, { "name": "BuiltinTool", "description": "" @@ -7355,7 +7315,7 @@ }, { "name": "AgentTurnResponseStreamChunk", - "description": "" + "description": "streamed agent turn completion response.\n\n" }, { "name": "AgentTurnResponseTurnCompletePayload", @@ -7486,8 +7446,8 @@ "description": "" }, { - "name": "ScoringFunctionDefWithProvider", - "description": "" + "name": "ScoringFnDefWithProvider", + "description": "" }, { "name": "ShieldDefWithProvider", @@ -7805,7 +7765,7 @@ "ScoreBatchResponse", "ScoreRequest", "ScoreResponse", - "ScoringFunctionDefWithProvider", + "ScoringFnDefWithProvider", "ScoringResult", "SearchToolDefinition", "Session", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 9dcdbb028..7dd231965 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -190,6 +190,7 @@ components: $ref: '#/components/schemas/AgentTurnResponseEvent' required: - event + title: streamed agent turn completion response. type: object AgentTurnResponseTurnCompletePayload: additionalProperties: false @@ -360,7 +361,7 @@ components: oneOf: - additionalProperties: false properties: - schema: + json_schema: additionalProperties: oneOf: - type: 'null' @@ -376,7 +377,7 @@ components: type: string required: - type - - schema + - json_schema type: object - additionalProperties: false properties: @@ -541,7 +542,7 @@ components: oneOf: - additionalProperties: false properties: - schema: + json_schema: additionalProperties: oneOf: - type: 'null' @@ -557,7 +558,7 @@ components: type: string required: - type - - schema + - json_schema type: object - additionalProperties: false properties: @@ -747,18 +748,6 @@ components: required: - type type: object - - additionalProperties: false - properties: - type: - const: custom - default: custom - type: string - validator_class: - type: string - required: - - type - - validator_class - type: object - additionalProperties: false properties: type: @@ -1575,18 +1564,6 @@ components: required: - type type: object - - additionalProperties: false - properties: - type: - const: custom - default: custom - type: string - validator_class: - type: string - required: - - type - - validator_class - type: object - additionalProperties: false properties: type: @@ -1724,7 +1701,7 @@ components: $ref: '#/components/schemas/RLHFAlgorithm' algorithm_config: $ref: '#/components/schemas/DPOAlignmentConfig' - dataset: + dataset_id: type: string finetuned_model: $ref: '#/components/schemas/URL' @@ -1754,13 +1731,13 @@ components: $ref: '#/components/schemas/OptimizerConfig' training_config: $ref: '#/components/schemas/TrainingConfig' - validation_dataset: + validation_dataset_id: type: string required: - job_uuid - finetuned_model - - dataset - - validation_dataset + - dataset_id + - validation_dataset_id - algorithm - algorithm_config - optimizer_config @@ -1899,7 +1876,7 @@ components: additionalProperties: false properties: function_def: - $ref: '#/components/schemas/ScoringFunctionDefWithProvider' + $ref: '#/components/schemas/ScoringFnDefWithProvider' required: - function_def type: object @@ -2121,7 +2098,7 @@ components: required: - results type: object - ScoringFunctionDefWithProvider: + ScoringFnDefWithProvider: additionalProperties: false properties: context: @@ -2129,6 +2106,10 @@ components: properties: judge_model: type: string + judge_score_regex: + items: + type: string + type: array prompt_template: type: string required: @@ -2219,18 +2200,6 @@ components: required: - type type: object - - additionalProperties: false - properties: - type: - const: custom - default: custom - type: string - validator_class: - type: string - required: - - type - - validator_class - type: object - additionalProperties: false properties: type: @@ -2484,7 +2453,7 @@ components: - $ref: '#/components/schemas/LoraFinetuningConfig' - $ref: '#/components/schemas/QLoraFinetuningConfig' - $ref: '#/components/schemas/DoraFinetuningConfig' - dataset: + dataset_id: type: string hyperparam_search_config: additionalProperties: @@ -2514,13 +2483,13 @@ components: $ref: '#/components/schemas/OptimizerConfig' training_config: $ref: '#/components/schemas/TrainingConfig' - validation_dataset: + validation_dataset_id: type: string required: - job_uuid - model - - dataset - - validation_dataset + - dataset_id + - validation_dataset_id - algorithm - algorithm_config - optimizer_config @@ -3029,7 +2998,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-10-24 17:40:59.576117" + \ draft and subject to change.\n Generated at 2024-10-31 14:28:52.128905" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3222,8 +3191,11 @@ paths: content: text/event-stream: schema: - $ref: '#/components/schemas/AgentTurnResponseStreamChunk' - description: OK + oneOf: + - $ref: '#/components/schemas/Turn' + - $ref: '#/components/schemas/AgentTurnResponseStreamChunk' + description: A single turn in an interaction with an Agentic System. **OR** + streamed agent turn completion response. tags: - Agents /agents/turn/get: @@ -4122,7 +4094,7 @@ paths: application/json: schema: oneOf: - - $ref: '#/components/schemas/ScoringFunctionDefWithProvider' + - $ref: '#/components/schemas/ScoringFnDefWithProvider' - type: 'null' description: OK tags: @@ -4142,7 +4114,7 @@ paths: content: application/jsonl: schema: - $ref: '#/components/schemas/ScoringFunctionDefWithProvider' + $ref: '#/components/schemas/ScoringFnDefWithProvider' description: OK tags: - ScoringFunctions @@ -4308,23 +4280,23 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Eval -- name: ScoringFunctions -- name: SyntheticDataGeneration -- name: Inspect -- name: PostTraining -- name: Models -- name: Safety -- name: MemoryBanks -- name: DatasetIO - name: Memory -- name: Scoring -- name: Shields -- name: Datasets - name: Inference -- name: Telemetry +- name: Eval +- name: MemoryBanks +- name: Models - name: BatchInference +- name: PostTraining - name: Agents +- name: Shields +- name: Telemetry +- name: Inspect +- name: DatasetIO +- name: SyntheticDataGeneration +- name: Datasets +- name: Scoring +- name: ScoringFunctions +- name: Safety - description: name: BuiltinTool - description: name: AgentTurnResponseStepStartPayload -- description: +- description: 'streamed agent turn completion response. + + + ' name: AgentTurnResponseStreamChunk - description: @@ -4577,9 +4552,9 @@ tags: name: PaginatedRowsResult - description: name: Parameter -- description: - name: ScoringFunctionDefWithProvider + name: ScoringFnDefWithProvider - description: name: ShieldDefWithProvider @@ -4844,7 +4819,7 @@ x-tagGroups: - ScoreBatchResponse - ScoreRequest - ScoreResponse - - ScoringFunctionDefWithProvider + - ScoringFnDefWithProvider - ScoringResult - SearchToolDefinition - Session diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e0eaacf51..613844f5e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -8,6 +8,7 @@ from datetime import datetime from enum import Enum from typing import ( Any, + AsyncIterator, Dict, List, Literal, @@ -405,6 +406,8 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): + """streamed agent turn completion response.""" + event: AgentTurnResponseEvent @@ -434,7 +437,7 @@ class Agents(Protocol): ], attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, - ) -> AgentTurnResponseStreamChunk: ... + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod(route="/agents/turn/get") async def get_agents_turn( diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index eb2c41d32..4b6530f63 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -6,7 +6,15 @@ from enum import Enum -from typing import List, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + AsyncIterator, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod @@ -224,7 +232,7 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ... + ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... @webmethod(route="/inference/chat_completion") async def chat_completion( @@ -239,7 +247,9 @@ class Inference(Protocol): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: ... + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: ... @webmethod(route="/inference/embeddings") async def embeddings( diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 8044dda28..ae2b17d9e 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -77,9 +77,9 @@ if [ -n "$LLAMA_STACK_DIR" ]; then # Install in editable format. We will mount the source code into the container # so that changes will be reflected in the container without having to do a # rebuild. This is just for development convenience. - add_to_docker "RUN pip install -e $stack_mount" + add_to_docker "RUN pip install --no-cache -e $stack_mount" else - add_to_docker "RUN pip install llama-stack" + add_to_docker "RUN pip install --no-cache llama-stack" fi if [ -n "$LLAMA_MODELS_DIR" ]; then @@ -90,19 +90,19 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then add_to_docker < Type: + if protocol in _CLIENT_CLASSES: + return _CLIENT_CLASSES[protocol] + + protocols = [protocol, additional_protocol] if additional_protocol else [protocol] + + class APIClient: + def __init__(self, base_url: str): + print(f"({protocol.__name__}) Connecting to {base_url}") + self.base_url = base_url.rstrip("/") + self.routes = {} + + # Store routes for this protocol + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): + sig = inspect.signature(method) + self.routes[name] = (method.__webmethod__, sig) + + async def initialize(self): + pass + + async def shutdown(self): + pass + + async def __acall__(self, method_name: str, *args, **kwargs) -> Any: + assert method_name in self.routes, f"Unknown endpoint: {method_name}" + + # TODO: make this more precise, same thing needs to happen in server.py + is_streaming = kwargs.get("stream", False) + if is_streaming: + return self._call_streaming(method_name, *args, **kwargs) + else: + return await self._call_non_streaming(method_name, *args, **kwargs) + + async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any: + _, sig = self.routes[method_name] + + if sig.return_annotation is None: + return_type = None + else: + return_type = extract_non_async_iterator_type(sig.return_annotation) + assert ( + return_type + ), f"Could not extract return type for {sig.return_annotation}" + + async with httpx.AsyncClient() as client: + params = self.httpx_request_params(method_name, *args, **kwargs) + response = await client.request(**params) + response.raise_for_status() + + j = response.json() + if j is None: + return None + return parse_obj_as(return_type, j) + + async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any: + webmethod, sig = self.routes[method_name] + + return_type = extract_async_iterator_type(sig.return_annotation) + assert ( + return_type + ), f"Could not extract return type for {sig.return_annotation}" + + async with httpx.AsyncClient() as client: + params = self.httpx_request_params(method_name, *args, **kwargs) + async with client.stream(**params) as response: + response.raise_for_status() + + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line[len("data: ") :] + try: + if "error" in data: + cprint(data, "red") + continue + + yield parse_obj_as(return_type, json.loads(data)) + except Exception as e: + print(data) + print(f"Error with parsing or validation: {e}") + + def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: + webmethod, sig = self.routes[method_name] + + parameters = list(sig.parameters.values())[1:] # skip `self` + for i, param in enumerate(parameters): + if i >= len(args): + break + kwargs[param.name] = args[i] + + url = f"{self.base_url}{webmethod.route}" + + def convert(value): + if isinstance(value, list): + return [convert(v) for v in value] + elif isinstance(value, dict): + return {k: convert(v) for k, v in value.items()} + elif isinstance(value, BaseModel): + return json.loads(value.model_dump_json()) + elif isinstance(value, Enum): + return value.value + else: + return value + + params = {} + data = {} + if webmethod.method == "GET": + params.update(kwargs) + else: + data.update(convert(kwargs)) + + return dict( + method=webmethod.method or "POST", + url=url, + headers={"Content-Type": "application/json"}, + params=params, + json=data, + timeout=30, + ) + + # Add protocol methods to the wrapper + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): + + async def method_impl(self, *args, method_name=name, **kwargs): + return await self.__acall__(method_name, *args, **kwargs) + + method_impl.__name__ = name + method_impl.__qualname__ = f"APIClient.{name}" + method_impl.__signature__ = inspect.signature(method) + setattr(APIClient, name, method_impl) + + # Name the class after the protocol + APIClient.__name__ = f"{protocol.__name__}Client" + _CLIENT_CLASSES[protocol] = APIClient + return APIClient + + +# not quite general these methods are +def extract_non_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if not issubclass(get_origin(arg) or arg, AsyncIterator): + return arg + return type_hint + + +def extract_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if issubclass(get_origin(arg) or arg, AsyncIterator): + inner_args = get_args(arg) + return inner_args[0] + return None + + +async def example(model: str = None): + from llama_stack.apis.inference import Inference, UserMessage # noqa: F403 + from llama_stack.apis.inference.event_logger import EventLogger + + client_class = create_api_client_class(Inference) + client = client_class("http://localhost:5003") + + if not model: + model = "Llama3.2-3B-Instruct" + + message = UserMessage( + content="hello world, write me a 2 sentence poem about the moon" + ) + cprint(f"User>{message.content}", "green") + + stream = True + iterator = await client.chat_completion( + model=model, + messages=[message], + stream=stream, + ) + + async for log in EventLogger().log(iterator): + log.print() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(example()) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bab807da9..a93cc1183 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]: Api.safety: Safety, Api.shields: Shields, Api.telemetry: Telemetry, - Api.datasets: Datasets, Api.datasetio: DatasetIO, - Api.scoring_functions: ScoringFunctions, + Api.datasets: Datasets, Api.scoring: Scoring, + Api.scoring_functions: ScoringFunctions, Api.eval: Eval, } def additional_protocols_map() -> Dict[Api, Any]: return { - Api.inference: ModelsProtocolPrivate, - Api.memory: MemoryBanksProtocolPrivate, - Api.safety: ShieldsProtocolPrivate, + Api.inference: (ModelsProtocolPrivate, Models), + Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), + Api.safety: (ShieldsProtocolPrivate, Shields), + Api.datasetio: (DatasetsProtocolPrivate, Datasets), + Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), } @@ -112,8 +114,6 @@ async def resolve_impls( if info.router_api.value not in apis_to_serve: continue - available_providers = providers_with_specs[f"inner-{info.router_api.value}"] - providers_with_specs[info.routing_table_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__routing_table__", @@ -246,14 +246,21 @@ async def instantiate_provider( args = [] if isinstance(provider_spec, RemoteProviderSpec): - if provider_spec.adapter: - method = "get_adapter_impl" - else: - method = "get_client_impl" - config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - args = [config, deps] + + if provider_spec.adapter: + method = "get_adapter_impl" + args = [config, deps] + else: + method = "get_client_impl" + protocol = protocols[provider_spec.api] + if provider_spec.api in additional_protocols: + _, additional_protocol = additional_protocols[provider_spec.api] + else: + additional_protocol = None + args = [protocol, additional_protocol, config, deps] + elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -282,7 +289,7 @@ async def instantiate_provider( not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols ): - additional_api = additional_protocols[provider_spec.api] + additional_api, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3e07b9162..4e462c54b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api: async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) + + if obj.provider_id == "remote": + # if this is just a passthrough, we want to let the remote + # end actually do the registration with the correct provider + obj = obj.model_copy(deep=True) + obj.provider_id = "" + if api == Api.inference: await p.register_model(obj) elif api == Api.safety: @@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable): async def initialize(self) -> None: self.registry: Registry = {} - def add_objects(objs: List[RoutableObjectWithProvider]) -> None: + def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if obj.identifier not in self.registry: self.registry[obj.identifier] = [] + if cls is None: + obj.provider_id = provider_id + else: + if provider_id == "remote": + # if this is just a passthrough, we got the *WithProvider object + # so we should just override the provider in-place + obj.provider_id = provider_id + else: + obj = cls(**obj.model_dump(), provider_id=provider_id) self.registry[obj.identifier].append(obj) for pid, p in self.impls_by_provider_id.items(): @@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable): if api == Api.inference: p.model_store = self models = await p.list_models() - add_objects( - [ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models] - ) + add_objects(models, pid, ModelDefWithProvider) elif api == Api.safety: p.shield_store = self shields = await p.list_shields() - add_objects( - [ - ShieldDefWithProvider(**s.dict(), provider_id=pid) - for s in shields - ] - ) + add_objects(shields, pid, ShieldDefWithProvider) elif api == Api.memory: p.memory_bank_store = self memory_banks = await p.list_memory_banks() - - # do in-memory updates due to pesky Annotated unions - for m in memory_banks: - m.provider_id = pid - - add_objects(memory_banks) + add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self datasets = await p.list_datasets() - - # do in-memory updates due to pesky Annotated unions - for d in datasets: - d.provider_id = pid + add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - add_objects( - [ - ScoringFnDefWithProvider(**s.dict(), provider_id=pid) - for s in scoring_functions - ] - ) + add_objects(scoring_functions, pid, ScoringFnDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index 3800c0496..caf886c0b 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -55,7 +55,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: + ) -> AsyncGenerator: raise NotImplementedError() @staticmethod @@ -290,23 +290,130 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, - # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> ( - AsyncGenerator - ): # Union[ChatCompletionResponse, ChatCompletionResponseStreamChunk]: - bedrock_model = self.map_to_provider_model(model) - inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( - sampling_params + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + response_format=response_format, + stream=stream, + logprobs=logprobs, ) - tool_config = BedrockInferenceAdapter._tools_to_tool_config(tools, tool_choice) + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletionResponse: + params = self._get_params_for_chat_completion(request) + converse_api_res = self.client.converse(**params) + + output_message = BedrockInferenceAdapter._bedrock_message_to_message( + converse_api_res + ) + + return ChatCompletionResponse( + completion_message=output_message, + logprobs=None, + ) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest + ) -> AsyncGenerator: + params = self._get_params_for_chat_completion(request) + converse_stream_api_res = self.client.converse_stream(**params) + event_stream = converse_stream_api_res["stream"] + + for chunk in event_stream: + if "messageStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + elif "contentBlockStart" in chunk: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=ToolCall( + tool_name=chunk["contentBlockStart"]["toolUse"]["name"], + call_id=chunk["contentBlockStart"]["toolUse"][ + "toolUseId" + ], + ), + parse_status=ToolCallParseStatus.started, + ), + ) + ) + elif "contentBlockDelta" in chunk: + if "text" in chunk["contentBlockDelta"]["delta"]: + delta = chunk["contentBlockDelta"]["delta"]["text"] + else: + delta = ToolCallDelta( + content=ToolCall( + arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][ + "input" + ] + ), + parse_status=ToolCallParseStatus.success, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + ) + ) + elif "contentBlockStop" in chunk: + # Ignored + pass + elif "messageStop" in chunk: + stop_reason = ( + BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( + chunk["messageStop"]["stopReason"] + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + elif "metadata" in chunk: + # Ignored + pass + else: + # Ignored + pass + + def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: + bedrock_model = self.map_to_provider_model(request.model) + inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( + request.sampling_params + ) + + tool_config = BedrockInferenceAdapter._tools_to_tool_config( + request.tools, request.tool_choice + ) bedrock_messages, system_bedrock_messages = ( - BedrockInferenceAdapter._messages_to_bedrock_messages(messages) + BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages) ) converse_api_params = { @@ -317,93 +424,12 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): converse_api_params["inferenceConfig"] = inference_config # Tool use is not supported in streaming mode - if tool_config and not stream: + if tool_config and not request.stream: converse_api_params["toolConfig"] = tool_config if system_bedrock_messages: converse_api_params["system"] = system_bedrock_messages - if not stream: - converse_api_res = self.client.converse(**converse_api_params) - - output_message = BedrockInferenceAdapter._bedrock_message_to_message( - converse_api_res - ) - - yield ChatCompletionResponse( - completion_message=output_message, - logprobs=None, - ) - else: - converse_stream_api_res = self.client.converse_stream(**converse_api_params) - event_stream = converse_stream_api_res["stream"] - - for chunk in event_stream: - if "messageStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) - ) - elif "contentBlockStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=ToolCall( - tool_name=chunk["contentBlockStart"]["toolUse"][ - "name" - ], - call_id=chunk["contentBlockStart"]["toolUse"][ - "toolUseId" - ], - ), - parse_status=ToolCallParseStatus.started, - ), - ) - ) - elif "contentBlockDelta" in chunk: - if "text" in chunk["contentBlockDelta"]["delta"]: - delta = chunk["contentBlockDelta"]["delta"]["text"] - else: - delta = ToolCallDelta( - content=ToolCall( - arguments=chunk["contentBlockDelta"]["delta"][ - "toolUse" - ]["input"] - ), - parse_status=ToolCallParseStatus.success, - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - ) - ) - elif "contentBlockStop" in chunk: - # Ignored - pass - elif "messageStop" in chunk: - stop_reason = ( - BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - chunk["messageStop"]["stopReason"] - ) - ) - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) - elif "metadata" in chunk: - # Ignored - pass - else: - # Ignored - pass + return converse_api_params async def embeddings( self, diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index 4687618fa..4cf55035c 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -75,7 +75,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): for model in self.client.models.list() ] - def completion( + async def completion( self, model: str, content: InterleavedTextMedia, @@ -86,7 +86,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> Union[CompletionResponse, CompletionResponseStreamChunk]: raise NotImplementedError() - def chat_completion( + async def chat_completion( self, model: str, messages: List[Message], @@ -111,7 +111,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if stream: return self._stream_chat_completion(request, self.client) else: - return self._nonstream_chat_completion(request, self.client) + return await self._nonstream_chat_completion(request, self.client) async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index eace0ea1a..9a37a28a9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): async def list_datasets(self) -> List[DatasetDef]: ... - async def register_datasets(self, dataset_def: DatasetDef) -> None: ... + async def register_dataset(self, dataset_def: DatasetDef) -> None: ... class ScoringFunctionsProtocolPrivate(Protocol): @@ -171,7 +171,7 @@ as being "Llama Stack compatible" def module(self) -> str: if self.adapter: return self.adapter.module - return f"llama_stack.apis.{self.api.value}.client" + return "llama_stack.distribution.client" @property def pip_packages(self) -> List[str]: diff --git a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift b/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift index 89f24a561..84da42d1b 100644 --- a/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift +++ b/llama_stack/providers/impls/ios/inference/LocalInferenceImpl/Parsing.swift @@ -81,7 +81,9 @@ func encodeMessage(message: Components.Schemas.ChatCompletionRequest.messagesPay switch (m.content) { case .case1(let c): prompt += _processContent(c) - case .case2(let c): + case .ImageMedia(let c): + prompt += _processContent(c) + case .case3(let c): prompt += _processContent(c) } case .CompletionMessage(let m): diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 9c34c3a28..c09db3d20 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -26,6 +26,7 @@ from dotenv import load_dotenv # # ```bash # PROVIDER_ID= \ +# MODEL_ID= \ # PROVIDER_CONFIG=provider_config.yaml \ # pytest -s llama_stack/providers/tests/agents/test_agents.py \ # --tb=short --disable-warnings @@ -44,7 +45,7 @@ async def agents_settings(): "impl": impls[Api.agents], "memory_impl": impls[Api.memory], "common_params": { - "model": "Llama3.1-8B-Instruct", + "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct", "instructions": "You are a helpful assistant.", }, } diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b26bf75a7..d83601de1 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import pytest import pytest_asyncio @@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - provider_id=os.environ["PROVIDER_ID"], ) await banks_impl.register_memory_bank(bank)