supported models wip

This commit is contained in:
Xi Yan 2024-09-21 18:37:22 -07:00
parent 20a4302877
commit c0199029e5
10 changed files with 215 additions and 34 deletions

View file

@ -20,6 +20,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"
@json_schema_type
@ -81,6 +82,29 @@ class RouterProviderSpec(ProviderSpec):
raise AssertionError("Should not be called on RouterProviderSpec")
@json_schema_type
class BuiltinProviderSpec(ProviderSpec):
provider_id: str = "builtin"
config_class: str = ""
docker_image: Optional[str] = None
api_dependencies: List[Api] = []
provider_data_validator: Optional[str] = Field(
default=None,
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(

View file

@ -11,6 +11,7 @@ from typing import Dict, List
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
}
for api, protocol in protocols.items():

View file

@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
)
@ -306,15 +307,6 @@ async def resolve_impls_with_routing(
api = Api(api_str)
providers = all_providers[api]
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
@ -323,6 +315,19 @@ async def resolve_impls_with_routing(
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
sorted_specs = topological_sort(specs.values())
@ -338,7 +343,7 @@ async def resolve_impls_with_routing(
spec, api.value, stack_run_config.provider_routing_table
)
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impl = await instantiate_builtin_provider(spec, stack_run_config)
impls[api] = impl
return impls, specs

View file

@ -29,6 +29,18 @@ async def instantiate_router(
return impl
async def instantiate_builtin_provider(
provider_spec: BuiltinProviderSpec,
run_config: StackRunConfig,
):
print("!!! instantiate_builtin_provider")
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_builtin_impl")
impl = await fn(run_config)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,