Merge branch 'main' into evals_8

This commit is contained in:
Xi Yan 2024-10-28 13:08:28 -07:00
commit dc79f1c2c6
6 changed files with 16 additions and 17 deletions

View file

@ -44,9 +44,7 @@ class ScoringFnDef(BaseModel):
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
return_type: ParamType
context: Optional[LLMAsJudgeContext] = None
# We can optionally add information here to support packaging of code, etc.

View file

@ -25,10 +25,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -65,14 +62,14 @@ class ProviderWithSpec(Provider):
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
async def resolve_impls(
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
) -> Dict[Api, Any]:
"""
Does two things:
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
all_api_providers = get_provider_registry()
routing_table_apis = set(
x.routing_table_api for x in builtin_automatically_routed_apis()
)
@ -89,12 +86,12 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
specs = {}
for provider in providers:
if provider.provider_type not in all_api_providers[api]:
if provider.provider_type not in provider_registry[api]:
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,

View file

@ -26,7 +26,10 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -276,7 +279,7 @@ def main(
app = FastAPI()
impls = asyncio.run(resolve_impls(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])

View file

@ -13,6 +13,7 @@ import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.46
llama-models>=0.0.47
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.46",
version="0.0.47",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",