mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Further generalize Xi's changes (#88)
* Further generalize Xi's changes - introduce a slightly more general notion of an AutoRouted provider - the AutoRouted provider is associated with a RoutingTable provider - e.g. inference -> models - Introduced safety -> shields and memory -> memory_banks correspondences * typo * Basic build and run succeeded
This commit is contained in:
parent
b8914bb56f
commit
c1ab66f1e6
21 changed files with 597 additions and 418 deletions
|
@ -3,9 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuiltinImplConfig(BaseModel): ...
|
||||
from .memory_banks import * # noqa: F401 F403
|
67
llama_stack/apis/memory_banks/client.py
Normal file
67
llama_stack/apis/memory_banks/client.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .memory_banks import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryBanksClient(MemoryBanks):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [MemoryBankSpec(**x) for x in response.json()]
|
||||
|
||||
async def get_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/memory_banks/get",
|
||||
json={
|
||||
"bank_type": bank_type,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return MemoryBankSpec(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
32
llama_stack/apis/memory_banks/memory_banks.py
Normal file
32
llama_stack/apis/memory_banks/memory_banks.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankType
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryBankSpec(BaseModel):
|
||||
bank_type: MemoryBankType
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||
)
|
||||
|
||||
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory_banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||
|
||||
@webmethod(route="/memory_banks/get", method="GET")
|
||||
async def get_memory_bank(
|
||||
self, bank_type: MemoryBankType
|
||||
) -> Optional[MemoryBankSpec]: ...
|
|
@ -5,15 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
from termcolor import cprint
|
||||
|
||||
from .models import * # noqa: F403
|
||||
|
@ -29,18 +25,18 @@ class ModelsClient(Models):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> ModelsListResponse:
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ModelsListResponse(**response.json())
|
||||
return [ModelServingSpec(**x) for x in response.json()]
|
||||
|
||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
response = await client.get(
|
||||
f"{self.base_url}/models/get",
|
||||
json={
|
||||
"core_model_id": core_model_id,
|
||||
|
@ -48,7 +44,10 @@ class ModelsClient(Models):
|
|||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ModelsGetResponse(**response.json())
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
return ModelServingSpec(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
|
|
|
@ -4,14 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelServingSpec(BaseModel):
|
||||
|
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
|
|||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||
)
|
||||
api: str = Field(
|
||||
description="The API that this model is serving (e.g. inference / safety).",
|
||||
default="inference",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelsListResponse(BaseModel):
|
||||
models_list: List[ModelServingSpec]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ModelsGetResponse(BaseModel):
|
||||
core_model_spec: Optional[ModelServingSpec] = None
|
||||
|
||||
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> ModelsListResponse: ...
|
||||
async def list_models(self) -> List[ModelServingSpec]: ...
|
||||
|
||||
@webmethod(route="/models/get", method="POST")
|
||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
|
||||
@webmethod(route="/models/get", method="GET")
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
|
||||
|
|
7
llama_stack/apis/shields/__init__.py
Normal file
7
llama_stack/apis/shields/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .shields import * # noqa: F401 F403
|
67
llama_stack/apis/shields/client.py
Normal file
67
llama_stack/apis/shields/client.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .shields import * # noqa: F403
|
||||
|
||||
|
||||
class ShieldsClient(Shields):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ShieldSpec(**x) for x in response.json()]
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/shields/get",
|
||||
json={
|
||||
"shield_type": shield_type,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
|
||||
return ShieldSpec(**j)
|
||||
|
||||
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = ShieldsClient(f"http://{host}:{port}")
|
||||
|
||||
response = await client.list_shields()
|
||||
cprint(f"list_shields response={response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
28
llama_stack/apis/shields/shields.py
Normal file
28
llama_stack/apis/shields/shields.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldSpec(BaseModel):
|
||||
shield_type: str
|
||||
provider_config: GenericProviderConfig = Field(
|
||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||
)
|
||||
|
||||
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[ShieldSpec]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
|
@ -19,8 +19,12 @@ class Api(Enum):
|
|||
safety = "safety"
|
||||
agents = "agents"
|
||||
memory = "memory"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -44,27 +48,36 @@ class ProviderSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_routing_keys(self) -> List[str]: ...
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
class GenericProviderConfig(BaseModel):
|
||||
provider_id: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderRoutingEntry(GenericProviderConfig):
|
||||
class RoutableProviderConfig(GenericProviderConfig):
|
||||
routing_key: str
|
||||
|
||||
|
||||
class RoutingTableConfig(BaseModel):
|
||||
entries: List[RoutableProviderConfig] = Field(...)
|
||||
keys: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
# Example: /inference, /safety
|
||||
@json_schema_type
|
||||
class RouterProviderSpec(ProviderSpec):
|
||||
class AutoRoutedProviderSpec(ProviderSpec):
|
||||
provider_id: str = "router"
|
||||
config_class: str = ""
|
||||
|
||||
docker_image: Optional[str] = None
|
||||
|
||||
routing_table: List[ProviderRoutingEntry] = Field(
|
||||
default_factory=list,
|
||||
description="Routing table entries corresponding to the API",
|
||||
)
|
||||
routing_table_api: Api
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
|
@ -79,18 +92,17 @@ class RouterProviderSpec(ProviderSpec):
|
|||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
@json_schema_type
|
||||
class BuiltinProviderSpec(ProviderSpec):
|
||||
provider_id: str = "builtin"
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
provider_id: str = "routing_table"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
api_dependencies: List[Api] = []
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
inner_specs: List[ProviderSpec]
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
|
@ -99,10 +111,7 @@ class BuiltinProviderSpec(ProviderSpec):
|
|||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
pip_packages: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -130,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
supported_model_ids: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The list of model ids that this adapter supports",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -243,9 +248,6 @@ in the runtime configuration to help route to the correct provider.""",
|
|||
)
|
||||
|
||||
|
||||
ProviderMapEntry = GenericProviderConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StackRunConfig(BaseModel):
|
||||
built_at: datetime
|
||||
|
@ -269,25 +271,17 @@ this could be just a hash
|
|||
description="""
|
||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||
)
|
||||
provider_map: Dict[str, ProviderMapEntry] = Field(
|
||||
|
||||
api_providers: Dict[str, GenericProviderConfig] = Field(
|
||||
description="""
|
||||
Provider configurations for each of the APIs provided by this package.
|
||||
|
||||
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
||||
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
||||
|
||||
As examples:
|
||||
- the "inference" API interprets the routing_key as a "model"
|
||||
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
||||
|
||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
||||
""",
|
||||
)
|
||||
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
|
||||
routing_tables: Dict[str, RoutingTableConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="""
|
||||
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
|
||||
|
||||
E.g. The following is a ProviderRoutingEntry for inference API:
|
||||
E.g. The following is a ProviderRoutingEntry for models:
|
||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||
provider_id: meta-reference
|
||||
config:
|
||||
|
|
|
@ -8,11 +8,15 @@ import importlib
|
|||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||
|
@ -30,6 +34,28 @@ def stack_apis() -> List[Api]:
|
|||
return [v for v in Api]
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
routing_table_api: Api
|
||||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.models,
|
||||
router_api=Api.inference,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.memory_banks,
|
||||
router_api=Api.memory,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
apis = {}
|
||||
|
||||
|
@ -40,6 +66,8 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
Api.shields: Shields,
|
||||
Api.memory_banks: MemoryBanks,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -68,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
for api in stack_apis():
|
||||
if api in routing_table_apis:
|
||||
continue
|
||||
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
|
|
|
@ -4,25 +4,47 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
async def get_router_impl(
|
||||
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
|
||||
):
|
||||
from .routers import InferenceRouter, MemoryRouter
|
||||
from .routing_table import RoutingTable
|
||||
async def get_routing_table_impl(
|
||||
api: Api,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
_deps,
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
api2routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
api_to_tables = {
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
}
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
# initialize routing table with concrete provider impls
|
||||
routing_table = RoutingTable(provider_routing_table)
|
||||
|
||||
impl = api2routers[api](routing_table)
|
||||
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
impl = api_to_routers[api.value](routing_table)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,17 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import RoutingTable
|
||||
|
||||
from .routing_table import RoutingTable
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from types import MethodType
|
||||
|
||||
from termcolor import cprint
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.api = Api.memory.value
|
||||
self.routing_table = routing_table
|
||||
self.bank_id_to_type = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.routing_table.initialize(self.api)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.routing_table.shutdown(self.api)
|
||||
pass
|
||||
|
||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||
bank_type = self.bank_id_to_type.get(bank_id)
|
||||
if not bank_type:
|
||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||
|
||||
return self.routing_table.get_provider_impl(self.api, bank_type)
|
||||
provider = self.routing_table.get_provider_impl(bank_type)
|
||||
if not provider:
|
||||
raise ValueError(f"Could not find provider for {bank_type}")
|
||||
return provider
|
||||
|
||||
async def create_memory_bank(
|
||||
self,
|
||||
|
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
|
|||
url: Optional[URL] = None,
|
||||
) -> MemoryBank:
|
||||
bank_type = config.type
|
||||
bank = await self.routing_table.get_provider_impl(
|
||||
self.api, bank_type
|
||||
provider = await self.routing_table.get_provider_impl(
|
||||
bank_type
|
||||
).create_memory_bank(name, config, url)
|
||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||
return bank
|
||||
|
||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
|
||||
provider = self.get_provider_from_bank_id(bank_id)
|
||||
return await provider.get_memory_bank(bank_id)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
|
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.api = Api.inference.value
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await self.routing_table.initialize(self.api)
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
await self.routing_table.shutdown(self.api)
|
||||
pass
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = [],
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||
async for chunk in self.routing_table.get_provider_impl(
|
||||
self.api, model
|
||||
).chat_completion(
|
||||
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
tool_choice=tool_choice,
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
stream=stream,
|
||||
|
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||
return await self.routing_table.get_provider_impl(self.api, model).completion(
|
||||
return await self.routing_table.get_provider_impl(model).completion(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
|
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
|
|||
model: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
|
||||
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||
model=model,
|
||||
contents=contents,
|
||||
)
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
Api,
|
||||
GenericProviderConfig,
|
||||
ProviderRoutingEntry,
|
||||
)
|
||||
from llama_stack.distribution.distribution import api_providers
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class RoutingTable:
|
||||
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
|
||||
self.provider_routing_table = provider_routing_table
|
||||
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
|
||||
self.api2routes = {}
|
||||
|
||||
async def initialize(self, api_str: str) -> None:
|
||||
"""Initialize the routing table with concrete provider impls"""
|
||||
if api_str not in self.provider_routing_table:
|
||||
raise ValueError(f"API {api_str} not found in routing table")
|
||||
|
||||
providers = api_providers()[Api(api_str)]
|
||||
routing_list = self.provider_routing_table[api_str]
|
||||
|
||||
self.api2routes[api_str] = {}
|
||||
for rt_entry in routing_list:
|
||||
rt_key = rt_entry.routing_key
|
||||
provider_id = rt_entry.provider_id
|
||||
impl = await instantiate_provider(
|
||||
providers[provider_id],
|
||||
deps=[],
|
||||
provider_config=GenericProviderConfig(
|
||||
provider_id=provider_id, config=rt_entry.config
|
||||
),
|
||||
)
|
||||
cprint(f"impl = {impl}", "red")
|
||||
self.api2routes[api_str][rt_key] = impl
|
||||
|
||||
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
|
||||
|
||||
async def shutdown(self, api_str: str) -> None:
|
||||
"""Shutdown the routing table"""
|
||||
if api_str not in self.api2routes:
|
||||
return
|
||||
|
||||
for impl in self.api2routes[api_str].values():
|
||||
await impl.shutdown()
|
||||
|
||||
def get_provider_impl(self, api: str, routing_key: str) -> Any:
|
||||
"""Get the provider impl for a given api and routing key"""
|
||||
return self.api2routes[api][routing_key]
|
118
llama_stack/distribution/routers/routing_tables.py
Normal file
118
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class CommonRoutingTableImpl(RoutingTable):
|
||||
def __init__(
|
||||
self,
|
||||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
) -> None:
|
||||
self.providers = {k: v for k, v in inner_impls}
|
||||
self.routing_keys = list(self.providers.keys())
|
||||
self.routing_table_config = routing_table_config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||
return self.providers.get(routing_key)
|
||||
|
||||
async def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
||||
async def get_provider_config(
|
||||
self, routing_key: str
|
||||
) -> Optional[GenericProviderConfig]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == routing_key:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
||||
async def list_models(self) -> List[ModelServingSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
model_id = entry.routing_key
|
||||
specs.append(
|
||||
ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == core_model_id:
|
||||
return ModelServingSpec(
|
||||
llama_model=resolve_model(core_model_id),
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
|
||||
async def list_shields(self) -> List[ShieldSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
specs.append(
|
||||
ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == shield_type:
|
||||
return ShieldSpec(
|
||||
shield_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||
specs = []
|
||||
for entry in self.routing_table_config.entries:
|
||||
specs.append(
|
||||
MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||
for entry in self.routing_table_config.entries:
|
||||
if entry.routing_key == bank_type:
|
||||
return MemoryBankSpec(
|
||||
bank_type=entry.routing_key,
|
||||
provider_config=entry,
|
||||
)
|
||||
return None
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import signal
|
||||
|
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import (
|
||||
instantiate_builtin_provider,
|
||||
instantiate_provider,
|
||||
instantiate_router,
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_endpoints,
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
|
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
|
|||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||
|
||||
|
||||
async def resolve_impls_with_routing(
|
||||
stack_run_config: StackRunConfig,
|
||||
) -> Dict[Api, Any]:
|
||||
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
|
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
|
|||
"""
|
||||
all_providers = api_providers()
|
||||
specs = {}
|
||||
configs = {}
|
||||
|
||||
for api_str in stack_run_config.apis_to_serve:
|
||||
for api_str, config in run_config.api_providers.items():
|
||||
api = Api(api_str)
|
||||
|
||||
# TODO: check that these APIs are not in the routing table part of the config
|
||||
providers = all_providers[api]
|
||||
|
||||
# check for routing table, we need to pass routing table to the router implementation
|
||||
if api_str in stack_run_config.provider_routing_table:
|
||||
specs[api] = RouterProviderSpec(
|
||||
api=api,
|
||||
module=f"llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||
)
|
||||
else:
|
||||
if api_str in stack_run_config.provider_map:
|
||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||
provider_id = provider_map_entry.provider_id
|
||||
else:
|
||||
# not defined in config, will be a builtin provider, assign builtin provider id
|
||||
provider_id = "builtin"
|
||||
|
||||
if provider_id not in providers:
|
||||
if config.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[provider_id]
|
||||
specs[api] = providers[config.provider_id]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_tables.keys())
|
||||
)
|
||||
print("apis_to_serve", apis_to_serve)
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
), f"Routing table API {source_api} specified in wrong place?"
|
||||
assert (
|
||||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if source_api.value not in run_config.routing_tables:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
|
||||
routing_table = run_config.routing_tables[source_api.value]
|
||||
|
||||
providers = all_providers[info.router_api]
|
||||
|
||||
inner_specs = []
|
||||
for rt_entry in routing_table.entries:
|
||||
if rt_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
inner_specs.append(providers[rt_entry.provider_id])
|
||||
|
||||
specs[source_api] = RoutingTableProviderSpec(
|
||||
api=source_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=[],
|
||||
inner_specs=inner_specs,
|
||||
)
|
||||
configs[source_api] = routing_table
|
||||
|
||||
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||
api=info.router_api,
|
||||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=source_api,
|
||||
api_dependencies=[source_api],
|
||||
)
|
||||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_id}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||
if api.value in stack_run_config.provider_map:
|
||||
provider_config = stack_run_config.provider_map[api.value]
|
||||
impl = await instantiate_provider(spec, deps, provider_config)
|
||||
elif api.value in stack_run_config.provider_routing_table:
|
||||
impl = await instantiate_router(
|
||||
spec, api.value, stack_run_config.provider_routing_table
|
||||
)
|
||||
else:
|
||||
impl = await instantiate_builtin_provider(spec, stack_run_config)
|
||||
impl = await instantiate_provider(spec, deps, configs[api])
|
||||
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in apis_to_serve:
|
||||
apis_to_serve.add(inf.routing_table_api)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
|
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
provider_spec.provider_data_validator,
|
||||
(
|
||||
provider_spec.provider_data_validator
|
||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name):
|
|||
return getattr(module, class_name)
|
||||
|
||||
|
||||
async def instantiate_router(
|
||||
provider_spec: RouterProviderSpec,
|
||||
api: str,
|
||||
provider_routing_table: Dict[str, Any],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
fn = getattr(module, "get_router_impl")
|
||||
impl = await fn(api, provider_routing_table)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
return impl
|
||||
|
||||
|
||||
async def instantiate_builtin_provider(
|
||||
provider_spec: BuiltinProviderSpec,
|
||||
run_config: StackRunConfig,
|
||||
):
|
||||
print("!!! instantiate_builtin_provider")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
fn = getattr(module, "get_builtin_impl")
|
||||
impl = await fn(run_config)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
return impl
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
deps: Dict[str, Any],
|
||||
provider_config: ProviderMapEntry,
|
||||
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||
):
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
|
||||
|
@ -60,6 +35,29 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider_config.config)
|
||||
args = [config, deps]
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
||||
assert isinstance(provider_config, RoutingTableConfig)
|
||||
routing_table = provider_config
|
||||
|
||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||
inner_impls = []
|
||||
for routing_entry in routing_table.entries:
|
||||
impl = await instantiate_provider(
|
||||
inner_specs[routing_entry.provider_id],
|
||||
deps,
|
||||
routing_entry,
|
||||
)
|
||||
inner_impls.append((routing_entry.routing_key, impl))
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||
else:
|
||||
method = "get_provider_impl"
|
||||
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
|
||||
|
||||
from .config import BuiltinImplConfig # noqa
|
||||
|
||||
|
||||
async def get_builtin_impl(config: StackRunConfig):
|
||||
from .models import BuiltinModelsImpl
|
||||
|
||||
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BuiltinModelsImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,113 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import StopReason
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.distribution import Api, api_providers
|
||||
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.datatypes import CoreModelId, Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from termcolor import cprint
|
||||
|
||||
|
||||
class BuiltinModelsImpl(Models):
|
||||
def __init__(
|
||||
self,
|
||||
config: StackRunConfig,
|
||||
) -> None:
|
||||
self.run_config = config
|
||||
self.models = {}
|
||||
# check against inference & safety api
|
||||
apis_with_models = [Api.inference, Api.safety]
|
||||
|
||||
all_providers = api_providers()
|
||||
|
||||
for api in apis_with_models:
|
||||
|
||||
# check against provider_map (simple case single model)
|
||||
if api.value in config.provider_map:
|
||||
providers_for_api = all_providers[api]
|
||||
provider_spec = config.provider_map[api.value]
|
||||
core_model_id = provider_spec.config
|
||||
# get supported model ids from the provider
|
||||
supported_model_ids = self.get_supported_model_ids(
|
||||
api.value, provider_spec, providers_for_api
|
||||
)
|
||||
for model_id in supported_model_ids:
|
||||
self.models[model_id] = ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=provider_spec,
|
||||
api=api.value,
|
||||
)
|
||||
|
||||
# check against provider_routing_table (router with multiple models)
|
||||
# with routing table, we use the routing_key as the supported models
|
||||
if api.value in config.provider_routing_table:
|
||||
routing_table = config.provider_routing_table[api.value]
|
||||
for rt_entry in routing_table:
|
||||
model_id = rt_entry.routing_key
|
||||
self.models[model_id] = ModelServingSpec(
|
||||
llama_model=resolve_model(model_id),
|
||||
provider_config=GenericProviderConfig(
|
||||
provider_id=rt_entry.provider_id,
|
||||
config=rt_entry.config,
|
||||
),
|
||||
api=api.value,
|
||||
)
|
||||
|
||||
print("BuiltinModelsImpl models", self.models)
|
||||
|
||||
def get_supported_model_ids(
|
||||
self,
|
||||
api: str,
|
||||
provider_spec: GenericProviderConfig,
|
||||
providers_for_api: Dict[str, ProviderSpec],
|
||||
) -> List[str]:
|
||||
serving_models_list = []
|
||||
if api == Api.inference.value:
|
||||
provider_id = provider_spec.provider_id
|
||||
if provider_id == "meta-reference":
|
||||
serving_models_list.append(provider_spec.config["model"])
|
||||
if provider_id in {
|
||||
remote_provider_id("ollama"),
|
||||
remote_provider_id("fireworks"),
|
||||
remote_provider_id("together"),
|
||||
}:
|
||||
adapter_supported_models = providers_for_api[
|
||||
provider_id
|
||||
].adapter.supported_model_ids
|
||||
serving_models_list.extend(adapter_supported_models)
|
||||
elif api == Api.safety.value:
|
||||
if provider_spec.config and "llama_guard_shield" in provider_spec.config:
|
||||
llama_guard_shield = provider_spec.config["llama_guard_shield"]
|
||||
serving_models_list.append(llama_guard_shield["model"])
|
||||
if provider_spec.config and "prompt_guard_shield" in provider_spec.config:
|
||||
prompt_guard_shield = provider_spec.config["prompt_guard_shield"]
|
||||
serving_models_list.append(prompt_guard_shield["model"])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported api {api} for builtin models")
|
||||
|
||||
return serving_models_list
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> ModelsListResponse:
|
||||
return ModelsListResponse(models_list=list(self.models.values()))
|
||||
|
||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
|
||||
if core_model_id in self.models:
|
||||
return ModelsGetResponse(core_model_spec=self.models[core_model_id])
|
||||
print(f"Cannot find {core_model_id} in model registry")
|
||||
return ModelsGetResponse()
|
|
@ -6,17 +6,14 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
|
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
|
|
|
@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter_id="ollama",
|
||||
pip_packages=["ollama"],
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.adapters.inference.fireworks",
|
||||
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
"Meta-Llama3.1-405B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.adapters.inference.together",
|
||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
"Meta-Llama3.1-405B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
BuiltinProviderSpec(
|
||||
api=Api.models,
|
||||
provider_id="builtin",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.impls.builtin.models",
|
||||
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
|
||||
api_dependencies=[],
|
||||
)
|
||||
]
|
Loading…
Add table
Add a link
Reference in a new issue