Revert "stage tmp changes"

This reverts commit 164d0e25c7.
This commit is contained in:
Xi Yan 2024-09-21 12:42:07 -07:00
parent c22844f5f6
commit 3ea55d9b0f
2 changed files with 17 additions and 65 deletions

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import importlib
import inspect import inspect
import json import json
import signal import signal
@ -457,52 +456,7 @@ def run_main_DEPRECATED(
def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False): def run_main(config: StackRunConfig, port: int = 5000, disable_ipv6: bool = False):
app = FastAPI() raise ValueError("Not implemented")
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve
# get unified router
module = importlib.import_module("llama_stack.distribution.routers")
get_router_fn = getattr(module, "get_router_impl")
router_impl = asyncio.run(get_router_fn())
cprint(router_impl, "blue")
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
print(api, endpoints)
for endpoint in endpoints:
print(endpoint.route)
impl_method = getattr(router_impl, "process_request")
attr = getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(impl_method, endpoint.method)
)
print(endpoint, attr)
# check if it is a simple endpoint
for route in app.routes:
if isinstance(route, APIRoute):
cprint(
f"Serving {next(iter(route.methods))} {route.path}",
"white",
attrs=["bold"],
)
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, handle_sigint)
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)
def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
@ -512,10 +466,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
cprint(f"StackRunConfig: {config}", "blue") cprint(f"StackRunConfig: {config}", "blue")
if not config.provider_routing_table: if not config.provider_routing_table:
cprint("- running old implementation", "red")
run_main_DEPRECATED(config, port, disable_ipv6) run_main_DEPRECATED(config, port, disable_ipv6)
else: else:
cprint("- running new implementation with routers", "red")
run_main(config, port, disable_ipv6) run_main(config, port, disable_ipv6)

View file

@ -3,27 +3,27 @@ image_name: local
docker_image: null docker_image: null
conda_env: local conda_env: local
apis_to_serve: apis_to_serve:
- inference # - inference
# - memory - memory
# - telemetry - telemetry
provider_map: provider_map:
telemetry: telemetry:
provider_id: meta-reference provider_id: meta-reference
config: {} config: {}
provider_routing_table: provider_routing_table:
inference: # inference:
- routing_key: Meta-Llama3.1-8B-Instruct # - routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference # provider_id: meta-reference
config: # config:
model: Meta-Llama3.1-8B-Instruct # model: Meta-Llama3.1-8B-Instruct
quantization: null # quantization: null
torch_seed: null # torch_seed: null
max_seq_len: 4096 # max_seq_len: 4096
max_batch_size: 1 # max_batch_size: 1
- routing_key: Meta-Llama3.1-8B # - routing_key: Meta-Llama3.1-8B
provider_id: remote::ollama # provider_id: remote::ollama
config: # config:
url: http:ollama-url-1.com # url: http:ollama-url-1.com
memory: memory:
- routing_key: keyvalue - routing_key: keyvalue
provider_id: remote::pgvector provider_id: remote::pgvector