From 73399fe9053f4557cb77c38021028ecee6ba4ebc Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 20 Sep 2024 11:22:58 -0700 Subject: [PATCH] example config --- llama_stack/distribution/server/server.py | 17 ++-- llama_stack/examples/router-table-run.yaml | 94 ++++++++++++++++++++++ 2 files changed, 106 insertions(+), 5 deletions(-) create mode 100644 llama_stack/examples/router-table-run.yaml diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 25374c97c..deb7bf787 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers @@ -287,14 +287,21 @@ def snake_to_camel(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) +async def resolve_impls_with_routing( + stack_run_config: StackRunConfig, +) -> Dict[Api, Any]: + raise NotImplementedError("This is not implemented yet") + + async def resolve_impls( - provider_map: Dict[str, ProviderMapEntry], + stack_run_config: StackRunConfig, ) -> Dict[Api, Any]: """ Does two things: - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ + provider_map = stack_run_config.provider_map all_providers = api_providers() specs = {} @@ -344,7 +351,7 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): app = FastAPI() - impls, specs = asyncio.run(resolve_impls(config.provider_map)) + impls, specs = asyncio.run(resolve_impls(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml new file mode 100644 index 000000000..aec0fca7e --- /dev/null +++ b/llama_stack/examples/router-table-run.yaml @@ -0,0 +1,94 @@ +built_at: '2024-09-18T13:41:17.656743' +image_name: local +docker_image: null +conda_env: local +apis_to_serve: +- inference +- memory +provider_map: + # use builtin-router as dummy field + memory: builtin-router + inference: builtin-router +routing_table: + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + - routing_key: Meta-Llama3.1-8B + provider_id: remote::ollama + config: + url: http:ollama-url-1.com + memory: + - routing_key: keyvalue + provider_id: remote::pgvector + config: + host: localhost + port: 5432 + db: vectordb + user: vectoruser + - routing_key: vector + provider_id: meta-reference + config: {} + + + +# safety: +# provider_id: meta-reference +# config: +# llama_guard_shield: +# model: Llama-Guard-3-8B +# excluded_categories: [] +# disable_input_check: false +# disable_output_check: false +# prompt_guard_shield: +# model: Prompt-Guard-86M +# telemetry: +# provider_id: meta-reference +# config: {} +# agents: +# provider_id: meta-reference +# config: {} +# memory: +# provider_id: meta-reference +# config: {} +# models: +# provider_id: builtin +# config: +# models_config: +# - core_model_id: Meta-Llama3.1-8B-Instruct +# provider_id: meta-reference +# api: inference +# config: +# model: Meta-Llama3.1-8B-Instruct +# quantization: null +# torch_seed: null +# max_seq_len: 4096 +# max_batch_size: 1 +# - core_model_id: Meta-Llama3.1-8B +# provider_id: meta-reference +# api: inference +# config: +# model: Meta-Llama3.1-8B +# quantization: null +# torch_seed: null +# max_seq_len: 4096 +# max_batch_size: 1 +# - core_model_id: Llama-Guard-3-8B +# provider_id: meta-reference +# api: safety +# config: +# model: Llama-Guard-3-8B +# excluded_categories: [] +# disable_input_check: false +# disable_output_check: false +# - core_model_id: Prompt-Guard-86M +# provider_id: meta-reference +# api: safety +# config: +# model: Prompt-Guard-86M +