From 164d0e25c776c3390d9f6951a04a4c4ee7e52724 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 20 Sep 2024 15:33:31 -0700 Subject: [PATCH] stage tmp changes --- llama_stack/distribution/server/server.py | 50 +++++++++++++++++++++- llama_stack/examples/router-table-run.yaml | 32 +++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 64c1111e7..645e5ed34 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import importlib import inspect import json import signal @@ -456,7 +457,52 @@ def run_main_DEPRECATED( def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False): - raise ValueError("Not implemented") + app = FastAPI() + + all_endpoints = api_endpoints() + + apis_to_serve = config.apis_to_serve + + # get unified router + module = importlib.import_module("llama_stack.distribution.routers") + get_router_fn = getattr(module, "get_router_impl") + router_impl = asyncio.run(get_router_fn()) + + cprint(router_impl, "blue") + + for api_str in apis_to_serve: + api = Api(api_str) + endpoints = all_endpoints[api] + + print(api, endpoints) + for endpoint in endpoints: + print(endpoint.route) + impl_method = getattr(router_impl, "process_request") + attr = getattr(app, endpoint.method)(endpoint.route, response_model=None)( + create_dynamic_typed_route(impl_method, endpoint.method) + ) + print(endpoint, attr) + + # check if it is a simple endpoint + + for route in app.routes: + if isinstance(route, APIRoute): + cprint( + f"Serving {next(iter(route.methods))} {route.path}", + "white", + attrs=["bold"], + ) + + app.exception_handler(RequestValidationError)(global_exception_handler) + app.exception_handler(Exception)(global_exception_handler) + signal.signal(signal.SIGINT, handle_sigint) + + import uvicorn + + # FYI this does not do hot-reloads + listen_host = "::" if not disable_ipv6 else "0.0.0.0" + print(f"Listening on {listen_host}:{port}") + uvicorn.run(app, host=listen_host, port=port) def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): @@ -466,8 +512,10 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): cprint(f"StackRunConfig: {config}", "blue") if not config.provider_routing_table: + cprint("- running old implementation", "red") run_main_DEPRECATED(config, port, disable_ipv6) else: + cprint("- running new implementation with routers", "red") run_main(config, port, disable_ipv6) diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 379ccebe3..9fbc394c1 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -3,27 +3,27 @@ image_name: local docker_image: null conda_env: local apis_to_serve: -# - inference -- memory -- telemetry +- inference +# - memory +# - telemetry provider_map: telemetry: provider_id: meta-reference config: {} provider_routing_table: - # inference: - # - routing_key: Meta-Llama3.1-8B-Instruct - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B-Instruct - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B - # provider_id: remote::ollama - # config: - # url: http:ollama-url-1.com + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + - routing_key: Meta-Llama3.1-8B + provider_id: remote::ollama + config: + url: http:ollama-url-1.com memory: - routing_key: keyvalue provider_id: remote::pgvector