supported models wip

This commit is contained in:
Xi Yan 2024-09-21 18:37:22 -07:00
parent 20a4302877
commit c0199029e5
10 changed files with 215 additions and 34 deletions

View file

@ -29,17 +29,12 @@ class ModelServingSpec(BaseModel):
@json_schema_type
class ModelsListResponse(BaseModel):
models_list: List[ModelSpec]
models_list: List[ModelServingSpec]
@json_schema_type
class ModelsGetResponse(BaseModel):
core_model_spec: Optional[ModelSpec] = None
@json_schema_type
class ModelsRegisterResponse(BaseModel):
core_model_spec: Optional[ModelSpec] = None
core_model_spec: Optional[ModelServingSpec] = None
class Models(Protocol):

View file

@ -20,6 +20,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
telemetry = "telemetry"
models = "models"
@json_schema_type
@ -81,6 +82,29 @@ class RouterProviderSpec(ProviderSpec):
raise AssertionError("Should not be called on RouterProviderSpec")
@json_schema_type
class BuiltinProviderSpec(ProviderSpec):
provider_id: str = "builtin"
config_class: str = ""
docker_image: Optional[str] = None
api_dependencies: List[Api] = []
provider_data_validator: Optional[str] = Field(
default=None,
)
module: str = Field(
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
""",
)
pip_packages: List[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
@json_schema_type
class AdapterSpec(BaseModel):
adapter_id: str = Field(

View file

@ -11,6 +11,7 @@ from typing import Dict, List
from llama_stack.apis.agents import Agents
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.telemetry import Telemetry
@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
Api.agents: Agents,
Api.memory: Memory,
Api.telemetry: Telemetry,
Api.models: Models,
}
for api, protocol in protocols.items():

View file

@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import (
instantiate_builtin_provider,
instantiate_provider,
instantiate_router,
)
@ -306,15 +307,6 @@ async def resolve_impls_with_routing(
api = Api(api_str)
providers = all_providers[api]
# check for regular providers without routing
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
if provider_map_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_map_entry.provider_id]
# check for routing table, we need to pass routing table to the router implementation
if api_str in stack_run_config.provider_routing_table:
specs[api] = RouterProviderSpec(
@ -323,6 +315,19 @@ async def resolve_impls_with_routing(
api_dependencies=[],
routing_table=stack_run_config.provider_routing_table[api_str],
)
else:
if api_str in stack_run_config.provider_map:
provider_map_entry = stack_run_config.provider_map[api_str]
provider_id = provider_map_entry.provider_id
else:
# not defined in config, will be a builtin provider, assign builtin provider id
provider_id = "builtin"
if provider_id not in providers:
raise ValueError(
f"Unknown provider `{provider_id}` is not available for API `{api}`"
)
specs[api] = providers[provider_id]
sorted_specs = topological_sort(specs.values())
@ -338,7 +343,7 @@ async def resolve_impls_with_routing(
spec, api.value, stack_run_config.provider_routing_table
)
else:
raise ValueError(f"Cannot find provider_config for Api {api.value}")
impl = await instantiate_builtin_provider(spec, stack_run_config)
impls[api] = impl
return impls, specs

View file

@ -29,6 +29,18 @@ async def instantiate_router(
return impl
async def instantiate_builtin_provider(
provider_spec: BuiltinProviderSpec,
run_config: StackRunConfig,
):
print("!!! instantiate_builtin_provider")
module = importlib.import_module(provider_spec.module)
fn = getattr(module, "get_builtin_impl")
impl = await fn(run_config)
impl.__provider_spec__ = provider_spec
return impl
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,

View file

@ -8,6 +8,7 @@ apis_to_serve:
- telemetry
- agents
- safety
- models
provider_map:
telemetry:
provider_id: meta-reference
@ -22,20 +23,32 @@ provider_map:
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
# inference:
# provider_id: meta-reference
# config:
# model: Meta-Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
inference:
provider_id: remote::ollama
config:
agents:
provider_id: meta-reference
config: {}
provider_routing_table:
inference:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
# - routing_key: Meta-Llama3.1-8B
# inference:
# - routing_key: Meta-Llama3.1-8B-Instruct
# provider_id: meta-reference
# config:
# model: Meta-Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# - routing_key: Meta-Llama3.1-8B-Instruct
# provider_id: meta-reference
# config:
# model: Meta-Llama3.1-8B

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
async def get_builtin_impl(config: StackRunConfig):
from .models import BuiltinModelsImpl
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
impl = BuiltinModelsImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class BuiltinImplConfig(BaseModel): ...

View file

@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import AsyncIterator, Union
from llama_models.llama3.api.datatypes import StopReason
from llama_models.sku_list import resolve_model
from llama_stack.apis.models import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.datatypes import CoreModelId, Model
from llama_models.sku_list import resolve_model
from llama_stack.distribution.datatypes import (
Api,
GenericProviderConfig,
StackRunConfig,
)
from termcolor import cprint
class BuiltinModelsImpl(Models):
def __init__(
self,
config: StackRunConfig,
) -> None:
print("BuiltinModelsImpl init")
self.run_config = config
self.models = {}
print("BuiltinModelsImpl run_config", config)
# check against inference & safety api
apis_with_models = [Api.inference, Api.safety]
for api in apis_with_models:
# check against provider_map (simple case single model)
if api.value in config.provider_map:
provider_spec = config.provider_map[api.value]
core_model_id = provider_spec.config
print("provider_spec", provider_spec)
model_spec = ModelServingSpec(
provider_config=provider_spec,
)
# get supported model ids from the provider
supported_model_ids = self.get_supported_model_ids(provider_spec)
for model_id in supported_model_ids:
self.models[model_id] = ModelServingSpec(
llama_model=resolve_model(model_id),
provider_config=provider_spec,
api=api.value,
)
# check against provider_routing_table (router with multiple models)
# with routing table, we use the routing_key as the supported models
def resolve_supported_model_ids(self) -> list[CoreModelId]:
# TODO: for remote providers, provide registry to list supported models
return ["Meta-Llama3.1-8B-Instruct"]
async def initialize(self) -> None:
pass
async def list_models(self) -> ModelsListResponse:
pass
# return ModelsListResponse(models_list=list(self.models.values()))
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
pass
# if core_model_id in self.models:
# return ModelsGetResponse(core_model_spec=self.models[core_model_id])
# raise RuntimeError(f"Cannot find {core_model_id} in model registry")

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
BuiltinProviderSpec(
api=Api.models,
provider_id="builtin",
pip_packages=[],
module="llama_stack.providers.impls.builtin.models",
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
api_dependencies=[],
)
]