mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 12:21:52 +00:00
supported models wip
This commit is contained in:
parent
20a4302877
commit
c0199029e5
10 changed files with 215 additions and 34 deletions
|
@ -20,6 +20,7 @@ class Api(Enum):
|
|||
agents = "agents"
|
||||
memory = "memory"
|
||||
telemetry = "telemetry"
|
||||
models = "models"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -81,6 +82,29 @@ class RouterProviderSpec(ProviderSpec):
|
|||
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BuiltinProviderSpec(ProviderSpec):
|
||||
provider_id: str = "builtin"
|
||||
config_class: str = ""
|
||||
docker_image: Optional[str] = None
|
||||
api_dependencies: List[Api] = []
|
||||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
module: str = Field(
|
||||
...,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_id: str = Field(
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Dict, List
|
|||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
|
||||
|
@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.agents: Agents,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
Api.models: Models,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
|
|
@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import (
|
||||
instantiate_builtin_provider,
|
||||
instantiate_provider,
|
||||
instantiate_router,
|
||||
)
|
||||
|
@ -306,15 +307,6 @@ async def resolve_impls_with_routing(
|
|||
api = Api(api_str)
|
||||
providers = all_providers[api]
|
||||
|
||||
# check for regular providers without routing
|
||||
if api_str in stack_run_config.provider_map:
|
||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||
if provider_map_entry.provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[provider_map_entry.provider_id]
|
||||
|
||||
# check for routing table, we need to pass routing table to the router implementation
|
||||
if api_str in stack_run_config.provider_routing_table:
|
||||
specs[api] = RouterProviderSpec(
|
||||
|
@ -323,6 +315,19 @@ async def resolve_impls_with_routing(
|
|||
api_dependencies=[],
|
||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||
)
|
||||
else:
|
||||
if api_str in stack_run_config.provider_map:
|
||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||
provider_id = provider_map_entry.provider_id
|
||||
else:
|
||||
# not defined in config, will be a builtin provider, assign builtin provider id
|
||||
provider_id = "builtin"
|
||||
|
||||
if provider_id not in providers:
|
||||
raise ValueError(
|
||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||
)
|
||||
specs[api] = providers[provider_id]
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
|
@ -338,7 +343,7 @@ async def resolve_impls_with_routing(
|
|||
spec, api.value, stack_run_config.provider_routing_table
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
||||
impl = await instantiate_builtin_provider(spec, stack_run_config)
|
||||
impls[api] = impl
|
||||
|
||||
return impls, specs
|
||||
|
|
|
@ -29,6 +29,18 @@ async def instantiate_router(
|
|||
return impl
|
||||
|
||||
|
||||
async def instantiate_builtin_provider(
|
||||
provider_spec: BuiltinProviderSpec,
|
||||
run_config: StackRunConfig,
|
||||
):
|
||||
print("!!! instantiate_builtin_provider")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
fn = getattr(module, "get_builtin_impl")
|
||||
impl = await fn(run_config)
|
||||
impl.__provider_spec__ = provider_spec
|
||||
return impl
|
||||
|
||||
|
||||
# returns a class implementing the protocol corresponding to the Api
|
||||
async def instantiate_provider(
|
||||
provider_spec: ProviderSpec,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue