mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Merge branch 'evals_8' into evals_9
This commit is contained in:
commit
f3aab94029
6 changed files with 16 additions and 17 deletions
|
@ -44,9 +44,7 @@ class ScoringFnDef(BaseModel):
|
||||||
description="List of parameters for the deterministic function",
|
description="List of parameters for the deterministic function",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType
|
||||||
description="The return type of the deterministic function",
|
|
||||||
)
|
|
||||||
context: Optional[LLMAsJudgeContext] = None
|
context: Optional[LLMAsJudgeContext] = None
|
||||||
# We can optionally add information here to support packaging of code, etc.
|
# We can optionally add information here to support packaging of code, etc.
|
||||||
|
|
||||||
|
|
|
@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
builtin_automatically_routed_apis,
|
|
||||||
get_provider_registry,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
async def resolve_impls(
|
||||||
|
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
- flatmaps, sorts and resolves the providers in dependency order
|
- flatmaps, sorts and resolves the providers in dependency order
|
||||||
- for each API, produces either a (local, passthrough or router) implementation
|
- for each API, produces either a (local, passthrough or router) implementation
|
||||||
"""
|
"""
|
||||||
all_api_providers = get_provider_registry()
|
|
||||||
|
|
||||||
routing_table_apis = set(
|
routing_table_apis = set(
|
||||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||||
)
|
)
|
||||||
|
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
|
||||||
specs = {}
|
specs = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.provider_type not in all_api_providers[api]:
|
if provider.provider_type not in provider_registry[api]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
p = all_api_providers[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
p.deps__ = [a.value for a in p.api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=p,
|
spec=p,
|
||||||
|
|
|
@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import (
|
||||||
|
builtin_automatically_routed_apis,
|
||||||
|
get_provider_registry,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -276,7 +279,7 @@ def main(
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config))
|
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import yaml
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
|
||||||
|
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
||||||
providers=chosen,
|
providers=chosen,
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
impls = await resolve_impls(run_config)
|
impls = await resolve_impls(run_config, get_provider_registry())
|
||||||
|
|
||||||
if "provider_data" in config_dict:
|
if "provider_data" in config_dict:
|
||||||
provider_id = chosen[api.value][0].provider_id
|
provider_id = chosen[api.value][0].provider_id
|
||||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
llama-models>=0.0.46
|
llama-models>=0.0.47
|
||||||
prompt-toolkit
|
prompt-toolkit
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pydantic>=2
|
pydantic>=2
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="llama_stack",
|
name="llama_stack",
|
||||||
version="0.0.46",
|
version="0.0.47",
|
||||||
author="Meta Llama",
|
author="Meta Llama",
|
||||||
author_email="llama-oss@meta.com",
|
author_email="llama-oss@meta.com",
|
||||||
description="Llama Stack",
|
description="Llama Stack",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue