mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into evals_8
This commit is contained in:
commit
dc79f1c2c6
6 changed files with 16 additions and 17 deletions
|
@ -44,9 +44,7 @@ class ScoringFnDef(BaseModel):
|
|||
description="List of parameters for the deterministic function",
|
||||
default_factory=list,
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
return_type: ParamType
|
||||
context: Optional[LLMAsJudgeContext] = None
|
||||
# We can optionally add information here to support packaging of code, etc.
|
||||
|
||||
|
|
|
@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
|
||||
|
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
|
|||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
Does two things:
|
||||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
all_api_providers = get_provider_registry()
|
||||
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
|
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in all_api_providers[api]:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
|
|
|
@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -276,7 +279,7 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
impls = asyncio.run(resolve_impls(config))
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import yaml
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
|
||||
|
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
|||
providers=chosen,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls(run_config)
|
||||
impls = await resolve_impls(run_config, get_provider_registry())
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_id = chosen[api.value][0].provider_id
|
||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
|||
fire
|
||||
httpx
|
||||
huggingface-hub
|
||||
llama-models>=0.0.46
|
||||
llama-models>=0.0.47
|
||||
prompt-toolkit
|
||||
python-dotenv
|
||||
pydantic>=2
|
||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
|||
|
||||
setup(
|
||||
name="llama_stack",
|
||||
version="0.0.46",
|
||||
version="0.0.47",
|
||||
author="Meta Llama",
|
||||
author_email="llama-oss@meta.com",
|
||||
description="Llama Stack",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue