From ddebf9b6e7f992092d3d200e6c3dcb6691a6f2a6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 23 Sep 2024 08:46:33 -0700 Subject: [PATCH] [api_updates_3] fix CLI for routing_table, bug fixes for memory & safety (#90) * fix llama stack build * fix configure * fix configure for simple case * configure w/ routing * move examples config * fix memory router naming * issue w/ safety * fix config w/ safety * update memory endpoints * allow providers in api_providers * configure script works * all endpoints w/ build->configure->run simple local works * new example run.yaml * run openapi generator --- docs/resources/llama-stack-spec.html | 418 ++++++++-------- docs/resources/llama-stack-spec.yaml | 456 +++++++++--------- llama_stack/apis/memory/client.py | 8 +- llama_stack/apis/memory/memory.py | 18 +- llama_stack/apis/memory_banks/memory_banks.py | 2 +- llama_stack/cli/stack/build.py | 15 +- llama_stack/cli/stack/configure.py | 2 +- llama_stack/distribution/configure.py | 161 +++++-- llama_stack/distribution/datatypes.py | 19 +- llama_stack/distribution/distribution.py | 4 +- llama_stack/distribution/routers/__init__.py | 2 +- llama_stack/distribution/routers/routers.py | 7 +- .../distribution/routers/routing_tables.py | 16 +- llama_stack/distribution/server/server.py | 20 +- llama_stack/distribution/utils/dynamic.py | 5 +- tests/examples/local-run.yaml | 87 ++++ tests/examples/router-local-run.yaml | 50 -- tests/examples/simple-local-run.yaml | 40 -- 18 files changed, 725 insertions(+), 605 deletions(-) create mode 100644 tests/examples/local-run.yaml delete mode 100644 tests/examples/router-local-run.yaml delete mode 100644 tests/examples/simple-local-run.yaml diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 3933233b2..95b08d6ca 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-20 14:53:17.090953" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 01:08:55.758597" }, "servers": [ { @@ -422,7 +422,7 @@ } } }, - "/memory_banks/create": { + "/memory/create": { "post": { "responses": { "200": { @@ -561,7 +561,7 @@ } } }, - "/memory_bank/documents/delete": { + "/memory/documents/delete": { "post": { "responses": { "200": { @@ -594,7 +594,7 @@ } } }, - "/memory_banks/drop": { + "/memory/drop": { "post": { "responses": { "200": { @@ -988,7 +988,7 @@ ] } }, - "/memory_bank/documents/get": { + "/memory/documents/get": { "post": { "responses": { "200": { @@ -1180,7 +1180,7 @@ ] } }, - "/memory_banks/get": { + "/memory/get": { "get": { "responses": { "200": { @@ -1407,7 +1407,7 @@ ] } }, - "/memory_bank/insert": { + "/memory/insert": { "post": { "responses": { "200": { @@ -1440,7 +1440,7 @@ } } }, - "/memory_banks/list": { + "/memory/list": { "get": { "responses": { "200": { @@ -1543,7 +1543,7 @@ } } }, - "/memory_bank/query": { + "/memory/query": { "post": { "responses": { "200": { @@ -1743,7 +1743,7 @@ } } }, - "/memory_bank/update": { + "/memory/update": { "post": { "responses": { "200": { @@ -2584,183 +2584,7 @@ "$ref": "#/components/schemas/FunctionCallToolDefinition" }, { - "type": "object", - "properties": { - "input_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "output_shields": { - "type": "array", - "items": { - "type": "string" - } - }, - "type": { - "type": "string", - "const": "memory" - }, - "memory_bank_configs": { - "type": "array", - "items": { - "oneOf": [ - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "vector" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyvalue" - }, - "keys": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "keys" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "keyword" - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type" - ] - }, - { - "type": "object", - "properties": { - "bank_id": { - "type": "string" - }, - "type": { - "type": "string", - "const": "graph" - }, - "entities": { - "type": "array", - "items": { - "type": "string" - } - } - }, - "additionalProperties": false, - "required": [ - "bank_id", - "type", - "entities" - ] - } - ] - } - }, - "query_generator_config": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "default" - }, - "sep": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "sep" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "llm" - }, - "model": { - "type": "string" - }, - "template": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "type", - "model", - "template" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "custom" - } - }, - "additionalProperties": false, - "required": [ - "type" - ] - } - ] - }, - "max_tokens_in_context": { - "type": "integer" - }, - "max_chunks": { - "type": "integer" - } - }, - "additionalProperties": false, - "required": [ - "type", - "memory_bank_configs", - "query_generator_config", - "max_tokens_in_context", - "max_chunks" - ] + "$ref": "#/components/schemas/MemoryToolDefinition" } ] } @@ -2771,17 +2595,25 @@ "tool_prompt_format": { "$ref": "#/components/schemas/ToolPromptFormat" }, + "max_infer_iters": { + "type": "integer" + }, "model": { "type": "string" }, "instructions": { "type": "string" + }, + "enable_session_persistence": { + "type": "boolean" } }, "additionalProperties": false, "required": [ + "max_infer_iters", "model", - "instructions" + "instructions", + "enable_session_persistence" ] }, "CodeInterpreterToolDefinition": { @@ -2859,6 +2691,185 @@ "parameters" ] }, + "MemoryToolDefinition": { + "type": "object", + "properties": { + "input_shields": { + "type": "array", + "items": { + "type": "string" + } + }, + "output_shields": { + "type": "array", + "items": { + "type": "string" + } + }, + "type": { + "type": "string", + "const": "memory" + }, + "memory_bank_configs": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "vector" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyvalue" + }, + "keys": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "keys" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "keyword" + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type" + ] + }, + { + "type": "object", + "properties": { + "bank_id": { + "type": "string" + }, + "type": { + "type": "string", + "const": "graph" + }, + "entities": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "bank_id", + "type", + "entities" + ] + } + ] + } + }, + "query_generator_config": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "default" + }, + "sep": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "sep" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "llm" + }, + "model": { + "type": "string" + }, + "template": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "model", + "template" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "custom" + } + }, + "additionalProperties": false, + "required": [ + "type" + ] + } + ] + }, + "max_tokens_in_context": { + "type": "integer" + }, + "max_chunks": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "type", + "memory_bank_configs", + "query_generator_config", + "max_tokens_in_context", + "max_chunks" + ] + }, "PhotogenToolDefinition": { "type": "object", "properties": { @@ -5569,31 +5580,28 @@ ], "tags": [ { - "name": "Agents" - }, - { - "name": "RewardScoring" - }, - { - "name": "Evaluations" + "name": "PostTraining" }, { "name": "Safety" }, { - "name": "Telemetry" - }, - { - "name": "PostTraining" + "name": "SyntheticDataGeneration" }, { "name": "Datasets" }, { - "name": "Inference" + "name": "Telemetry" }, { - "name": "SyntheticDataGeneration" + "name": "Evaluations" + }, + { + "name": "RewardScoring" + }, + { + "name": "Agents" }, { "name": "Memory" @@ -5601,6 +5609,9 @@ { "name": "BatchInference" }, + { + "name": "Inference" + }, { "name": "BuiltinTool", "description": "" @@ -5733,6 +5744,10 @@ "name": "FunctionCallToolDefinition", "description": "" }, + { + "name": "MemoryToolDefinition", + "description": "" + }, { "name": "PhotogenToolDefinition", "description": "" @@ -6174,6 +6189,7 @@ "MemoryBank", "MemoryBankDocument", "MemoryRetrievalStep", + "MemoryToolDefinition", "MetricEvent", "OptimizerConfig", "PhotogenToolDefinition", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 8cfd6ee2e..d08a2a2c1 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -4,12 +4,16 @@ components: AgentConfig: additionalProperties: false properties: + enable_session_persistence: + type: boolean input_shields: items: type: string type: array instructions: type: string + max_infer_iters: + type: integer model: type: string output_shields: @@ -30,127 +34,13 @@ components: - $ref: '#/components/schemas/PhotogenToolDefinition' - $ref: '#/components/schemas/CodeInterpreterToolDefinition' - $ref: '#/components/schemas/FunctionCallToolDefinition' - - additionalProperties: false - properties: - input_shields: - items: - type: string - type: array - max_chunks: - type: integer - max_tokens_in_context: - type: integer - memory_bank_configs: - items: - oneOf: - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: vector - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - keys: - items: - type: string - type: array - type: - const: keyvalue - type: string - required: - - bank_id - - type - - keys - type: object - - additionalProperties: false - properties: - bank_id: - type: string - type: - const: keyword - type: string - required: - - bank_id - - type - type: object - - additionalProperties: false - properties: - bank_id: - type: string - entities: - items: - type: string - type: array - type: - const: graph - type: string - required: - - bank_id - - type - - entities - type: object - type: array - output_shields: - items: - type: string - type: array - query_generator_config: - oneOf: - - additionalProperties: false - properties: - sep: - type: string - type: - const: default - type: string - required: - - type - - sep - type: object - - additionalProperties: false - properties: - model: - type: string - template: - type: string - type: - const: llm - type: string - required: - - type - - model - - template - type: object - - additionalProperties: false - properties: - type: - const: custom - type: string - required: - - type - type: object - type: - const: memory - type: string - required: - - type - - memory_bank_configs - - query_generator_config - - max_tokens_in_context - - max_chunks - type: object + - $ref: '#/components/schemas/MemoryToolDefinition' type: array required: + - max_infer_iters - model - instructions + - enable_session_persistence type: object AgentCreateResponse: additionalProperties: false @@ -1182,6 +1072,124 @@ components: - memory_bank_ids - inserted_context type: object + MemoryToolDefinition: + additionalProperties: false + properties: + input_shields: + items: + type: string + type: array + max_chunks: + type: integer + max_tokens_in_context: + type: integer + memory_bank_configs: + items: + oneOf: + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: vector + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + keys: + items: + type: string + type: array + type: + const: keyvalue + type: string + required: + - bank_id + - type + - keys + type: object + - additionalProperties: false + properties: + bank_id: + type: string + type: + const: keyword + type: string + required: + - bank_id + - type + type: object + - additionalProperties: false + properties: + bank_id: + type: string + entities: + items: + type: string + type: array + type: + const: graph + type: string + required: + - bank_id + - type + - entities + type: object + type: array + output_shields: + items: + type: string + type: array + query_generator_config: + oneOf: + - additionalProperties: false + properties: + sep: + type: string + type: + const: default + type: string + required: + - type + - sep + type: object + - additionalProperties: false + properties: + model: + type: string + template: + type: string + type: + const: llm + type: string + required: + - type + - model + - template + type: object + - additionalProperties: false + properties: + type: + const: custom + type: string + required: + - type + type: object + type: + const: memory + type: string + required: + - type + - memory_bank_configs + - query_generator_config + - max_tokens_in_context + - max_chunks + type: object MetricEvent: additionalProperties: false properties: @@ -2341,7 +2349,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-09-20 14:53:17.090953" + \ draft and subject to change.\n Generated at 2024-09-23 01:08:55.758597" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -2944,7 +2952,32 @@ paths: description: OK tags: - Inference - /memory_bank/documents/delete: + /memory/create: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateMemoryBankRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/MemoryBank' + description: OK + tags: + - Memory + /memory/documents/delete: post: parameters: - description: JSON-encoded provider data which will be made available to the @@ -2965,7 +2998,7 @@ paths: description: OK tags: - Memory - /memory_bank/documents/get: + /memory/documents/get: post: parameters: - in: query @@ -2995,99 +3028,7 @@ paths: description: OK tags: - Memory - /memory_bank/insert: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/InsertDocumentsRequest' - required: true - responses: - '200': - description: OK - tags: - - Memory - /memory_bank/query: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/QueryDocumentsResponse' - description: OK - tags: - - Memory - /memory_bank/update: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/UpdateDocumentsRequest' - required: true - responses: - '200': - description: OK - tags: - - Memory - /memory_banks/create: - post: - parameters: - - description: JSON-encoded provider data which will be made available to the - adapter servicing the API - in: header - name: X-LlamaStack-ProviderData - required: false - schema: - type: string - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/CreateMemoryBankRequest' - required: true - responses: - '200': - content: - application/json: - schema: - $ref: '#/components/schemas/MemoryBank' - description: OK - tags: - - Memory - /memory_banks/drop: + /memory/drop: post: parameters: - description: JSON-encoded provider data which will be made available to the @@ -3112,7 +3053,7 @@ paths: description: OK tags: - Memory - /memory_banks/get: + /memory/get: get: parameters: - in: query @@ -3138,7 +3079,28 @@ paths: description: OK tags: - Memory - /memory_banks/list: + /memory/insert: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/InsertDocumentsRequest' + required: true + responses: + '200': + description: OK + tags: + - Memory + /memory/list: get: parameters: - description: JSON-encoded provider data which will be made available to the @@ -3157,6 +3119,52 @@ paths: description: OK tags: - Memory + /memory/query: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueryDocumentsRequest' + required: true + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/QueryDocumentsResponse' + description: OK + tags: + - Memory + /memory/update: + post: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateDocumentsRequest' + required: true + responses: + '200': + description: OK + tags: + - Memory /post_training/job/artifacts: get: parameters: @@ -3444,17 +3452,17 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Agents -- name: RewardScoring -- name: Evaluations -- name: Safety -- name: Telemetry - name: PostTraining -- name: Datasets -- name: Inference +- name: Safety - name: SyntheticDataGeneration +- name: Datasets +- name: Telemetry +- name: Evaluations +- name: RewardScoring +- name: Agents - name: Memory - name: BatchInference +- name: Inference - description: name: BuiltinTool - description: name: FunctionCallToolDefinition +- description: + name: MemoryToolDefinition - description: name: PhotogenToolDefinition @@ -3922,6 +3933,7 @@ x-tagGroups: - MemoryBank - MemoryBankDocument - MemoryRetrievalStep + - MemoryToolDefinition - MetricEvent - OptimizerConfig - PhotogenToolDefinition diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index 0cddf0d0e..b4bfcb34d 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -38,7 +38,7 @@ class MemoryClient(Memory): async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: r = await client.get( - f"{self.base_url}/memory_banks/get", + f"{self.base_url}/memory/get", params={ "bank_id": bank_id, }, @@ -59,7 +59,7 @@ class MemoryClient(Memory): ) -> MemoryBank: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_banks/create", + f"{self.base_url}/memory/create", json={ "name": name, "config": config.dict(), @@ -81,7 +81,7 @@ class MemoryClient(Memory): ) -> None: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_bank/insert", + f"{self.base_url}/memory/insert", json={ "bank_id": bank_id, "documents": [d.dict() for d in documents], @@ -99,7 +99,7 @@ class MemoryClient(Memory): ) -> QueryDocumentsResponse: async with httpx.AsyncClient() as client: r = await client.post( - f"{self.base_url}/memory_bank/query", + f"{self.base_url}/memory/query", json={ "bank_id": bank_id, "query": query, diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index a26ff67ea..261dd93ee 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -96,7 +96,7 @@ class MemoryBank(BaseModel): class Memory(Protocol): - @webmethod(route="/memory_banks/create") + @webmethod(route="/memory/create") async def create_memory_bank( self, name: str, @@ -104,13 +104,13 @@ class Memory(Protocol): url: Optional[URL] = None, ) -> MemoryBank: ... - @webmethod(route="/memory_banks/list", method="GET") + @webmethod(route="/memory/list", method="GET") async def list_memory_banks(self) -> List[MemoryBank]: ... - @webmethod(route="/memory_banks/get", method="GET") + @webmethod(route="/memory/get", method="GET") async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... - @webmethod(route="/memory_banks/drop", method="DELETE") + @webmethod(route="/memory/drop", method="DELETE") async def drop_memory_bank( self, bank_id: str, @@ -118,7 +118,7 @@ class Memory(Protocol): # this will just block now until documents are inserted, but it should # probably return a Job instance which can be polled for completion - @webmethod(route="/memory_bank/insert") + @webmethod(route="/memory/insert") async def insert_documents( self, bank_id: str, @@ -126,14 +126,14 @@ class Memory(Protocol): ttl_seconds: Optional[int] = None, ) -> None: ... - @webmethod(route="/memory_bank/update") + @webmethod(route="/memory/update") async def update_documents( self, bank_id: str, documents: List[MemoryBankDocument], ) -> None: ... - @webmethod(route="/memory_bank/query") + @webmethod(route="/memory/query") async def query_documents( self, bank_id: str, @@ -141,14 +141,14 @@ class Memory(Protocol): params: Optional[Dict[str, Any]] = None, ) -> QueryDocumentsResponse: ... - @webmethod(route="/memory_bank/documents/get", method="GET") + @webmethod(route="/memory/documents/get", method="GET") async def get_documents( self, bank_id: str, document_ids: List[str], ) -> List[MemoryBankDocument]: ... - @webmethod(route="/memory_bank/documents/delete", method="DELETE") + @webmethod(route="/memory/documents/delete", method="DELETE") async def delete_documents( self, bank_id: str, diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 23bfb69e1..721983b19 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -7,11 +7,11 @@ from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field from llama_stack.apis.memory import MemoryBankType from llama_stack.distribution.datatypes import GenericProviderConfig +from pydantic import BaseModel, Field @json_schema_type diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index dea705628..f787b1a8e 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -160,7 +160,11 @@ class StackBuild(Subcommand): def _run_stack_build_command(self, args: argparse.Namespace) -> None: import yaml - from llama_stack.distribution.distribution import Api, api_providers + from llama_stack.distribution.distribution import ( + Api, + api_providers, + builtin_automatically_routed_apis, + ) from llama_stack.distribution.utils.dynamic import instantiate_class_type from prompt_toolkit import prompt from prompt_toolkit.validation import Validator @@ -213,8 +217,15 @@ class StackBuild(Subcommand): ) providers = dict() + all_providers = api_providers() + routing_table_apis = set( + x.routing_table_api for x in builtin_automatically_routed_apis() + ) + for api in Api: - all_providers = api_providers() + if api in routing_table_apis: + continue + providers_for_api = all_providers[api] api_provider = prompt( diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index ff2976c96..1c4453e90 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -145,7 +145,7 @@ class StackConfigure(Subcommand): built_at=datetime.now(), image_name=image_name, apis_to_serve=[], - provider_map={}, + api_providers={}, ) config = configure_api_providers(config, build_config.distribution_spec) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index ab1f31de6..35130c027 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -9,12 +9,21 @@ from typing import Any from pydantic import BaseModel from llama_stack.distribution.datatypes import * # noqa: F403 -from termcolor import cprint - -from llama_stack.distribution.distribution import api_providers, stack_apis +from llama_stack.apis.memory.memory import MemoryBankType +from llama_stack.distribution.distribution import ( + api_providers, + builtin_automatically_routed_apis, + stack_apis, +) from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config +from llama_stack.providers.impls.meta_reference.safety.config import ( + MetaReferenceShieldType, +) +from prompt_toolkit import prompt +from prompt_toolkit.validation import Validator +from termcolor import cprint def make_routing_entry_type(config_class: Any): @@ -25,71 +34,139 @@ def make_routing_entry_type(config_class: Any): return BaseModelWithConfig +def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]: + """Get corresponding builtin APIs given provider backed APIs""" + res = [] + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in provider_backed_apis: + res.append(inf.routing_table_api.value) + + return res + + # TODO: make sure we can deal with existing configuration values correctly # instead of just overwriting them def configure_api_providers( config: StackRunConfig, spec: DistributionSpec ) -> StackRunConfig: apis = config.apis_to_serve or list(spec.providers.keys()) - config.apis_to_serve = [a for a in apis if a != "telemetry"] + # append the bulitin routing APIs + apis += get_builtin_apis(apis) + + router_api2builtin_api = { + inf.router_api.value: inf.routing_table_api.value + for inf in builtin_automatically_routed_apis() + } + + config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) apis = [v.value for v in stack_apis()] all_providers = api_providers() + # configure simple case for with non-routing providers to api_providers for api_str in spec.providers.keys(): if api_str not in apis: raise ValueError(f"Unknown API `{api_str}`") - cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"]) + cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) api = Api(api_str) - provider_or_providers = spec.providers[api_str] - if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1: - print( - "You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n" + p = spec.providers[api_str] + cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green") + + if isinstance(p, list): + cprint( + f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml", + "yellow", ) + p = p[0] + provider_spec = all_providers[api][p] + config_type = instantiate_class_type(provider_spec.config_class) + try: + provider_config = config.api_providers.get(api_str) + if provider_config: + existing = config_type(**provider_config.config) + else: + existing = None + except Exception: + existing = None + cfg = prompt_for_config(config_type, existing) + + if api_str in router_api2builtin_api: + # a routing api, we need to infer and assign it a routing_key and put it in the routing_table + routing_key = "" routing_entries = [] - for p in provider_or_providers: - print(f"Configuring provider `{p}`...") - provider_spec = all_providers[api][p] - config_type = instantiate_class_type(provider_spec.config_class) - - # TODO: we need to validate the routing keys, and - # perhaps it is better if we break this out into asking - # for a routing key separately from the associated config - wrapper_type = make_routing_entry_type(config_type) - rt_entry = prompt_for_config(wrapper_type, None) - + if api_str == "inference": + if hasattr(cfg, "model"): + routing_key = cfg.model + else: + routing_key = prompt( + "> Please enter the supported model your provider has for inference: ", + default="Meta-Llama3.1-8B-Instruct", + ) routing_entries.append( - ProviderRoutingEntry( + RoutableProviderConfig( + routing_key=routing_key, provider_id=p, - routing_key=rt_entry.routing_key, - config=rt_entry.config.dict(), + config=cfg.dict(), ) ) - config.provider_map[api_str] = routing_entries - else: - p = ( - provider_or_providers[0] - if isinstance(provider_or_providers, list) - else provider_or_providers - ) - print(f"Configuring provider `{p}`...") - provider_spec = all_providers[api][p] - config_type = instantiate_class_type(provider_spec.config_class) - try: - provider_config = config.provider_map.get(api_str) - if provider_config: - existing = config_type(**provider_config.config) + + if api_str == "safety": + # TODO: add support for other safety providers, and simplify safety provider config + if p == "meta-reference": + for shield_type in MetaReferenceShieldType: + routing_entries.append( + RoutableProviderConfig( + routing_key=shield_type.value, + provider_id=p, + config=cfg.dict(), + ) + ) else: - existing = None - except Exception: - existing = None - cfg = prompt_for_config(config_type, existing) - config.provider_map[api_str] = GenericProviderConfig( + cprint( + f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml", + "yellow", + ) + routing_entries.append( + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ) + + if api_str == "memory": + bank_types = list([x.value for x in MemoryBankType]) + routing_key = prompt( + "> Please enter the supported memory bank type your provider has for memory: ", + default="vector", + validator=Validator.from_callable( + lambda x: x in bank_types, + error_message="Invalid provider, please enter one of the following: {}".format( + bank_types + ), + ), + ) + routing_entries.append( + RoutableProviderConfig( + routing_key=routing_key, + provider_id=p, + config=cfg.dict(), + ) + ) + + config.routing_table[api_str] = routing_entries + config.api_providers[api_str] = PlaceholderProviderConfig( + providers=p if isinstance(p, list) else [p] + ) + else: + config.api_providers[api_str] = GenericProviderConfig( provider_id=p, config=cfg.dict(), ) + print("") + return config diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a3ff86cdf..619b5b078 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -59,17 +59,16 @@ class GenericProviderConfig(BaseModel): config: Dict[str, Any] +class PlaceholderProviderConfig(BaseModel): + """Placeholder provider config for API whose provider are defined in routing_table""" + + providers: List[str] + + class RoutableProviderConfig(GenericProviderConfig): routing_key: str -class RoutingTableConfig(BaseModel): - entries: List[RoutableProviderConfig] = Field(...) - keys: Optional[List[str]] = Field( - default=None, - ) - - # Example: /inference, /safety @json_schema_type class AutoRoutedProviderSpec(ProviderSpec): @@ -270,12 +269,14 @@ this could be just a hash The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) - api_providers: Dict[str, GenericProviderConfig] = Field( + api_providers: Dict[ + str, Union[GenericProviderConfig, PlaceholderProviderConfig] + ] = Field( description=""" Provider configurations for each of the APIs provided by this package. """, ) - routing_tables: Dict[str, RoutingTableConfig] = Field( + routing_table: Dict[str, List[RoutableProviderConfig]] = Field( default_factory=dict, description=""" diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 6b72afed5..b641b6582 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -8,8 +8,6 @@ import importlib import inspect from typing import Dict, List -from pydantic import BaseModel - from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory @@ -19,6 +17,8 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry +from pydantic import BaseModel + from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec # These are the dependencies needed by the distribution server. diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index e8b8938b0..363c863aa 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 async def get_routing_table_impl( api: Api, inner_impls: List[Tuple[str, Any]], - routing_table_config: RoutingTableConfig, + routing_table_config: Dict[str, List[RoutableProviderConfig]], _deps, ) -> Any: from .routing_tables import ( diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6d296d20e..ba32e5986 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -46,9 +46,9 @@ class MemoryRouter(Memory): url: Optional[URL] = None, ) -> MemoryBank: bank_type = config.type - provider = await self.routing_table.get_provider_impl( - bank_type - ).create_memory_bank(name, config, url) + bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank( + name, config, url + ) self.bank_id_to_type[bank.bank_id] = bank_type return bank @@ -162,6 +162,7 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: + print(f"Running shield {shield_type}") return await self.routing_table.get_provider_impl(shield_type).run_shield( shield_type=shield_type, messages=messages, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index cd014d28d..fcd4d2b2b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable): def __init__( self, inner_impls: List[Tuple[str, Any]], - routing_table_config: RoutingTableConfig, + routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: self.providers = {k: v for k, v in inner_impls} self.routing_keys = list(self.providers.keys()) @@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable): return self.routing_keys def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == routing_key: return entry return None @@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelServingSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: model_id = entry.routing_key specs.append( ModelServingSpec( @@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return specs async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == core_model_id: return ModelServingSpec( llama_model=resolve_model(core_model_id), @@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: specs.append( ShieldSpec( shield_type=entry.routing_key, @@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return specs async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == shield_type: return ShieldSpec( shield_type=entry.routing_key, @@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: specs.append( MemoryBankSpec( bank_type=entry.routing_key, @@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return specs async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == bank_type: return MemoryBankSpec( bank_type=entry.routing_key, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 18433596f..f09e1c586 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( @@ -307,6 +307,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An # TODO: check that these APIs are not in the routing table part of the config providers = all_providers[api] + # skip checks for API whose provider config is specified in routing_table + if isinstance(config, PlaceholderProviderConfig): + continue + if config.provider_id not in providers: raise ValueError( f"Unknown provider `{config.provider_id}` is not available for API `{api}`" @@ -315,9 +319,8 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An configs[api] = config apis_to_serve = run_config.apis_to_serve or set( - list(specs.keys()) + list(run_config.routing_tables.keys()) + list(specs.keys()) + list(run_config.routing_table.keys()) ) - print("apis_to_serve", apis_to_serve) for info in builtin_automatically_routed_apis(): source_api = info.routing_table_api @@ -331,15 +334,16 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An if info.router_api.value not in apis_to_serve: continue - if source_api.value not in run_config.routing_tables: + print("router_api", info.router_api) + if info.router_api.value not in run_config.routing_table: raise ValueError(f"Routing table for `{source_api.value}` is not provided?") - routing_table = run_config.routing_tables[source_api.value] + routing_table = run_config.routing_table[info.router_api.value] providers = all_providers[info.router_api] inner_specs = [] - for rt_entry in routing_table.entries: + for rt_entry in routing_table: if rt_entry.provider_id not in providers: raise ValueError( f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 6d9c57dfd..e15ab63d6 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -8,6 +8,7 @@ import importlib from typing import Any, Dict from llama_stack.distribution.datatypes import * # noqa: F403 +from termcolor import cprint def instantiate_class_type(fully_qualified_name): @@ -43,12 +44,12 @@ async def instantiate_provider( elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" - assert isinstance(provider_config, RoutingTableConfig) + assert isinstance(provider_config, List) routing_table = provider_config inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} inner_impls = [] - for routing_entry in routing_table.entries: + for routing_entry in routing_table: impl = await instantiate_provider( inner_specs[routing_entry.provider_id], deps, diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml new file mode 100644 index 000000000..2ae975cdc --- /dev/null +++ b/tests/examples/local-run.yaml @@ -0,0 +1,87 @@ +built_at: '2024-09-23T00:54:40.551416' +image_name: test-2 +docker_image: null +conda_env: test-2 +apis_to_serve: +- shields +- agents +- models +- memory +- memory_banks +- inference +- safety +api_providers: + inference: + providers: + - meta-reference + safety: + providers: + - meta-reference + agents: + provider_id: meta-reference + config: + persistence_store: + namespace: null + type: sqlite + db_path: /home/xiyan/.llama/runtime/kvstore.db + memory: + providers: + - meta-reference + telemetry: + provider_id: meta-reference + config: {} +routing_table: + inference: + - provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + routing_key: Meta-Llama3.1-8B-Instruct + safety: + - provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + routing_key: llama_guard + - provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + routing_key: code_scanner_guard + - provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + routing_key: injection_shield + - provider_id: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + routing_key: jailbreak_shield + memory: + - provider_id: meta-reference + config: {} + routing_key: vector diff --git a/tests/examples/router-local-run.yaml b/tests/examples/router-local-run.yaml deleted file mode 100644 index 08cf9a804..000000000 --- a/tests/examples/router-local-run.yaml +++ /dev/null @@ -1,50 +0,0 @@ -built_at: '2024-09-18T13:41:17.656743' -image_name: local -docker_image: null -conda_env: local -apis_to_serve: -- inference -- memory -- telemetry -- agents -- safety -- models -provider_map: - telemetry: - provider_id: meta-reference - config: {} - safety: - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - agents: - provider_id: meta-reference - config: {} -provider_routing_table: - inference: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - - routing_key: Meta-Llama3.1-8B - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - memory: - - routing_key: vector - provider_id: meta-reference - config: {} diff --git a/tests/examples/simple-local-run.yaml b/tests/examples/simple-local-run.yaml deleted file mode 100644 index f517116aa..000000000 --- a/tests/examples/simple-local-run.yaml +++ /dev/null @@ -1,40 +0,0 @@ -built_at: '2024-09-19T22:50:36.239761' -image_name: simple-local -docker_image: null -conda_env: simple-local -apis_to_serve: -- inference -- safety -- agents -- memory -- models -- telemetry -provider_map: - inference: - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - safety: - provider_id: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - agents: - provider_id: meta-reference - config: {} - memory: - provider_id: meta-reference - config: {} - telemetry: - provider_id: meta-reference - config: {} -provider_routing_table: {}