mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
supported models wip
This commit is contained in:
parent
20a4302877
commit
c0199029e5
10 changed files with 215 additions and 34 deletions
|
@ -29,17 +29,12 @@ class ModelServingSpec(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelsListResponse(BaseModel):
|
class ModelsListResponse(BaseModel):
|
||||||
models_list: List[ModelSpec]
|
models_list: List[ModelServingSpec]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelsGetResponse(BaseModel):
|
class ModelsGetResponse(BaseModel):
|
||||||
core_model_spec: Optional[ModelSpec] = None
|
core_model_spec: Optional[ModelServingSpec] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ModelsRegisterResponse(BaseModel):
|
|
||||||
core_model_spec: Optional[ModelSpec] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
|
|
|
@ -20,6 +20,7 @@ class Api(Enum):
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
models = "models"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -81,6 +82,29 @@ class RouterProviderSpec(ProviderSpec):
|
||||||
raise AssertionError("Should not be called on RouterProviderSpec")
|
raise AssertionError("Should not be called on RouterProviderSpec")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BuiltinProviderSpec(ProviderSpec):
|
||||||
|
provider_id: str = "builtin"
|
||||||
|
config_class: str = ""
|
||||||
|
docker_image: Optional[str] = None
|
||||||
|
api_dependencies: List[Api] = []
|
||||||
|
provider_data_validator: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
module: str = Field(
|
||||||
|
...,
|
||||||
|
description="""
|
||||||
|
Fully-qualified name of the module to import. The module is expected to have:
|
||||||
|
|
||||||
|
- `get_router_impl(config, provider_specs, deps)`: returns the router implementation
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
pip_packages: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The pip dependencies needed for this implementation",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_id: str = Field(
|
adapter_id: str = Field(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from typing import Dict, List
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
|
||||||
|
@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.memory: Memory,
|
Api.memory: Memory,
|
||||||
Api.telemetry: Telemetry,
|
Api.telemetry: Telemetry,
|
||||||
|
Api.models: Models,
|
||||||
}
|
}
|
||||||
|
|
||||||
for api, protocol in protocols.items():
|
for api, protocol in protocols.items():
|
||||||
|
|
|
@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.utils.dynamic import (
|
from llama_stack.distribution.utils.dynamic import (
|
||||||
|
instantiate_builtin_provider,
|
||||||
instantiate_provider,
|
instantiate_provider,
|
||||||
instantiate_router,
|
instantiate_router,
|
||||||
)
|
)
|
||||||
|
@ -306,15 +307,6 @@ async def resolve_impls_with_routing(
|
||||||
api = Api(api_str)
|
api = Api(api_str)
|
||||||
providers = all_providers[api]
|
providers = all_providers[api]
|
||||||
|
|
||||||
# check for regular providers without routing
|
|
||||||
if api_str in stack_run_config.provider_map:
|
|
||||||
provider_map_entry = stack_run_config.provider_map[api_str]
|
|
||||||
if provider_map_entry.provider_id not in providers:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
|
||||||
)
|
|
||||||
specs[api] = providers[provider_map_entry.provider_id]
|
|
||||||
|
|
||||||
# check for routing table, we need to pass routing table to the router implementation
|
# check for routing table, we need to pass routing table to the router implementation
|
||||||
if api_str in stack_run_config.provider_routing_table:
|
if api_str in stack_run_config.provider_routing_table:
|
||||||
specs[api] = RouterProviderSpec(
|
specs[api] = RouterProviderSpec(
|
||||||
|
@ -323,6 +315,19 @@ async def resolve_impls_with_routing(
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
routing_table=stack_run_config.provider_routing_table[api_str],
|
routing_table=stack_run_config.provider_routing_table[api_str],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if api_str in stack_run_config.provider_map:
|
||||||
|
provider_map_entry = stack_run_config.provider_map[api_str]
|
||||||
|
provider_id = provider_map_entry.provider_id
|
||||||
|
else:
|
||||||
|
# not defined in config, will be a builtin provider, assign builtin provider id
|
||||||
|
provider_id = "builtin"
|
||||||
|
|
||||||
|
if provider_id not in providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown provider `{provider_id}` is not available for API `{api}`"
|
||||||
|
)
|
||||||
|
specs[api] = providers[provider_id]
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
|
@ -338,7 +343,7 @@ async def resolve_impls_with_routing(
|
||||||
spec, api.value, stack_run_config.provider_routing_table
|
spec, api.value, stack_run_config.provider_routing_table
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Cannot find provider_config for Api {api.value}")
|
impl = await instantiate_builtin_provider(spec, stack_run_config)
|
||||||
impls[api] = impl
|
impls[api] = impl
|
||||||
|
|
||||||
return impls, specs
|
return impls, specs
|
||||||
|
|
|
@ -29,6 +29,18 @@ async def instantiate_router(
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
async def instantiate_builtin_provider(
|
||||||
|
provider_spec: BuiltinProviderSpec,
|
||||||
|
run_config: StackRunConfig,
|
||||||
|
):
|
||||||
|
print("!!! instantiate_builtin_provider")
|
||||||
|
module = importlib.import_module(provider_spec.module)
|
||||||
|
fn = getattr(module, "get_builtin_impl")
|
||||||
|
impl = await fn(run_config)
|
||||||
|
impl.__provider_spec__ = provider_spec
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
# returns a class implementing the protocol corresponding to the Api
|
# returns a class implementing the protocol corresponding to the Api
|
||||||
async def instantiate_provider(
|
async def instantiate_provider(
|
||||||
provider_spec: ProviderSpec,
|
provider_spec: ProviderSpec,
|
||||||
|
|
|
@ -8,6 +8,7 @@ apis_to_serve:
|
||||||
- telemetry
|
- telemetry
|
||||||
- agents
|
- agents
|
||||||
- safety
|
- safety
|
||||||
|
- models
|
||||||
provider_map:
|
provider_map:
|
||||||
telemetry:
|
telemetry:
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
|
@ -22,27 +23,39 @@ provider_map:
|
||||||
disable_output_check: false
|
disable_output_check: false
|
||||||
prompt_guard_shield:
|
prompt_guard_shield:
|
||||||
model: Prompt-Guard-86M
|
model: Prompt-Guard-86M
|
||||||
|
# inference:
|
||||||
|
# provider_id: meta-reference
|
||||||
|
# config:
|
||||||
|
# model: Meta-Llama3.1-8B-Instruct
|
||||||
|
# quantization: null
|
||||||
|
# torch_seed: null
|
||||||
|
# max_seq_len: 4096
|
||||||
|
# max_batch_size: 1
|
||||||
|
inference:
|
||||||
|
provider_id: remote::ollama
|
||||||
|
config:
|
||||||
|
|
||||||
agents:
|
agents:
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
provider_routing_table:
|
provider_routing_table:
|
||||||
inference:
|
# inference:
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
# - routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
provider_id: meta-reference
|
# provider_id: meta-reference
|
||||||
config:
|
# config:
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
# model: Meta-Llama3.1-8B-Instruct
|
||||||
quantization: null
|
# quantization: null
|
||||||
torch_seed: null
|
# torch_seed: null
|
||||||
max_seq_len: 4096
|
# max_seq_len: 4096
|
||||||
max_batch_size: 1
|
# max_batch_size: 1
|
||||||
# - routing_key: Meta-Llama3.1-8B
|
# - routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
# provider_id: meta-reference
|
# provider_id: meta-reference
|
||||||
# config:
|
# config:
|
||||||
# model: Meta-Llama3.1-8B
|
# model: Meta-Llama3.1-8B
|
||||||
# quantization: null
|
# quantization: null
|
||||||
# torch_seed: null
|
# torch_seed: null
|
||||||
# max_seq_len: 4096
|
# max_seq_len: 4096
|
||||||
# max_batch_size: 1
|
# max_batch_size: 1
|
||||||
memory:
|
memory:
|
||||||
# - routing_key: keyvalue
|
# - routing_key: keyvalue
|
||||||
# provider_id: remote::pgvector
|
# provider_id: remote::pgvector
|
||||||
|
|
19
llama_stack/providers/impls/builtin/models/__init__.py
Normal file
19
llama_stack/providers/impls/builtin/models/__init__.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_builtin_impl(config: StackRunConfig):
|
||||||
|
from .models import BuiltinModelsImpl
|
||||||
|
|
||||||
|
assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = BuiltinModelsImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
11
llama_stack/providers/impls/builtin/models/config.py
Normal file
11
llama_stack/providers/impls/builtin/models/config.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BuiltinImplConfig(BaseModel): ...
|
78
llama_stack/providers/impls/builtin/models/models.py
Normal file
78
llama_stack/providers/impls/builtin/models/models.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import AsyncIterator, Union
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import StopReason
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.models import * # noqa: F403
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_models.datatypes import CoreModelId, Model
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
Api,
|
||||||
|
GenericProviderConfig,
|
||||||
|
StackRunConfig,
|
||||||
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinModelsImpl(Models):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: StackRunConfig,
|
||||||
|
) -> None:
|
||||||
|
print("BuiltinModelsImpl init")
|
||||||
|
|
||||||
|
self.run_config = config
|
||||||
|
self.models = {}
|
||||||
|
|
||||||
|
print("BuiltinModelsImpl run_config", config)
|
||||||
|
|
||||||
|
# check against inference & safety api
|
||||||
|
apis_with_models = [Api.inference, Api.safety]
|
||||||
|
|
||||||
|
for api in apis_with_models:
|
||||||
|
# check against provider_map (simple case single model)
|
||||||
|
if api.value in config.provider_map:
|
||||||
|
provider_spec = config.provider_map[api.value]
|
||||||
|
core_model_id = provider_spec.config
|
||||||
|
print("provider_spec", provider_spec)
|
||||||
|
model_spec = ModelServingSpec(
|
||||||
|
provider_config=provider_spec,
|
||||||
|
)
|
||||||
|
# get supported model ids from the provider
|
||||||
|
supported_model_ids = self.get_supported_model_ids(provider_spec)
|
||||||
|
for model_id in supported_model_ids:
|
||||||
|
self.models[model_id] = ModelServingSpec(
|
||||||
|
llama_model=resolve_model(model_id),
|
||||||
|
provider_config=provider_spec,
|
||||||
|
api=api.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check against provider_routing_table (router with multiple models)
|
||||||
|
# with routing table, we use the routing_key as the supported models
|
||||||
|
|
||||||
|
def resolve_supported_model_ids(self) -> list[CoreModelId]:
|
||||||
|
# TODO: for remote providers, provide registry to list supported models
|
||||||
|
|
||||||
|
return ["Meta-Llama3.1-8B-Instruct"]
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_models(self) -> ModelsListResponse:
|
||||||
|
pass
|
||||||
|
# return ModelsListResponse(models_list=list(self.models.values()))
|
||||||
|
|
||||||
|
async def get_model(self, core_model_id: str) -> ModelsGetResponse:
|
||||||
|
pass
|
||||||
|
# if core_model_id in self.models:
|
||||||
|
# return ModelsGetResponse(core_model_spec=self.models[core_model_id])
|
||||||
|
# raise RuntimeError(f"Cannot find {core_model_id} in model registry")
|
22
llama_stack/providers/registry/models.py
Normal file
22
llama_stack/providers/registry/models.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
BuiltinProviderSpec(
|
||||||
|
api=Api.models,
|
||||||
|
provider_id="builtin",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.impls.builtin.models",
|
||||||
|
config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig",
|
||||||
|
api_dependencies=[],
|
||||||
|
)
|
||||||
|
]
|
Loading…
Add table
Add a link
Reference in a new issue