mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Further generalize Xi's changes (#88)
* Further generalize Xi's changes - introduce a slightly more general notion of an AutoRouted provider - the AutoRouted provider is associated with a RoutingTable provider - e.g. inference -> models - Introduced safety -> shields and memory -> memory_banks correspondences * typo * Basic build and run succeeded
This commit is contained in:
parent
b8914bb56f
commit
c1ab66f1e6
21 changed files with 597 additions and 418 deletions
|
@ -3,9 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
from .memory_banks import * # noqa: F401 F403
|
||||||
@json_schema_type
|
|
||||||
class BuiltinImplConfig(BaseModel): ...
|
|
67
llama_stack/apis/memory_banks/client.py
Normal file
67
llama_stack/apis/memory_banks/client.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBanksClient(MemoryBanks):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/memory_banks/list",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return [MemoryBankSpec(**x) for x in response.json()]
|
||||||
|
|
||||||
|
async def get_memory_bank(
|
||||||
|
self, bank_type: MemoryBankType
|
||||||
|
) -> Optional[MemoryBankSpec]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/memory_banks/get",
|
||||||
|
json={
|
||||||
|
"bank_type": bank_type,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
j = response.json()
|
||||||
|
if j is None:
|
||||||
|
return None
|
||||||
|
return MemoryBankSpec(**j)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
client = MemoryBanksClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
response = await client.list_memory_banks()
|
||||||
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int, stream: bool = True):
|
||||||
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
32
llama_stack/apis/memory_banks/memory_banks.py
Normal file
32
llama_stack/apis/memory_banks/memory_banks.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.apis.memory import MemoryBankType
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBankSpec(BaseModel):
|
||||||
|
bank_type: MemoryBankType
|
||||||
|
provider_config: GenericProviderConfig = Field(
|
||||||
|
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBanks(Protocol):
|
||||||
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_banks/get", method="GET")
|
||||||
|
async def get_memory_bank(
|
||||||
|
self, bank_type: MemoryBankType
|
||||||
|
) -> Optional[MemoryBankSpec]: ...
|
|
@ -5,15 +5,11 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .models import * # noqa: F403
|
from .models import * # noqa: F403
|
||||||
|
@ -29,18 +25,18 @@ class ModelsClient(Models):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_models(self) -> ModelsListResponse:
|
async def list_models(self) -> List[ModelServingSpec]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
f"{self.base_url}/models/list",
|
f"{self.base_url}/models/list",
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return ModelsListResponse(**response.json())
|
return [ModelServingSpec(**x) for x in response.json()]
|
||||||
|
|
||||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.get(
|
||||||
f"{self.base_url}/models/get",
|
f"{self.base_url}/models/get",
|
||||||
json={
|
json={
|
||||||
"core_model_id": core_model_id,
|
"core_model_id": core_model_id,
|
||||||
|
@ -48,7 +44,10 @@ class ModelsClient(Models):
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return ModelsGetResponse(**response.json())
|
j = response.json()
|
||||||
|
if j is None:
|
||||||
|
return None
|
||||||
|
return ModelServingSpec(**j)
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int, stream: bool):
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Model
|
from llama_models.llama3.api.datatypes import Model
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelServingSpec(BaseModel):
|
class ModelServingSpec(BaseModel):
|
||||||
|
@ -21,25 +22,11 @@ class ModelServingSpec(BaseModel):
|
||||||
provider_config: GenericProviderConfig = Field(
|
provider_config: GenericProviderConfig = Field(
|
||||||
description="Provider config for the model, including provider_id, and corresponding config. ",
|
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||||
)
|
)
|
||||||
api: str = Field(
|
|
||||||
description="The API that this model is serving (e.g. inference / safety).",
|
|
||||||
default="inference",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelsListResponse(BaseModel):
|
|
||||||
models_list: List[ModelServingSpec]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelsGetResponse(BaseModel):
|
|
||||||
core_model_spec: Optional[ModelServingSpec] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models/list", method="GET")
|
@webmethod(route="/models/list", method="GET")
|
||||||
async def list_models(self) -> ModelsListResponse: ...
|
async def list_models(self) -> List[ModelServingSpec]: ...
|
||||||
|
|
||||||
@webmethod(route="/models/get", method="POST")
|
@webmethod(route="/models/get", method="GET")
|
||||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse: ...
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: ...
|
||||||
|
|
7
llama_stack/apis/shields/__init__.py
Normal file
7
llama_stack/apis/shields/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .shields import * # noqa: F401 F403
|
67
llama_stack/apis/shields/client.py
Normal file
67
llama_stack/apis/shields/client.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .shields import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsClient(Shields):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldSpec]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/shields/list",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return [ShieldSpec(**x) for x in response.json()]
|
||||||
|
|
||||||
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/shields/get",
|
||||||
|
json={
|
||||||
|
"shield_type": shield_type,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
j = response.json()
|
||||||
|
if j is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ShieldSpec(**j)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int, stream: bool):
|
||||||
|
client = ShieldsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
response = await client.list_shields()
|
||||||
|
cprint(f"list_shields response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int, stream: bool = True):
|
||||||
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
28
llama_stack/apis/shields/shields.py
Normal file
28
llama_stack/apis/shields/shields.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ShieldSpec(BaseModel):
|
||||||
|
shield_type: str
|
||||||
|
provider_config: GenericProviderConfig = Field(
|
||||||
|
description="Provider config for the model, including provider_id, and corresponding config. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Shields(Protocol):
|
||||||
|
@webmethod(route="/shields/list", method="GET")
|
||||||
|
async def list_shields(self) -> List[ShieldSpec]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/shields/get", method="GET")
|
||||||
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: ...
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
@ -19,8 +19,12 @@ class Api(Enum):
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
|
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
models = "models"
|
models = "models"
|
||||||
|
shields = "shields"
|
||||||
|
memory_banks = "memory_banks"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -44,27 +48,36 @@ class ProviderSpec(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingTable(Protocol):
|
||||||
|
def get_routing_keys(self) -> List[str]: ...
|
||||||
|
|
||||||
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
class GenericProviderConfig(BaseModel):
|
class GenericProviderConfig(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
config: Dict[str, Any]
|
config: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class RoutableProviderConfig(GenericProviderConfig):
|
||||||
class ProviderRoutingEntry(GenericProviderConfig):
|
|
||||||
routing_key: str
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingTableConfig(BaseModel):
|
||||||
|
entries: List[RoutableProviderConfig] = Field(...)
|
||||||
|
keys: Optional[List[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Example: /inference, /safety
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouterProviderSpec(ProviderSpec):
|
class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
provider_id: str = "router"
|
provider_id: str = "router"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
|
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
|
routing_table_api: Api
|
||||||
routing_table: List[ProviderRoutingEntry] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="Routing table entries corresponding to the API",
|
|
||||||
)
|
|
||||||
module: str = Field(
|
module: str = Field(
|
||||||
...,
|
...,
|
||||||
description="""
|
description="""
|
||||||
|
@ -79,18 +92,17 @@ class RouterProviderSpec(ProviderSpec):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
# Example: /models, /shields
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BuiltinProviderSpec(ProviderSpec):
|
class RoutingTableProviderSpec(ProviderSpec):
|
||||||
provider_id: str = "builtin"
|
provider_id: str = "routing_table"
|
||||||
config_class: str = ""
|
config_class: str = ""
|
||||||
docker_image: Optional[str] = None
|
docker_image: Optional[str] = None
|
||||||
api_dependencies: List[Api] = []
|
|
||||||
provider_data_validator: Optional[str] = Field(
|
inner_specs: List[ProviderSpec]
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
module: str = Field(
|
module: str = Field(
|
||||||
...,
|
...,
|
||||||
description="""
|
description="""
|
||||||
|
@ -99,10 +111,7 @@ class BuiltinProviderSpec(ProviderSpec):
|
||||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
pip_packages: List[str] = Field(
|
pip_packages: List[str] = Field(default_factory=list)
|
||||||
default_factory=list,
|
|
||||||
description="The pip dependencies needed for this implementation",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -130,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
||||||
provider_data_validator: Optional[str] = Field(
|
provider_data_validator: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
supported_model_ids: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="The list of model ids that this adapter supports",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -243,9 +248,6 @@ in the runtime configuration to help route to the correct provider.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ProviderMapEntry = GenericProviderConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
built_at: datetime
|
built_at: datetime
|
||||||
|
@ -269,25 +271,17 @@ this could be just a hash
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
provider_map: Dict[str, ProviderMapEntry] = Field(
|
|
||||||
|
api_providers: Dict[str, GenericProviderConfig] = Field(
|
||||||
description="""
|
description="""
|
||||||
Provider configurations for each of the APIs provided by this package.
|
Provider configurations for each of the APIs provided by this package.
|
||||||
|
""",
|
||||||
Given an API, you can specify a single provider or a "routing table". Each entry in the routing
|
|
||||||
table has a (routing_key, provider_config) tuple. How the key is interpreted is API-specific.
|
|
||||||
|
|
||||||
As examples:
|
|
||||||
- the "inference" API interprets the routing_key as a "model"
|
|
||||||
- the "memory" API interprets the routing_key as the type of a "memory bank"
|
|
||||||
|
|
||||||
The key may support wild-cards alsothe routing_key to route to the correct provider.""",
|
|
||||||
)
|
)
|
||||||
provider_routing_table: Dict[str, List[ProviderRoutingEntry]] = Field(
|
routing_tables: Dict[str, RoutingTableConfig] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
API: List[ProviderRoutingEntry] map. Each ProviderRoutingEntry is a (routing_key, provider_config) tuple.
|
|
||||||
|
|
||||||
E.g. The following is a ProviderRoutingEntry for inference API:
|
E.g. The following is a ProviderRoutingEntry for models:
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -8,11 +8,15 @@ import importlib
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec
|
||||||
|
@ -30,6 +34,28 @@ def stack_apis() -> List[Api]:
|
||||||
return [v for v in Api]
|
return [v for v in Api]
|
||||||
|
|
||||||
|
|
||||||
|
class AutoRoutedApiInfo(BaseModel):
|
||||||
|
routing_table_api: Api
|
||||||
|
router_api: Api
|
||||||
|
|
||||||
|
|
||||||
|
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||||
|
return [
|
||||||
|
AutoRoutedApiInfo(
|
||||||
|
routing_table_api=Api.models,
|
||||||
|
router_api=Api.inference,
|
||||||
|
),
|
||||||
|
AutoRoutedApiInfo(
|
||||||
|
routing_table_api=Api.shields,
|
||||||
|
router_api=Api.safety,
|
||||||
|
),
|
||||||
|
AutoRoutedApiInfo(
|
||||||
|
routing_table_api=Api.memory_banks,
|
||||||
|
router_api=Api.memory,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
|
@ -40,6 +66,8 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.memory: Memory,
|
Api.memory: Memory,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
|
Api.shields: Shields,
|
||||||
|
Api.memory_banks: MemoryBanks,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
@ -68,7 +96,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
|
|
||||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
ret = {}
|
ret = {}
|
||||||
|
routing_table_apis = set(
|
||||||
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
|
)
|
||||||
for api in stack_apis():
|
for api in stack_apis():
|
||||||
|
if api in routing_table_apis:
|
||||||
|
continue
|
||||||
|
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {
|
||||||
|
|
|
@ -4,25 +4,47 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderRoutingEntry
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
async def get_router_impl(
|
async def get_routing_table_impl(
|
||||||
api: str, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]
|
api: Api,
|
||||||
):
|
inner_impls: List[Tuple[str, Any]],
|
||||||
from .routers import InferenceRouter, MemoryRouter
|
routing_table_config: RoutingTableConfig,
|
||||||
from .routing_table import RoutingTable
|
_deps,
|
||||||
|
) -> Any:
|
||||||
|
from .routing_tables import (
|
||||||
|
MemoryBanksRoutingTable,
|
||||||
|
ModelsRoutingTable,
|
||||||
|
ShieldsRoutingTable,
|
||||||
|
)
|
||||||
|
|
||||||
api2routers = {
|
api_to_tables = {
|
||||||
"memory": MemoryRouter,
|
"memory_banks": MemoryBanksRoutingTable,
|
||||||
"inference": InferenceRouter,
|
"models": ModelsRoutingTable,
|
||||||
|
"shields": ShieldsRoutingTable,
|
||||||
}
|
}
|
||||||
|
if api.value not in api_to_tables:
|
||||||
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
# initialize routing table with concrete provider impls
|
impl = api_to_tables[api.value](inner_impls, routing_table_config)
|
||||||
routing_table = RoutingTable(provider_routing_table)
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
impl = api2routers[api](routing_table)
|
|
||||||
|
|
||||||
|
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||||
|
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||||
|
|
||||||
|
api_to_routers = {
|
||||||
|
"memory": MemoryRouter,
|
||||||
|
"inference": InferenceRouter,
|
||||||
|
"safety": SafetyRouter,
|
||||||
|
}
|
||||||
|
if api.value not in api_to_routers:
|
||||||
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
|
impl = api_to_routers[api.value](routing_table)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -4,17 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
from typing import Any, AsyncGenerator, Dict, List
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import RoutingTable
|
||||||
|
|
||||||
from .routing_table import RoutingTable
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRouter(Memory):
|
class MemoryRouter(Memory):
|
||||||
|
@ -24,22 +20,24 @@ class MemoryRouter(Memory):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api = Api.memory.value
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
self.bank_id_to_type = {}
|
self.bank_id_to_type = {}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self.routing_table.initialize(self.api)
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.routing_table.shutdown(self.api)
|
pass
|
||||||
|
|
||||||
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
def get_provider_from_bank_id(self, bank_id: str) -> Any:
|
||||||
bank_type = self.bank_id_to_type.get(bank_id)
|
bank_type = self.bank_id_to_type.get(bank_id)
|
||||||
if not bank_type:
|
if not bank_type:
|
||||||
raise ValueError(f"Could not find bank type for {bank_id}")
|
raise ValueError(f"Could not find bank type for {bank_id}")
|
||||||
|
|
||||||
return self.routing_table.get_provider_impl(self.api, bank_type)
|
provider = self.routing_table.get_provider_impl(bank_type)
|
||||||
|
if not provider:
|
||||||
|
raise ValueError(f"Could not find provider for {bank_type}")
|
||||||
|
return provider
|
||||||
|
|
||||||
async def create_memory_bank(
|
async def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
|
@ -48,14 +46,15 @@ class MemoryRouter(Memory):
|
||||||
url: Optional[URL] = None,
|
url: Optional[URL] = None,
|
||||||
) -> MemoryBank:
|
) -> MemoryBank:
|
||||||
bank_type = config.type
|
bank_type = config.type
|
||||||
bank = await self.routing_table.get_provider_impl(
|
provider = await self.routing_table.get_provider_impl(
|
||||||
self.api, bank_type
|
bank_type
|
||||||
).create_memory_bank(name, config, url)
|
).create_memory_bank(name, config, url)
|
||||||
self.bank_id_to_type[bank.bank_id] = bank_type
|
self.bank_id_to_type[bank.bank_id] = bank_type
|
||||||
return bank
|
return bank
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
|
||||||
return await self.get_provider_from_bank_id(bank_id).get_memory_bank(bank_id)
|
provider = self.get_provider_from_bank_id(bank_id)
|
||||||
|
return await provider.get_memory_bank(bank_id)
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
|
@ -85,34 +84,31 @@ class InferenceRouter(Inference):
|
||||||
self,
|
self,
|
||||||
routing_table: RoutingTable,
|
routing_table: RoutingTable,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api = Api.inference.value
|
|
||||||
self.routing_table = routing_table
|
self.routing_table = routing_table
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await self.routing_table.initialize(self.api)
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.routing_table.shutdown(self.api)
|
pass
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
tools: Optional[List[ToolDefinition]] = [],
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
# TODO: we need to fix streaming response to align provider implementations with Protocol.
|
||||||
async for chunk in self.routing_table.get_provider_impl(
|
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
|
||||||
self.api, model
|
|
||||||
).chat_completion(
|
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
tools=tools,
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -128,7 +124,7 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
|
||||||
return await self.routing_table.get_provider_impl(self.api, model).completion(
|
return await self.routing_table.get_provider_impl(model).completion(
|
||||||
model=model,
|
model=model,
|
||||||
content=content,
|
content=content,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -141,7 +137,33 @@ class InferenceRouter(Inference):
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
return await self.routing_table.get_provider_impl(self.api, model).embeddings(
|
return await self.routing_table.get_provider_impl(model).embeddings(
|
||||||
model=model,
|
model=model,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_type: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||||
|
shield_type=shield_type,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
Api,
|
|
||||||
GenericProviderConfig,
|
|
||||||
ProviderRoutingEntry,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.distribution import api_providers
|
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable:
|
|
||||||
def __init__(self, provider_routing_table: Dict[str, List[ProviderRoutingEntry]]):
|
|
||||||
self.provider_routing_table = provider_routing_table
|
|
||||||
# map {api: {routing_key: impl}}, e.g. {'inference': {'8b': <MetaReferenceImpl>, '70b': <OllamaImpl>}}
|
|
||||||
self.api2routes = {}
|
|
||||||
|
|
||||||
async def initialize(self, api_str: str) -> None:
|
|
||||||
"""Initialize the routing table with concrete provider impls"""
|
|
||||||
if api_str not in self.provider_routing_table:
|
|
||||||
raise ValueError(f"API {api_str} not found in routing table")
|
|
||||||
|
|
||||||
providers = api_providers()[Api(api_str)]
|
|
||||||
routing_list = self.provider_routing_table[api_str]
|
|
||||||
|
|
||||||
self.api2routes[api_str] = {}
|
|
||||||
for rt_entry in routing_list:
|
|
||||||
rt_key = rt_entry.routing_key
|
|
||||||
provider_id = rt_entry.provider_id
|
|
||||||
impl = await instantiate_provider(
|
|
||||||
providers[provider_id],
|
|
||||||
deps=[],
|
|
||||||
provider_config=GenericProviderConfig(
|
|
||||||
provider_id=provider_id, config=rt_entry.config
|
|
||||||
),
|
|
||||||
)
|
|
||||||
cprint(f"impl = {impl}", "red")
|
|
||||||
self.api2routes[api_str][rt_key] = impl
|
|
||||||
|
|
||||||
cprint(f"> Initialized implementations for {api_str} in routing table", "blue")
|
|
||||||
|
|
||||||
async def shutdown(self, api_str: str) -> None:
|
|
||||||
"""Shutdown the routing table"""
|
|
||||||
if api_str not in self.api2routes:
|
|
||||||
return
|
|
||||||
|
|
||||||
for impl in self.api2routes[api_str].values():
|
|
||||||
await impl.shutdown()
|
|
||||||
|
|
||||||
def get_provider_impl(self, api: str, routing_key: str) -> Any:
|
|
||||||
"""Get the provider impl for a given api and routing key"""
|
|
||||||
return self.api2routes[api][routing_key]
|
|
118
llama_stack/distribution/routers/routing_tables.py
Normal file
118
llama_stack/distribution/routers/routing_tables.py
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_stack.apis.shields import * # noqa: F403
|
||||||
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner_impls: List[Tuple[str, Any]],
|
||||||
|
routing_table_config: RoutingTableConfig,
|
||||||
|
) -> None:
|
||||||
|
self.providers = {k: v for k, v in inner_impls}
|
||||||
|
self.routing_keys = list(self.providers.keys())
|
||||||
|
self.routing_table_config = routing_table_config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.providers.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
async def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||||
|
return self.providers.get(routing_key)
|
||||||
|
|
||||||
|
async def get_routing_keys(self) -> List[str]:
|
||||||
|
return self.routing_keys
|
||||||
|
|
||||||
|
async def get_provider_config(
|
||||||
|
self, routing_key: str
|
||||||
|
) -> Optional[GenericProviderConfig]:
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
if entry.routing_key == routing_key:
|
||||||
|
return entry
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
|
async def list_models(self) -> List[ModelServingSpec]:
|
||||||
|
specs = []
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
model_id = entry.routing_key
|
||||||
|
specs.append(
|
||||||
|
ModelServingSpec(
|
||||||
|
llama_model=resolve_model(model_id),
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return specs
|
||||||
|
|
||||||
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
if entry.routing_key == core_model_id:
|
||||||
|
return ModelServingSpec(
|
||||||
|
llama_model=resolve_model(core_model_id),
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
|
async def list_shields(self) -> List[ShieldSpec]:
|
||||||
|
specs = []
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
specs.append(
|
||||||
|
ShieldSpec(
|
||||||
|
shield_type=entry.routing_key,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return specs
|
||||||
|
|
||||||
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
if entry.routing_key == shield_type:
|
||||||
|
return ShieldSpec(
|
||||||
|
shield_type=entry.routing_key,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
|
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||||
|
specs = []
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
specs.append(
|
||||||
|
MemoryBankSpec(
|
||||||
|
bank_type=entry.routing_key,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return specs
|
||||||
|
|
||||||
|
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||||
|
for entry in self.routing_table_config.entries:
|
||||||
|
if entry.routing_key == bank_type:
|
||||||
|
return MemoryBankSpec(
|
||||||
|
bank_type=entry.routing_key,
|
||||||
|
provider_config=entry,
|
||||||
|
)
|
||||||
|
return None
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
|
@ -36,6 +35,9 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -43,18 +45,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
from llama_stack.distribution.distribution import (
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
api_endpoints,
|
||||||
from llama_stack.distribution.utils.dynamic import (
|
api_providers,
|
||||||
instantiate_builtin_provider,
|
builtin_automatically_routed_apis,
|
||||||
instantiate_provider,
|
|
||||||
instantiate_router,
|
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def is_async_iterator_type(typ):
|
||||||
|
@ -292,9 +291,7 @@ def snake_to_camel(snake_str):
|
||||||
return "".join(word.capitalize() for word in snake_str.split("_"))
|
return "".join(word.capitalize() for word in snake_str.split("_"))
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_with_routing(
|
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
stack_run_config: StackRunConfig,
|
|
||||||
) -> Dict[Api, Any]:
|
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
|
@ -302,48 +299,80 @@ async def resolve_impls_with_routing(
|
||||||
"""
|
"""
|
||||||
all_providers = api_providers()
|
all_providers = api_providers()
|
||||||
specs = {}
|
specs = {}
|
||||||
|
configs = {}
|
||||||
|
|
||||||
for api_str in stack_run_config.apis_to_serve:
|
for api_str, config in run_config.api_providers.items():
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
|
# TODO: check that these APIs are not in the routing table part of the config
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
|
|
||||||
# check for routing table, we need to pass routing table to the router implementation
|
if config.provider_id not in providers:
|
||||||
if api_str in stack_run_config.provider_routing_table:
|
raise ValueError(
|
||||||
specs[api] = RouterProviderSpec(
|
f"Unknown provider `{config.provider_id}` is not available for API `{api}`"
|
||||||
api=api,
|
|
||||||
module=f"llama_stack.distribution.routers",
|
|
||||||
api_dependencies=[],
|
|
||||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
|
||||||
)
|
)
|
||||||
else:
|
specs[api] = providers[config.provider_id]
|
||||||
if api_str in stack_run_config.provider_map:
|
configs[api] = config
|
||||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
|
||||||
provider_id = provider_map_entry.provider_id
|
|
||||||
else:
|
|
||||||
# not defined in config, will be a builtin provider, assign builtin provider id
|
|
||||||
provider_id = "builtin"
|
|
||||||
|
|
||||||
if provider_id not in providers:
|
apis_to_serve = run_config.apis_to_serve or set(
|
||||||
|
list(specs.keys()) + list(run_config.routing_tables.keys())
|
||||||
|
)
|
||||||
|
print("apis_to_serve", apis_to_serve)
|
||||||
|
for info in builtin_automatically_routed_apis():
|
||||||
|
source_api = info.routing_table_api
|
||||||
|
|
||||||
|
assert (
|
||||||
|
source_api not in specs
|
||||||
|
), f"Routing table API {source_api} specified in wrong place?"
|
||||||
|
assert (
|
||||||
|
info.router_api not in specs
|
||||||
|
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||||
|
|
||||||
|
if info.router_api.value not in apis_to_serve:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source_api.value not in run_config.routing_tables:
|
||||||
|
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||||
|
|
||||||
|
routing_table = run_config.routing_tables[source_api.value]
|
||||||
|
|
||||||
|
providers = all_providers[info.router_api]
|
||||||
|
|
||||||
|
inner_specs = []
|
||||||
|
for rt_entry in routing_table.entries:
|
||||||
|
if rt_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
specs[api] = providers[provider_id]
|
inner_specs.append(providers[rt_entry.provider_id])
|
||||||
|
|
||||||
|
specs[source_api] = RoutingTableProviderSpec(
|
||||||
|
api=source_api,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
api_dependencies=[],
|
||||||
|
inner_specs=inner_specs,
|
||||||
|
)
|
||||||
|
configs[source_api] = routing_table
|
||||||
|
|
||||||
|
specs[info.router_api] = AutoRoutedProviderSpec(
|
||||||
|
api=info.router_api,
|
||||||
|
module="llama_stack.distribution.routers",
|
||||||
|
routing_table_api=source_api,
|
||||||
|
api_dependencies=[source_api],
|
||||||
|
)
|
||||||
|
configs[info.router_api] = {}
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||||
|
for spec in sorted_specs:
|
||||||
|
print(f" {spec.api}: {spec.provider_id}")
|
||||||
|
print("")
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
api = spec.api
|
api = spec.api
|
||||||
deps = {api: impls[api] for api in spec.api_dependencies}
|
deps = {api: impls[api] for api in spec.api_dependencies}
|
||||||
if api.value in stack_run_config.provider_map:
|
impl = await instantiate_provider(spec, deps, configs[api])
|
||||||
provider_config = stack_run_config.provider_map[api.value]
|
|
||||||
impl = await instantiate_provider(spec, deps, provider_config)
|
|
||||||
elif api.value in stack_run_config.provider_routing_table:
|
|
||||||
impl = await instantiate_router(
|
|
||||||
spec, api.value, stack_run_config.provider_routing_table
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
impl = await instantiate_builtin_provider(spec, stack_run_config)
|
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls, specs
|
return impls, specs
|
||||||
|
@ -355,16 +384,23 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# impls, specs = asyncio.run(resolve_impls(config.provider_map))
|
|
||||||
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
impls, specs = asyncio.run(resolve_impls_with_routing(config))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
all_endpoints = api_endpoints()
|
all_endpoints = api_endpoints()
|
||||||
|
|
||||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
if config.apis_to_serve:
|
||||||
|
apis_to_serve = set(config.apis_to_serve)
|
||||||
|
for inf in builtin_automatically_routed_apis():
|
||||||
|
if inf.router_api.value in apis_to_serve:
|
||||||
|
apis_to_serve.add(inf.routing_table_api)
|
||||||
|
else:
|
||||||
|
apis_to_serve = set(impls.keys())
|
||||||
|
|
||||||
for api_str in apis_to_serve:
|
for api_str in apis_to_serve:
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
|
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
|
@ -391,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
endpoint.method,
|
||||||
provider_spec.provider_data_validator,
|
(
|
||||||
|
provider_spec.provider_data_validator
|
||||||
|
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,36 +16,11 @@ def instantiate_class_type(fully_qualified_name):
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_router(
|
|
||||||
provider_spec: RouterProviderSpec,
|
|
||||||
api: str,
|
|
||||||
provider_routing_table: Dict[str, Any],
|
|
||||||
):
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
|
||||||
|
|
||||||
fn = getattr(module, "get_router_impl")
|
|
||||||
impl = await fn(api, provider_routing_table)
|
|
||||||
impl.__provider_spec__ = provider_spec
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
async def instantiate_builtin_provider(
|
|
||||||
provider_spec: BuiltinProviderSpec,
|
|
||||||
run_config: StackRunConfig,
|
|
||||||
):
|
|
||||||
print("!!! instantiate_builtin_provider")
|
|
||||||
module = importlib.import_module(provider_spec.module)
|
|
||||||
fn = getattr(module, "get_builtin_impl")
|
|
||||||
impl = await fn(run_config)
|
|
||||||
impl.__provider_spec__ = provider_spec
|
|
||||||
return impl
|
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
provider_config: ProviderMapEntry,
|
provider_config: Union[GenericProviderConfig, RoutingTable],
|
||||||
):
|
):
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
|
||||||
|
@ -60,6 +35,29 @@ async def instantiate_provider(
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider_config.config)
|
config = config_type(**provider_config.config)
|
||||||
args = [config, deps]
|
args = [config, deps]
|
||||||
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
|
method = "get_auto_router_impl"
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||||
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
|
assert isinstance(provider_config, RoutingTableConfig)
|
||||||
|
routing_table = provider_config
|
||||||
|
|
||||||
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
|
inner_impls = []
|
||||||
|
for routing_entry in routing_table.entries:
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
inner_specs[routing_entry.provider_id],
|
||||||
|
deps,
|
||||||
|
routing_entry,
|
||||||
|
)
|
||||||
|
inner_impls.append((routing_entry.routing_key, impl))
|
||||||
|
|
||||||
|
config = None
|
||||||
|
args = [provider_spec.api, inner_impls, routing_table, deps]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
|
|
||||||
|
|
||||||
from .config import BuiltinImplConfig # noqa
|
|
||||||
|
|
||||||
|
|
||||||
async def get_builtin_impl(config: StackRunConfig):
|
|
||||||
from .models import BuiltinModelsImpl
|
|
||||||
|
|
||||||
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
|
|
||||||
|
|
||||||
impl = BuiltinModelsImpl(config)
|
|
||||||
await impl.initialize()
|
|
||||||
return impl
|
|
|
@ -1,113 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from typing import AsyncIterator, Union
|
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import StopReason
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import Api, api_providers
|
|
||||||
|
|
||||||
from llama_stack.apis.models import * # noqa: F403
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
||||||
from llama_models.datatypes import CoreModelId, Model
|
|
||||||
from llama_models.sku_list import resolve_model
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
|
|
||||||
class BuiltinModelsImpl(Models):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: StackRunConfig,
|
|
||||||
) -> None:
|
|
||||||
self.run_config = config
|
|
||||||
self.models = {}
|
|
||||||
# check against inference & safety api
|
|
||||||
apis_with_models = [Api.inference, Api.safety]
|
|
||||||
|
|
||||||
all_providers = api_providers()
|
|
||||||
|
|
||||||
for api in apis_with_models:
|
|
||||||
|
|
||||||
# check against provider_map (simple case single model)
|
|
||||||
if api.value in config.provider_map:
|
|
||||||
providers_for_api = all_providers[api]
|
|
||||||
provider_spec = config.provider_map[api.value]
|
|
||||||
core_model_id = provider_spec.config
|
|
||||||
# get supported model ids from the provider
|
|
||||||
supported_model_ids = self.get_supported_model_ids(
|
|
||||||
api.value, provider_spec, providers_for_api
|
|
||||||
)
|
|
||||||
for model_id in supported_model_ids:
|
|
||||||
self.models[model_id] = ModelServingSpec(
|
|
||||||
llama_model=resolve_model(model_id),
|
|
||||||
provider_config=provider_spec,
|
|
||||||
api=api.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check against provider_routing_table (router with multiple models)
|
|
||||||
# with routing table, we use the routing_key as the supported models
|
|
||||||
if api.value in config.provider_routing_table:
|
|
||||||
routing_table = config.provider_routing_table[api.value]
|
|
||||||
for rt_entry in routing_table:
|
|
||||||
model_id = rt_entry.routing_key
|
|
||||||
self.models[model_id] = ModelServingSpec(
|
|
||||||
llama_model=resolve_model(model_id),
|
|
||||||
provider_config=GenericProviderConfig(
|
|
||||||
provider_id=rt_entry.provider_id,
|
|
||||||
config=rt_entry.config,
|
|
||||||
),
|
|
||||||
api=api.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("BuiltinModelsImpl models", self.models)
|
|
||||||
|
|
||||||
def get_supported_model_ids(
|
|
||||||
self,
|
|
||||||
api: str,
|
|
||||||
provider_spec: GenericProviderConfig,
|
|
||||||
providers_for_api: Dict[str, ProviderSpec],
|
|
||||||
) -> List[str]:
|
|
||||||
serving_models_list = []
|
|
||||||
if api == Api.inference.value:
|
|
||||||
provider_id = provider_spec.provider_id
|
|
||||||
if provider_id == "meta-reference":
|
|
||||||
serving_models_list.append(provider_spec.config["model"])
|
|
||||||
if provider_id in {
|
|
||||||
remote_provider_id("ollama"),
|
|
||||||
remote_provider_id("fireworks"),
|
|
||||||
remote_provider_id("together"),
|
|
||||||
}:
|
|
||||||
adapter_supported_models = providers_for_api[
|
|
||||||
provider_id
|
|
||||||
].adapter.supported_model_ids
|
|
||||||
serving_models_list.extend(adapter_supported_models)
|
|
||||||
elif api == Api.safety.value:
|
|
||||||
if provider_spec.config and "llama_guard_shield" in provider_spec.config:
|
|
||||||
llama_guard_shield = provider_spec.config["llama_guard_shield"]
|
|
||||||
serving_models_list.append(llama_guard_shield["model"])
|
|
||||||
if provider_spec.config and "prompt_guard_shield" in provider_spec.config:
|
|
||||||
prompt_guard_shield = provider_spec.config["prompt_guard_shield"]
|
|
||||||
serving_models_list.append(prompt_guard_shield["model"])
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported api {api} for builtin models")
|
|
||||||
|
|
||||||
return serving_models_list
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def list_models(self) -> ModelsListResponse:
|
|
||||||
return ModelsListResponse(models_list=list(self.models.values()))
|
|
||||||
|
|
||||||
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
|
|
||||||
if core_model_id in self.models:
|
|
||||||
return ModelsGetResponse(core_model_spec=self.models[core_model_id])
|
|
||||||
print(f"Cannot find {core_model_id} in model registry")
|
|
||||||
return ModelsGetResponse()
|
|
|
@ -6,17 +6,14 @@
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.datatypes import ModelFamily
|
from llama_models.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from llama_models.sku_list import all_registered_models, resolve_model
|
from llama_models.sku_list import all_registered_models, resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MetaReferenceImplConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="Meta-Llama3.1-8B-Instruct",
|
default="Meta-Llama3.1-8B-Instruct",
|
||||||
|
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
|
||||||
m.descriptor()
|
m.descriptor()
|
||||||
for m in all_registered_models()
|
for m in all_registered_models()
|
||||||
if m.model_family == ModelFamily.llama3_1
|
if m.model_family == ModelFamily.llama3_1
|
||||||
|
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||||
]
|
]
|
||||||
if model not in permitted_models:
|
if model not in permitted_models:
|
||||||
model_list = "\n\t".join(permitted_models)
|
model_list = "\n\t".join(permitted_models)
|
||||||
|
|
|
@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter_id="ollama",
|
adapter_id="ollama",
|
||||||
pip_packages=["ollama"],
|
pip_packages=["ollama"],
|
||||||
module="llama_stack.providers.adapters.inference.ollama",
|
module="llama_stack.providers.adapters.inference.ollama",
|
||||||
supported_model_ids=[
|
|
||||||
"Meta-Llama3.1-8B-Instruct",
|
|
||||||
"Meta-Llama3.1-70B-Instruct",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.adapters.inference.fireworks",
|
module="llama_stack.providers.adapters.inference.fireworks",
|
||||||
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
|
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
|
||||||
supported_model_ids=[
|
|
||||||
"Meta-Llama3.1-8B-Instruct",
|
|
||||||
"Meta-Llama3.1-70B-Instruct",
|
|
||||||
"Meta-Llama3.1-405B-Instruct",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.adapters.inference.together",
|
module="llama_stack.providers.adapters.inference.together",
|
||||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
||||||
supported_model_ids=[
|
|
||||||
"Meta-Llama3.1-8B-Instruct",
|
|
||||||
"Meta-Llama3.1-70B-Instruct",
|
|
||||||
"Meta-Llama3.1-405B-Instruct",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
|
||||||
return [
|
|
||||||
BuiltinProviderSpec(
|
|
||||||
api=Api.models,
|
|
||||||
provider_id="builtin",
|
|
||||||
pip_packages=[],
|
|
||||||
module="llama_stack.providers.impls.builtin.models",
|
|
||||||
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
|
|
||||||
api_dependencies=[],
|
|
||||||
)
|
|
||||||
]
|
|
Loading…
Add table
Add a link
Reference in a new issue