mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
instantiate inference models
This commit is contained in:
parent
d2ec822b12
commit
7071c46422
6 changed files with 40 additions and 20 deletions
|
@ -321,12 +321,10 @@ async def resolve_impls(
|
||||||
inner_specs=inner_specs,
|
inner_specs=inner_specs,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in specs.items():
|
|
||||||
cprint(k, "blue")
|
|
||||||
cprint(v, "blue")
|
|
||||||
|
|
||||||
sorted_specs = topological_sort(specs.values())
|
sorted_specs = topological_sort(specs.values())
|
||||||
|
|
||||||
|
cprint(f"sorted_specs={sorted_specs}", "red")
|
||||||
|
|
||||||
impls = {}
|
impls = {}
|
||||||
for spec in sorted_specs:
|
for spec in sorted_specs:
|
||||||
api = spec.api
|
api = spec.api
|
||||||
|
|
|
@ -53,7 +53,8 @@ async def instantiate_provider(
|
||||||
args = [inner_impls, deps]
|
args = [inner_impls, deps]
|
||||||
elif isinstance(provider_config, str) and provider_config == "models-router":
|
elif isinstance(provider_config, str) and provider_config == "models-router":
|
||||||
config = None
|
config = None
|
||||||
args = [[], deps]
|
assert len(deps) == 1 and Api.models in deps
|
||||||
|
args = [deps[Api.models]]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"provider_config {provider_config} is not valid")
|
raise ValueError(f"provider_config {provider_config} is not valid")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -18,8 +18,6 @@ async def get_provider_impl(config: BuiltinImplConfig, deps: Dict[Api, ProviderS
|
||||||
config, BuiltinImplConfig
|
config, BuiltinImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
print(config)
|
|
||||||
|
|
||||||
impl = BuiltinModelsImpl(config)
|
impl = BuiltinModelsImpl(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -25,7 +25,6 @@ class BuiltinModelsImpl(Models):
|
||||||
config: BuiltinImplConfig,
|
config: BuiltinImplConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
cprint(self.config, "red")
|
|
||||||
self.models = {
|
self.models = {
|
||||||
entry.core_model_id: ModelSpec(
|
entry.core_model_id: ModelSpec(
|
||||||
llama_model_metadata=resolve_model(entry.core_model_id),
|
llama_model_metadata=resolve_model(entry.core_model_id),
|
||||||
|
|
|
@ -9,9 +9,9 @@ from typing import Any, List, Tuple
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
async def get_router_impl(inner_impls: List[Tuple[str, Any]], deps: List[Api]):
|
async def get_router_impl(models_api: Api):
|
||||||
from .inference import InferenceRouterImpl
|
from .inference import InferenceRouterImpl
|
||||||
|
|
||||||
impl = InferenceRouterImpl(inner_impls, deps)
|
impl = InferenceRouterImpl(models_api)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -8,7 +8,14 @@ from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
|
from llama_stack.distribution.distribution import api_providers
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||||
|
from llama_stack.providers.impls.builtin.models.models import BuiltinModelsImpl
|
||||||
from llama_stack.providers.registry.inference import available_providers
|
from llama_stack.providers.registry.inference import available_providers
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouterImpl(Inference):
|
class InferenceRouterImpl(Inference):
|
||||||
|
@ -16,19 +23,36 @@ class InferenceRouterImpl(Inference):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inner_impls: List[Tuple[str, Any]],
|
models_api: Models,
|
||||||
deps: List[Api],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.inner_impls = inner_impls
|
# map of model_id to provider impl
|
||||||
self.deps = deps
|
self.providers = {}
|
||||||
print("INIT INFERENCE ROUTER!")
|
self.models_api = models_api
|
||||||
|
|
||||||
# self.providers = {}
|
|
||||||
# for routing_key, provider_impl in inner_impls:
|
|
||||||
# self.providers[routing_key] = provider_impl
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
inference_providers = api_providers()[Api.inference]
|
||||||
|
|
||||||
|
models_list_response = await self.models_api.list_models()
|
||||||
|
for model_spec in models_list_response.models_list:
|
||||||
|
|
||||||
|
if model_spec.api != Api.inference.value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if model_spec.provider_id not in inference_providers:
|
||||||
|
raise ValueError(
|
||||||
|
f"provider_id {model_spec.provider_id} is not available for inference. Please check run.yaml config spec to define a valid provider"
|
||||||
|
)
|
||||||
|
impl = await instantiate_provider(
|
||||||
|
inference_providers[model_spec.provider_id],
|
||||||
|
deps=[],
|
||||||
|
provider_config=GenericProviderConfig(
|
||||||
|
provider_id=model_spec.provider_id,
|
||||||
|
config=model_spec.provider_config,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cprint(f"impl={impl}", "blue")
|
||||||
|
# look up and initialize provider implementations for each model
|
||||||
|
core_model_id = model_spec.llama_model_metadata.core_model_id
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue