mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Merge branch 'eval_task_register' into mmlu_benchmark
This commit is contained in:
commit
cc6edf6287
72 changed files with 306 additions and 304 deletions
29
.github/PULL_REQUEST_TEMPLATE.md
vendored
29
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,17 +1,14 @@
|
|||
# What does this PR do?
|
||||
|
||||
Closes # (issue)
|
||||
In short, provide a summary of what this PR does and why. Usually, the relevant context should be present in a linked issue.
|
||||
|
||||
- [ ] Addresses issue (#issue)
|
||||
|
||||
## Feature/Issue validation/testing/test plan
|
||||
|
||||
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
|
||||
Please also list any relevant details for your test configuration or test plan.
|
||||
|
||||
- [ ] Test A
|
||||
Logs for Test A
|
||||
|
||||
- [ ] Test B
|
||||
Logs for Test B
|
||||
Please describe:
|
||||
- tests you ran to verify your changes with result summaries.
|
||||
- provide instructions so it can be reproduced.
|
||||
|
||||
|
||||
## Sources
|
||||
|
@ -20,12 +17,10 @@ Please link relevant resources if necessary.
|
|||
|
||||
|
||||
## Before submitting
|
||||
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
||||
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a Github issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
Thanks for contributing 🎉!
|
||||
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
||||
- [ ] Ran pre-commit to handle lint / formatting issues.
|
||||
- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
|
||||
Pull Request section?
|
||||
- [ ] Updated relevant documentation.
|
||||
- [ ] Wrote necessary unit or integration tests.
|
||||
|
|
|
@ -38,14 +38,16 @@ EvalCandidate = Annotated[
|
|||
|
||||
@json_schema_type
|
||||
class BenchmarkEvalTaskConfig(BaseModel):
|
||||
eval_candidate: EvalCandidate # type: ignore
|
||||
type: Literal["benchmark"] = "benchmark"
|
||||
eval_candidate: EvalCandidate
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AppEvalTaskConfig(BaseModel):
|
||||
eval_candidate: EvalCandidate # type: ignore
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field( # type: ignore
|
||||
description="Map between scoring function id and parameters",
|
||||
type: Literal["app"] = "app"
|
||||
eval_candidate: EvalCandidate
|
||||
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||
default_factory=dict,
|
||||
)
|
||||
# we could optinally add any specific dataset config here
|
||||
|
@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel):
|
|||
|
||||
|
||||
class Eval(Protocol):
|
||||
@webmethod(route="/eval/run_benchmark_eval", method="POST")
|
||||
async def run_benchmark_eval(
|
||||
@webmethod(route="/eval/run_benchmark", method="POST")
|
||||
async def run_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
eval_task_config: BenchmarkEvalTaskConfig,
|
||||
benchmark_config: BenchmarkEvalTaskConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/run_eval", method="POST")
|
||||
async def run_eval(
|
||||
self,
|
||||
eval_task_def: EvalTaskDef,
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task: EvalTaskDef,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
|
@ -83,14 +85,17 @@ class Eval(Protocol):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
eval_task_config: EvalTaskConfig, # type: ignore
|
||||
task_config: EvalTaskConfig,
|
||||
eval_task_id: Optional[str] = None,
|
||||
) -> EvaluateResponse: ...
|
||||
|
||||
@webmethod(route="/eval/job/status", method="GET")
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
||||
async def job_status(
|
||||
self, job_id: str, eval_task_id: str
|
||||
) -> Optional[JobStatus]: ...
|
||||
|
||||
@webmethod(route="/eval/job/cancel", method="POST")
|
||||
async def job_cancel(self, job_id: str) -> None: ...
|
||||
async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ...
|
||||
|
||||
@webmethod(route="/eval/job/result", method="GET")
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse: ...
|
||||
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ...
|
||||
|
|
|
@ -48,8 +48,7 @@ class Scoring(Protocol):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
|
@ -57,6 +56,5 @@ class Scoring(Protocol):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse: ...
|
||||
|
|
|
@ -33,7 +33,7 @@ class ScoringConfigType(Enum):
|
|||
|
||||
@json_schema_type
|
||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringConfigType.llm_as_judge.value] = ( # type: ignore
|
||||
type: Literal[ScoringConfigType.llm_as_judge.value] = (
|
||||
ScoringConfigType.llm_as_judge.value
|
||||
)
|
||||
judge_model: str
|
||||
|
@ -46,7 +46,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RegexParserScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringConfigType.regex_parser.value] = ( # type: ignore
|
||||
type: Literal[ScoringConfigType.regex_parser.value] = (
|
||||
ScoringConfigType.regex_parser.value
|
||||
)
|
||||
parsing_regexes: Optional[List[str]] = Field(
|
||||
|
@ -75,8 +75,8 @@ class ScoringFnDef(BaseModel):
|
|||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
params: Optional[ScoringFnParams] = Field( # type: ignore
|
||||
description="The parameters for the scoring function for benchmark eval, we could override this for app eval",
|
||||
params: Optional[ScoringFnParams] = Field(
|
||||
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||
default=None,
|
||||
)
|
||||
# We can optionally add information here to support packaging of code, etc.
|
||||
|
|
|
@ -8,6 +8,8 @@ import inspect
|
|||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
@ -100,6 +102,12 @@ async def resolve_impls(
|
|||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_warning:
|
||||
cprint(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
"red",
|
||||
attrs=["bold"],
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
|
|
|
@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions:
|
||||
for fn_identifier in scoring_functions.keys():
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
scoring_params=scoring_params,
|
||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
|
@ -272,24 +268,21 @@ class EvalRouter(Eval):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_benchmark_eval(
|
||||
async def run_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
eval_task_config: BenchmarkEvalTaskConfig,
|
||||
benchmark_config: BenchmarkEvalTaskConfig,
|
||||
) -> Job:
|
||||
pass
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
eval_task_def: EvalTaskDef,
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task: EvalTaskDef,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
|
||||
return await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).run_eval(
|
||||
eval_task_def=eval_task_def,
|
||||
eval_task_config=eval_task_config,
|
||||
return await self.routing_table.get_provider_impl(task.identifier).run_eval(
|
||||
task=task,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||
|
@ -297,29 +290,42 @@ class EvalRouter(Eval):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
eval_task_config: EvalTaskConfig, # type: ignore
|
||||
task_config: EvalTaskConfig,
|
||||
eval_task_id: Optional[str] = None,
|
||||
) -> EvaluateResponse:
|
||||
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
|
||||
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
|
||||
return await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).evaluate_rows(
|
||||
if eval_task_id is None:
|
||||
eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=scoring_functions,
|
||||
eval_task_config=eval_task_config,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).job_status(job_id)
|
||||
async def job_status(
|
||||
self,
|
||||
job_id: str,
|
||||
eval_task_id: str,
|
||||
) -> Optional[JobStatus]:
|
||||
return await self.routing_table.get_provider_impl(eval_task_id).job_status(
|
||||
job_id, eval_task_id
|
||||
)
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).job_cancel(job_id)
|
||||
async def job_cancel(
|
||||
self,
|
||||
job_id: str,
|
||||
eval_task_id: str,
|
||||
) -> None:
|
||||
await self.routing_table.get_provider_impl(eval_task_id).job_cancel(
|
||||
job_id, eval_task_id
|
||||
)
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
||||
).job_result(job_id)
|
||||
async def job_result(
|
||||
self,
|
||||
job_id: str,
|
||||
eval_task_id: str,
|
||||
) -> EvaluateResponse:
|
||||
return await self.routing_table.get_provider_impl(eval_task_id).job_result(
|
||||
job_id, eval_task_id
|
||||
)
|
||||
|
|
|
@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
|
|||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
@ -281,21 +279,8 @@ def main(
|
|||
config = StackRunConfig(**yaml.safe_load(fp))
|
||||
|
||||
app = FastAPI()
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if config.metadata_store:
|
||||
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
|
||||
else:
|
||||
dist_kvstore = asyncio.run(
|
||||
kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(
|
||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||
).as_posix()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||
|
||||
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||
if Api.telemetry in impls:
|
||||
|
|
|
@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
|
|||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObjectWithProvider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
KVStore,
|
||||
kvstore_impl,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class DistributionRegistry(Protocol):
|
||||
|
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
self.cache[obj.identifier].append(obj)
|
||||
|
||||
return success
|
||||
|
||||
|
||||
async def create_dist_registry(
|
||||
config: StackRunConfig,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if config.metadata_store:
|
||||
dist_kvstore = await kvstore_impl(config.metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(
|
||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||
).as_posix()
|
||||
)
|
||||
)
|
||||
|
||||
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore
|
||||
|
|
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
|||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
deprecation_warning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
|
|
@ -4,10 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
|
@ -11,9 +11,8 @@ from datetime import datetime
|
|||
|
||||
from typing import List, Optional
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
|
@ -10,14 +10,13 @@ from jinja2 import Template
|
|||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from termcolor import cprint # noqa: F401
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from termcolor import cprint # noqa: F401
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
|
@ -9,8 +9,7 @@ from typing import List
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
|
@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
|
@ -35,13 +35,12 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
|
@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
|
@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
|
|||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
|
@ -5,9 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@json_schema_type
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
|
@ -8,10 +8,11 @@ import logging
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import faiss
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
|
@ -7,11 +7,17 @@ from enum import Enum
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkEvalTaskConfig
|
||||
from .....apis.eval.eval import (
|
||||
AppEvalTaskConfig,
|
||||
BenchmarkEvalTaskConfig,
|
||||
Eval,
|
||||
EvalTaskConfig,
|
||||
EvaluateResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
|||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
# NOTE: this is the default eval task identifier for app eval
|
||||
# it is used to make the router work for all app evals
|
||||
# For app eval using other eval providers, the eval task identifier will be different
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
|
||||
|
||||
|
||||
|
@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def run_benchmark_eval(
|
||||
async def run_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
eval_task_config: BenchmarkEvalTaskConfig,
|
||||
benchmark_config: BenchmarkEvalTaskConfig,
|
||||
) -> Job:
|
||||
raise NotImplementedError("Benchmark eval is not implemented yet")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
eval_task_def: EvalTaskDef,
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task: EvalTaskDef,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
dataset_id = eval_task_def.dataset_id
|
||||
candidate = eval_task_config.eval_candidate
|
||||
scoring_functions = eval_task_def.scoring_functions
|
||||
dataset_id = task.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
res = await self.evaluate_rows(
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
eval_task_config=eval_task_config,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task_config: EvalTaskConfig,
|
||||
eval_task_id: Optional[str] = None,
|
||||
) -> EvaluateResponse:
|
||||
candidate = eval_task_config.eval_candidate
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
|
@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id, eval_task_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
|
||||
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestFaissMemoryImpl:
|
||||
@pytest.fixture
|
||||
def faiss_impl(self):
|
||||
# Create a temporary SQLite database file
|
||||
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
|
||||
return FaissMemoryImpl(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, faiss_impl):
|
||||
# Test empty initialization
|
||||
await faiss_impl.initialize()
|
||||
assert len(faiss_impl.cache) == 0
|
||||
|
||||
# Test initialization with existing banks
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
# Register a bank and reinitialize to test loading
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
# Create new instance to test initialization with existing data
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
|
||||
assert len(new_impl.cache) == 1
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_memory_bank(self, faiss_impl):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await faiss_impl.initialize()
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
assert "test_bank" in faiss_impl.cache
|
||||
assert faiss_impl.cache["test_bank"].bank == bank
|
||||
|
||||
# Verify persistence
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
|
@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
|
@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
|
@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = None
|
||||
if scoring_params is not None:
|
||||
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
|
|
|
@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
module="llama_stack.providers.inline.meta_reference.agents",
|
||||
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
|
|
|
@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
|
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
|
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.vllm",
|
||||
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.memory,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.meta_reference.memory",
|
||||
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `faiss` provider instead.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_type="faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
|
@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.meta_reference.safety",
|
||||
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
|
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.inline.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
|||
default="https://api.fireworks.ai/inference",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
|
|
@ -9,12 +9,11 @@ from typing import AsyncGenerator
|
|||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
|
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
||||
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
class FireworksInferenceAdapter(
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
|
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
fireworks_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
fireworks_api_key = provider_data.fireworks_api_key
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_completion(request, client)
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request, client)
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.acreate(**params)
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await client.completion.acreate(**params)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
if "messages" in params:
|
||||
stream = client.chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = client.completion.acreate(**params)
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
stream = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
|
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
elif isinstance(request, CompletionRequest):
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,9 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.agents import (
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class Testeval:
|
|||
response = await eval_impl.evaluate_rows(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
eval_task_config=AppEvalTaskConfig(
|
||||
task_config=AppEvalTaskConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
sampling_params=SamplingParams(),
|
||||
|
@ -76,13 +76,13 @@ class Testeval:
|
|||
]
|
||||
|
||||
response = await eval_impl.run_eval(
|
||||
eval_task_def=EvalTaskDef(
|
||||
task=EvalTaskDef(
|
||||
# NOTE: this is needed to make the router work for all app evals
|
||||
identifier="meta-reference::app_eval",
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
),
|
||||
eval_task_config=AppEvalTaskConfig(
|
||||
task_config=AppEvalTaskConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
sampling_params=SamplingParams(),
|
||||
|
@ -90,9 +90,13 @@ class Testeval:
|
|||
),
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(response.job_id)
|
||||
job_status = await eval_impl.job_status(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(response.job_id)
|
||||
eval_response = await eval_impl.job_result(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
|
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.inference import (
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.safety import (
|
||||
from llama_stack.providers.inline.safety.meta_reference import (
|
||||
LlamaGuardShieldConfig,
|
||||
SafetyConfig,
|
||||
)
|
||||
|
|
|
@ -44,10 +44,10 @@ class TestScoring:
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
"meta-reference::equality",
|
||||
]
|
||||
scoring_functions = {
|
||||
"meta-reference::llm_as_judge_8b_correctness": None,
|
||||
"meta-reference::equality": None,
|
||||
}
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
|
@ -83,7 +83,7 @@ class TestScoring:
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
params = {
|
||||
scoring_functions = {
|
||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
||||
judge_model="Llama3.1-405B-Instruct",
|
||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||
|
@ -91,13 +91,9 @@ class TestScoring:
|
|||
)
|
||||
}
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
]
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
|
@ -108,7 +104,6 @@ class TestScoring:
|
|||
response = await scoring_impl.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue