forked from phoenix-oss/llama-stack-mirror
[API Updates] Model / shield / memory-bank routing + agent persistence + support for private headers (#92)
This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack. Most important bits: * Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis. * We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are: - routing model A to ollama and model B to a remote provider like Together - routing shield A to local impl while shield B to a remote provider like Bedrock - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis * Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
This commit is contained in:
parent
8bf8c07eb3
commit
ec4fc800cc
130 changed files with 9701 additions and 11227 deletions
|
@ -93,4 +93,5 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
f"Failed to build target {build_config.name} with return code {return_code}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
return return_code
|
||||
|
|
|
@ -9,12 +9,21 @@ from typing import Any
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.distribution import api_providers, stack_apis
|
||||
from llama_stack.apis.memory.memory import MemoryBankType
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
stack_apis,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.providers.impls.meta_reference.safety.config import (
|
||||
MetaReferenceShieldType,
|
||||
)
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit.validation import Validator
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
def make_routing_entry_type(config_class: Any):
|
||||
|
@ -25,71 +34,139 @@ def make_routing_entry_type(config_class: Any):
|
|||
return BaseModelWithConfig
|
||||
|
||||
|
||||
def get_builtin_apis(provider_backed_apis: List[str]) -> List[str]:
|
||||
"""Get corresponding builtin APIs given provider backed APIs"""
|
||||
res = []
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in provider_backed_apis:
|
||||
res.append(inf.routing_table_api.value)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# TODO: make sure we can deal with existing configuration values correctly
|
||||
# instead of just overwriting them
|
||||
def configure_api_providers(
|
||||
config: StackRunConfig, spec: DistributionSpec
|
||||
) -> StackRunConfig:
|
||||
apis = config.apis_to_serve or list(spec.providers.keys())
|
||||
config.apis_to_serve = [a for a in apis if a != "telemetry"]
|
||||
# append the bulitin routing APIs
|
||||
apis += get_builtin_apis(apis)
|
||||
|
||||
router_api2builtin_api = {
|
||||
inf.router_api.value: inf.routing_table_api.value
|
||||
for inf in builtin_automatically_routed_apis()
|
||||
}
|
||||
|
||||
config.apis_to_serve = list(set([a for a in apis if a != "telemetry"]))
|
||||
|
||||
apis = [v.value for v in stack_apis()]
|
||||
all_providers = api_providers()
|
||||
|
||||
# configure simple case for with non-routing providers to api_providers
|
||||
for api_str in spec.providers.keys():
|
||||
if api_str not in apis:
|
||||
raise ValueError(f"Unknown API `{api_str}`")
|
||||
|
||||
cprint(f"Configuring API `{api_str}`...\n", "white", attrs=["bold"])
|
||||
cprint(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
|
||||
api = Api(api_str)
|
||||
|
||||
provider_or_providers = spec.providers[api_str]
|
||||
if isinstance(provider_or_providers, list) and len(provider_or_providers) > 1:
|
||||
print(
|
||||
"You have specified multiple providers for this API. We will configure a routing table now. For each provider, provide a routing key followed by provider configuration.\n"
|
||||
p = spec.providers[api_str]
|
||||
cprint(f"=== Configuring provider `{p}` for API {api_str}...", "green")
|
||||
|
||||
if isinstance(p, list):
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of multiple providers {p} is not supported, configuring {p[0]} only, please manually configure {p[1:]} in routing_table of run.yaml",
|
||||
"yellow",
|
||||
)
|
||||
p = p[0]
|
||||
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.api_providers.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
|
||||
if api_str in router_api2builtin_api:
|
||||
# a routing api, we need to infer and assign it a routing_key and put it in the routing_table
|
||||
routing_key = "<PLEASE_FILL_ROUTING_KEY>"
|
||||
routing_entries = []
|
||||
for p in provider_or_providers:
|
||||
print(f"Configuring provider `{p}`...")
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
|
||||
# TODO: we need to validate the routing keys, and
|
||||
# perhaps it is better if we break this out into asking
|
||||
# for a routing key separately from the associated config
|
||||
wrapper_type = make_routing_entry_type(config_type)
|
||||
rt_entry = prompt_for_config(wrapper_type, None)
|
||||
|
||||
if api_str == "inference":
|
||||
if hasattr(cfg, "model"):
|
||||
routing_key = cfg.model
|
||||
else:
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported model your provider has for inference: ",
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
)
|
||||
routing_entries.append(
|
||||
ProviderRoutingEntry(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_id=p,
|
||||
routing_key=rt_entry.routing_key,
|
||||
config=rt_entry.config.dict(),
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
config.provider_map[api_str] = routing_entries
|
||||
else:
|
||||
p = (
|
||||
provider_or_providers[0]
|
||||
if isinstance(provider_or_providers, list)
|
||||
else provider_or_providers
|
||||
)
|
||||
print(f"Configuring provider `{p}`...")
|
||||
provider_spec = all_providers[api][p]
|
||||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
try:
|
||||
provider_config = config.provider_map.get(api_str)
|
||||
if provider_config:
|
||||
existing = config_type(**provider_config.config)
|
||||
|
||||
if api_str == "safety":
|
||||
# TODO: add support for other safety providers, and simplify safety provider config
|
||||
if p == "meta-reference":
|
||||
for shield_type in MetaReferenceShieldType:
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=shield_type.value,
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
existing = None
|
||||
except Exception:
|
||||
existing = None
|
||||
cfg = prompt_for_config(config_type, existing)
|
||||
config.provider_map[api_str] = GenericProviderConfig(
|
||||
cprint(
|
||||
f"[WARN] Interactive configuration of safety provider {p} is not supported, please manually configure safety shields types in routing_table of run.yaml",
|
||||
"yellow",
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
if api_str == "memory":
|
||||
bank_types = list([x.value for x in MemoryBankType])
|
||||
routing_key = prompt(
|
||||
"> Please enter the supported memory bank type your provider has for memory: ",
|
||||
default="vector",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in bank_types,
|
||||
error_message="Invalid provider, please enter one of the following: {}".format(
|
||||
bank_types
|
||||
),
|
||||
),
|
||||
)
|
||||
routing_entries.append(
|
||||
RoutableProviderConfig(
|
||||
routing_key=routing_key,
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
)
|
||||
|
||||
config.routing_table[api_str] = routing_entries
|
||||
config.api_providers[api_str] = PlaceholderProviderConfig(
|
||||
providers=p if isinstance(p, list) else [p]
|
||||
)
|
||||
else:
|
||||
config.api_providers[api_str] = GenericProviderConfig(
|
||||
provider_id=p,
|
||||
config=cfg.dict(),
|
||||
)
|
||||
|
||||
print("")
|
||||
|
||||
return config
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import RedisImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: RedisImplConfig, _deps):
|
||||
from .redis import RedisControlPlaneAdapter
|
||||
|
||||
impl = RedisControlPlaneAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RedisImplConfig(BaseModel):
|
||||
url: str = Field(
|
||||
description="The URL for the Redis server",
|
||||
)
|
||||
namespace: Optional[str] = Field(
|
||||
default=None,
|
||||
description="All keys will be prefixed with this namespace",
|
||||
)
|
|
@ -1,62 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from llama_stack.apis.control_plane import * # noqa: F403
|
||||
|
||||
|
||||
from .config import RedisImplConfig
|
||||
|
||||
|
||||
class RedisControlPlaneAdapter(ControlPlane):
|
||||
def __init__(self, config: RedisImplConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
return f"{self.config.namespace}:{key}"
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
ttl = await self.redis.ttl(key)
|
||||
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
return ControlPlaneValue(key=key, value=value, expiration=expiration)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
keys = await self.redis.keys(f"{start_key}*")
|
||||
result = []
|
||||
for key in keys:
|
||||
if key <= end_key:
|
||||
value = await self.get(key)
|
||||
if value:
|
||||
result.append(value)
|
||||
return result
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import SqliteControlPlaneConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
|
||||
from .control_plane import SqliteControlPlane
|
||||
|
||||
impl = SqliteControlPlane(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SqliteControlPlaneConfig(BaseModel):
|
||||
db_path: str = Field(
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
table_name: str = Field(
|
||||
default="llamastack_control_plane",
|
||||
description="Table into which all the keys will be placed",
|
||||
)
|
|
@ -1,79 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.control_plane import * # noqa: F403
|
||||
|
||||
|
||||
from .config import SqliteControlPlaneConfig
|
||||
|
||||
|
||||
class SqliteControlPlane(ControlPlane):
|
||||
def __init__(self, config: SqliteControlPlaneConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = config.table_name
|
||||
|
||||
async def initialize(self):
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
expiration TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||
(key, json.dumps(value), expiration),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
return ControlPlaneValue(
|
||||
key=key, value=json.loads(value), expiration=expiration
|
||||
)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
key, value, expiration = row
|
||||
result.append(
|
||||
ControlPlaneValue(
|
||||
key=key, value=json.loads(value), expiration=expiration
|
||||
)
|
||||
)
|
||||
return result
|
|
@ -1,35 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ControlPlaneValue(BaseModel):
|
||||
key: str
|
||||
value: Any
|
||||
expiration: Optional[datetime] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ControlPlane(Protocol):
|
||||
@webmethod(route="/control_plane/set")
|
||||
async def set(
|
||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/control_plane/get", method="GET")
|
||||
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
|
||||
|
||||
@webmethod(route="/control_plane/delete")
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
@webmethod(route="/control_plane/range", method="GET")
|
||||
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...
|
|
@ -1,29 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.control_plane,
|
||||
provider_id="sqlite",
|
||||
pip_packages=["aiosqlite"],
|
||||
module="llama_stack.providers.impls.sqlite.control_plane",
|
||||
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.control_plane,
|
||||
AdapterSpec(
|
||||
adapter_id="redis",
|
||||
pip_packages=["redis"],
|
||||
module="llama_stack.providers.adapters.control_plane.redis",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -19,8 +19,13 @@ class Api(Enum):
|
|||
safety = "safety"
|
||||
agents = "agents"
|
||||
memory = "memory"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ApiEndpoint(BaseModel):
|
||||
|
@ -43,31 +48,69 @@ class ProviderSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_routing_keys(self) -> List[str]: ...
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class PlaceholderProviderConfig(BaseModel):
|
||||
"""Placeholder provider config for API whose provider are defined in routing_table"""
|
||||
|
||||
providers: List[str]
|
||||
|
||||
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
@json_schema_type
|
||||
class RouterProviderSpec(ProviderSpec):
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_id: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
routing_table_api: Api
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_id: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -92,6 +135,9 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
default=None,
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -115,17 +161,18 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
- `get_provider_impl(config, deps)`: returns the local implementation
|
||||
""",
|
||||
)
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
host: str = "localhost"
|
||||
port: int
|
||||
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url.rstrip("/")
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
|
@ -159,6 +206,12 @@ as being "Llama Stack compatible"
|
|||
return self.adapter.pip_packages
|
||||
return []
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
|
@ -192,14 +245,6 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
ProviderMapEntry = Union[GenericProviderConfig, List[ProviderRoutingEntry]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StackRunConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
@ -223,18 +268,28 @@ this could be just a hash
|
|||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||
|
||||
api_providers: Dict[
|
||||
str, Union[GenericProviderConfig, PlaceholderProviderConfig]
|
||||
] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
""",
|
||||
)
|
||||
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
|
||||
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||
|
||||
As examples:
|
||||
- the "inference" API interprets the routing_key as a "model"
|
||||
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||
|
||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||
provider_id: meta-reference
|
||||
config:
|
||||
model: Meta-Llama3.1-8B-Instruct
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 4096
|
||||
max_batch_size: 1
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -11,9 +11,14 @@ from typing import Dict, List
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||
|
||||
# These are the dependencies needed by the distribution server.
|
||||
|
@ -29,6 +34,28 @@ def stack_apis() -> List[Api]:
|
|||
return [v for v in Api]
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
routing_table_api: Api
|
||||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.models,
|
||||
router_api=Api.inference,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.memory_banks,
|
||||
router_api=Api.memory,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
|
@ -38,6 +65,9 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -66,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
for api in stack_apis():
|
||||
if api in routing_table_apis:
|
||||
continue
|
||||
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
|
|
49
llama_stack/distribution/request_headers.py
Normal file
49
llama_stack/distribution/request_headers.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .utils.dynamic import instantiate_class_type
|
||||
|
||||
_THREAD_LOCAL = threading.local()
|
||||
|
||||
|
||||
def get_request_provider_data() -> Any:
|
||||
return getattr(_THREAD_LOCAL, "provider_data", None)
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str], validator_class: Optional[str]):
|
||||
if not validator_class:
|
||||
return
|
||||
|
||||
keys = [
|
||||
"X-LlamaStack-ProviderData",
|
||||
"x-llamastack-providerdata",
|
||||
]
|
||||
for key in keys:
|
||||
val = headers.get(key, None)
|
||||
if val:
|
||||
break
|
||||
|
||||
if not val:
|
||||
return
|
||||
|
||||
try:
|
||||
val = json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
print("Provider data not encoded as a JSON object!", val)
|
||||
return
|
||||
|
||||
validator = instantiate_class_type(validator_class)
|
||||
try:
|
||||
provider_data = validator(**val)
|
||||
except Exception as e:
|
||||
print("Error parsing provider data", e)
|
||||
return
|
||||
|
||||
_THREAD_LOCAL.provider_data = provider_data
|
50
llama_stack/distribution/routers/__init__.py
Normal file
50
llama_stack/distribution/routers/__init__.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
}
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
169
llama_stack/distribution/routers/routers.py
Normal file
169
llama_stack/distribution/routers/routers.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
"""Routes to an provider based on the memory bank type"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
name: str,
|
||||
config: MemoryBankConfig,
|
||||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(bank_type).create_memory_bank(
|
||||
name, config, url
|
||||
)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
documents: List[MemoryBankDocument],
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
return await self.get_provider_from_bank_id(bank_id).insert_documents(
|
||||
bank_id, documents, ttl_seconds
|
||||
)
|
||||
|
||||
async def query_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
query: InterleavedTextMedia,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> QueryDocumentsResponse:
|
||||
return await self.get_provider_from_bank_id(bank_id).query_documents(
|
||||
bank_id, query, params
|
||||
)
|
||||
|
||||
|
||||
class InferenceRouter(Inference):
|
||||
"""Routes to an provider based on the model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
116
llama_stack/distribution/routers/routing_tables.py
Normal file
116
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||
) -> None:
|
||||
self.providers = {k: v for k, v in inner_impls}
|
||||
self.routing_keys = list(self.providers.keys())
|
||||
self.routing_table_config = routing_table_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||
return self.providers.get(routing_key)
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
|
||||
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
|
@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -45,9 +42,17 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_endpoints,
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
|
@ -176,7 +181,9 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
def create_dynamic_typed_route(
|
||||
func: Any, method: str, provider_data_validator: Optional[str]
|
||||
):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
|
@ -188,9 +195,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -217,8 +226,11 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
|
||||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers, provider_data_validator)
|
||||
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
|
@ -232,20 +244,23 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
"request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
|
||||
)
|
||||
]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
if method == "post":
|
||||
# make sure every parameter is annotated with Body() so FASTAPI doesn't
|
||||
# do anything too intelligent and ask for some parameters in the query
|
||||
# and some in the body
|
||||
endpoint.__signature__ = sig.replace(
|
||||
parameters=[
|
||||
param.replace(
|
||||
annotation=Annotated[param.annotation, Body(..., embed=True)]
|
||||
)
|
||||
for param in sig.parameters.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
endpoint.__signature__ = sig
|
||||
new_params = [new_params[0]] + [
|
||||
param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
for param in new_params[1:]
|
||||
]
|
||||
|
||||
endpoint.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
@ -276,52 +291,92 @@ def snake_to_camel(snake_str):
|
|||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls(
|
||||
provider_map: Dict[str, ProviderMapEntry],
|
||||
) -> Dict[Api, Any]:
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_providers = api_providers()
|
||||
|
||||
specs = {}
|
||||
for api_str, item in provider_map.items():
|
||||
configs = {}
|
||||
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
if isinstance(item, GenericProviderConfig):
|
||||
if item.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[item.provider_id]
|
||||
else:
|
||||
assert isinstance(item, list)
|
||||
inner_specs = []
|
||||
for rt_entry in item:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
# skip checks for API whose provider config is specified in routing_table
|
||||
if isinstance(config, PlaceholderProviderConfig):
|
||||
continue
|
||||
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.providers.routers.{api.value.lower()}",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
if config.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[config.provider_id]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||
)
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
print("router_api", info.router_api)
|
||||
if info.router_api.value not in run_config.routing_table:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
|
||||
routing_table = run_config.routing_table[info.router_api.value]
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
for rt_entry in routing_table:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_id}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
impl = await instantiate_provider(spec, deps, provider_map[api.value])
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
@ -333,15 +388,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in apis_to_serve:
|
||||
apis_to_serve.add(inf.routing_table_api)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
|
@ -365,7 +428,15 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(impl_method, endpoint.method)
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
(
|
||||
provider_spec.provider_data_validator
|
||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for route in app.routes:
|
||||
|
|
|
@ -15,3 +15,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
|
|
@ -8,6 +8,7 @@ import importlib
|
|||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
def instantiate_class_type(fully_qualified_name):
|
||||
|
@ -20,7 +21,7 @@ def instantiate_class_type(fully_qualified_name):
|
|||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: ProviderMapEntry,
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
|
@ -35,13 +36,20 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, RouterProviderSpec):
|
||||
method = "get_router_impl"
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, List)
|
||||
routing_table = provider_config
|
||||
|
||||
assert isinstance(provider_config, list)
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in provider_config:
|
||||
for routing_entry in routing_table:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
|
@ -50,7 +58,7 @@ async def instantiate_provider(
|
|||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [inner_impls, deps]
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
|
@ -83,10 +83,12 @@ def prompt_for_discriminated_union(
|
|||
if isinstance(typ, FieldInfo):
|
||||
inner_type = typ.annotation
|
||||
discriminator = typ.discriminator
|
||||
default_value = typ.default
|
||||
else:
|
||||
args = get_args(typ)
|
||||
inner_type = args[0]
|
||||
discriminator = args[1].discriminator
|
||||
default_value = args[1].default
|
||||
|
||||
union_types = get_args(inner_type)
|
||||
# Find the discriminator field in each union type
|
||||
|
@ -99,9 +101,14 @@ def prompt_for_discriminated_union(
|
|||
type_map[value] = t
|
||||
|
||||
while True:
|
||||
discriminator_value = input(
|
||||
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
|
||||
)
|
||||
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
|
||||
if default_value is not None:
|
||||
prompt += f" (default: {default_value})"
|
||||
|
||||
discriminator_value = input(f"{prompt}: ")
|
||||
if discriminator_value == "" and default_value is not None:
|
||||
discriminator_value = default_value
|
||||
|
||||
if discriminator_value in type_map:
|
||||
chosen_type = type_map[discriminator_value]
|
||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue