mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 06:53:47 +00:00
Merge branch 'eval_task_register' into mmlu_benchmark
This commit is contained in:
commit
cc6edf6287
72 changed files with 306 additions and 304 deletions
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
|||
default_factory=list,
|
||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
deprecation_warning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If this provider is deprecated, specify the warning message here",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
|
|
@ -4,10 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
|
@ -11,9 +11,8 @@ from datetime import datetime
|
|||
|
||||
from typing import List, Optional
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
|
@ -10,14 +10,13 @@ from jinja2 import Template
|
|||
from llama_models.llama3.api import * # noqa: F403
|
||||
|
||||
|
||||
from termcolor import cprint # noqa: F401
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
DefaultMemoryQueryGeneratorConfig,
|
||||
LLMMemoryQueryGeneratorConfig,
|
||||
MemoryQueryGenerator,
|
||||
MemoryQueryGeneratorConfig,
|
||||
)
|
||||
from termcolor import cprint # noqa: F401
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
|
|
@ -9,8 +9,7 @@ from typing import List
|
|||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
|
||||
|
||||
from ..safety import ShieldRunnerMixin
|
||||
from .builtin import BaseTool
|
||||
|
||||
|
|
@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
|
@ -35,13 +35,12 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
|
@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
get_model_parallel_src_rank,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||
|
||||
from .generation import TokenResult
|
||||
|
||||
|
|
@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
|
|||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from termcolor import cprint
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
|
@ -5,9 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@json_schema_type
|
|
@ -5,13 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@json_schema_type
|
|
@ -8,10 +8,11 @@ import logging
|
|||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
import faiss
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
|
@ -7,11 +7,17 @@ from enum import Enum
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import BenchmarkEvalTaskConfig
|
||||
from .....apis.eval.eval import (
|
||||
AppEvalTaskConfig,
|
||||
BenchmarkEvalTaskConfig,
|
||||
Eval,
|
||||
EvalTaskConfig,
|
||||
EvaluateResponse,
|
||||
JobStatus,
|
||||
)
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
|||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
# NOTE: this is the default eval task identifier for app eval
|
||||
# it is used to make the router work for all app evals
|
||||
# For app eval using other eval providers, the eval task identifier will be different
|
||||
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
|
||||
|
||||
|
||||
|
@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||
)
|
||||
|
||||
async def run_benchmark_eval(
|
||||
async def run_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
eval_task_config: BenchmarkEvalTaskConfig,
|
||||
benchmark_config: BenchmarkEvalTaskConfig,
|
||||
) -> Job:
|
||||
raise NotImplementedError("Benchmark eval is not implemented yet")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
eval_task_def: EvalTaskDef,
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task: EvalTaskDef,
|
||||
task_config: AppEvalTaskConfig,
|
||||
) -> Job:
|
||||
dataset_id = eval_task_def.dataset_id
|
||||
candidate = eval_task_config.eval_candidate
|
||||
scoring_functions = eval_task_def.scoring_functions
|
||||
dataset_id = task.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task.scoring_functions
|
||||
|
||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
|
@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
res = await self.evaluate_rows(
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
eval_task_config=eval_task_config,
|
||||
task_config=task_config,
|
||||
)
|
||||
|
||||
# TODO: currently needs to wait for generation before returning
|
||||
|
@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
eval_task_config: EvalTaskConfig,
|
||||
task_config: EvalTaskConfig,
|
||||
eval_task_id: Optional[str] = None,
|
||||
) -> EvaluateResponse:
|
||||
candidate = eval_task_config.eval_candidate
|
||||
candidate = task_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
raise NotImplementedError(
|
||||
"Evaluation with generation has not been implemented for agents"
|
||||
|
@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
for input_r, generated_r in zip(input_rows, generations)
|
||||
]
|
||||
|
||||
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||
for scoring_fn_id in scoring_functions
|
||||
}
|
||||
else:
|
||||
scoring_functions_dict = {
|
||||
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||
}
|
||||
|
||||
score_response = await self.scoring_api.score(
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
||||
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||
)
|
||||
|
||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||
|
||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
|
||||
if job_id in self.jobs:
|
||||
return JobStatus.completed
|
||||
|
||||
return None
|
||||
|
||||
async def job_cancel(self, job_id: str) -> None:
|
||||
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
|
||||
raise NotImplementedError("Job cancel is not implemented yet")
|
||||
|
||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id)
|
||||
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
|
||||
status = await self.job_status(job_id, eval_task_id)
|
||||
if not status or status != JobStatus.completed:
|
||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
|
||||
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestFaissMemoryImpl:
|
||||
@pytest.fixture
|
||||
def faiss_impl(self):
|
||||
# Create a temporary SQLite database file
|
||||
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
||||
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
|
||||
return FaissMemoryImpl(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, faiss_impl):
|
||||
# Test empty initialization
|
||||
await faiss_impl.initialize()
|
||||
assert len(faiss_impl.cache) == 0
|
||||
|
||||
# Test initialization with existing banks
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
# Register a bank and reinitialize to test loading
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
# Create new instance to test initialization with existing data
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
|
||||
assert len(new_impl.cache) == 1
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_memory_bank(self, faiss_impl):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
type=MemoryBankType.vector.value,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await faiss_impl.initialize()
|
||||
await faiss_impl.register_memory_bank(bank)
|
||||
|
||||
assert "test_bank" in faiss_impl.cache
|
||||
assert faiss_impl.cache["test_bank"].bank == bank
|
||||
|
||||
# Verify persistence
|
||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
||||
await new_impl.initialize()
|
||||
assert "test_bank" in new_impl.cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
|
@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
|
@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
res = await self.score(
|
||||
input_rows=all_rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
|
@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score(
|
||||
self,
|
||||
input_rows: List[Dict[str, Any]],
|
||||
scoring_functions: List[str],
|
||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
for scoring_fn_id in scoring_functions.keys():
|
||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||
scoring_fn_params = None
|
||||
if scoring_params is not None:
|
||||
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
|
||||
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
|
|
|
@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
module="llama_stack.providers.inline.meta_reference.agents",
|
||||
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
|
|
|
@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.inference,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=META_REFERENCE_DEPS,
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
|
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.inline.meta_reference.inference",
|
||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
|
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="vllm",
|
||||
pip_packages=[
|
||||
"vllm",
|
||||
],
|
||||
module="llama_stack.providers.inline.vllm",
|
||||
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
api=Api.memory,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.meta_reference.memory",
|
||||
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `faiss` provider instead.",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_type="faiss",
|
||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
|
@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"transformers",
|
||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||
],
|
||||
module="llama_stack.providers.inline.meta_reference.safety",
|
||||
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
],
|
||||
|
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[
|
||||
"codeshield",
|
||||
],
|
||||
module="llama_stack.providers.inline.meta_reference.codeshield",
|
||||
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
|
||||
module="llama_stack.providers.inline.safety.meta_reference",
|
||||
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
|
||||
api_dependencies=[],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
|||
default="https://api.fireworks.ai/inference",
|
||||
description="The URL for the Fireworks server",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Fireworks.ai API Key",
|
||||
)
|
||||
|
|
|
@ -9,12 +9,11 @@ from typing import AsyncGenerator
|
|||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
FIREWORKS_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||
|
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
|
|||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
||||
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
|
||||
}
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||
class FireworksInferenceAdapter(
|
||||
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||
|
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_client(self) -> Fireworks:
|
||||
fireworks_api_key = None
|
||||
if self.config.api_key is not None:
|
||||
fireworks_api_key = self.config.api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.fireworks_api_key:
|
||||
raise ValueError(
|
||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||
)
|
||||
fireworks_api_key = provider_data.fireworks_api_key
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_completion(request, client)
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request, client)
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _nonstream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
self, request: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(
|
||||
self, request: CompletionRequest, client: Fireworks
|
||||
) -> AsyncGenerator:
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
# Wrapper for async generator similar
|
||||
async def _to_async_generator():
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
def _build_options(
|
||||
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||
) -> dict:
|
||||
options = get_sampling_options(sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
client = Fireworks(api_key=self.config.api_key)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.acreate(**params)
|
||||
r = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await client.completion.acreate(**params)
|
||||
r = await self._get_client().completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
self, request: ChatCompletionRequest
|
||||
) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
if "messages" in params:
|
||||
stream = client.chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = client.completion.acreate(**params)
|
||||
async def _to_async_generator():
|
||||
if "messages" in params:
|
||||
stream = await self._get_client().chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = self._get_client().completion.create(**params)
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
|
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
elif isinstance(request, CompletionRequest):
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
"type": "json_object",
|
||||
"schema": fmt.json_schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
options["response_format"] = {
|
||||
"type": "grammar",
|
||||
"grammar": fmt.bnf,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -4,9 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
||||
|
||||
class TogetherProviderDataValidator(BaseModel):
|
||||
together_api_key: str
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||
from .together import TogetherInferenceAdapter
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.meta_reference.agents import (
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class Testeval:
|
|||
response = await eval_impl.evaluate_rows(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
eval_task_config=AppEvalTaskConfig(
|
||||
task_config=AppEvalTaskConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
sampling_params=SamplingParams(),
|
||||
|
@ -76,13 +76,13 @@ class Testeval:
|
|||
]
|
||||
|
||||
response = await eval_impl.run_eval(
|
||||
eval_task_def=EvalTaskDef(
|
||||
task=EvalTaskDef(
|
||||
# NOTE: this is needed to make the router work for all app evals
|
||||
identifier="meta-reference::app_eval",
|
||||
dataset_id="test_dataset_for_eval",
|
||||
scoring_functions=scoring_functions,
|
||||
),
|
||||
eval_task_config=AppEvalTaskConfig(
|
||||
task_config=AppEvalTaskConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
sampling_params=SamplingParams(),
|
||||
|
@ -90,9 +90,13 @@ class Testeval:
|
|||
),
|
||||
)
|
||||
assert response.job_id == "0"
|
||||
job_status = await eval_impl.job_status(response.job_id)
|
||||
job_status = await eval_impl.job_status(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
assert job_status and job_status.value == "completed"
|
||||
eval_response = await eval_impl.job_result(response.job_id)
|
||||
eval_response = await eval_impl.job_result(
|
||||
response.job_id, "meta-reference::app_eval"
|
||||
)
|
||||
|
||||
assert eval_response is not None
|
||||
assert len(eval_response.generations) == 5
|
||||
|
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.inference import (
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.meta_reference.safety import (
|
||||
from llama_stack.providers.inline.safety.meta_reference import (
|
||||
LlamaGuardShieldConfig,
|
||||
SafetyConfig,
|
||||
)
|
||||
|
|
|
@ -44,10 +44,10 @@ class TestScoring:
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
"meta-reference::equality",
|
||||
]
|
||||
scoring_functions = {
|
||||
"meta-reference::llm_as_judge_8b_correctness": None,
|
||||
"meta-reference::equality": None,
|
||||
}
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
|
@ -83,7 +83,7 @@ class TestScoring:
|
|||
)
|
||||
assert len(rows.rows) == 3
|
||||
|
||||
params = {
|
||||
scoring_functions = {
|
||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
||||
judge_model="Llama3.1-405B-Instruct",
|
||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||
|
@ -91,13 +91,9 @@ class TestScoring:
|
|||
)
|
||||
}
|
||||
|
||||
scoring_functions = [
|
||||
"meta-reference::llm_as_judge_8b_correctness",
|
||||
]
|
||||
response = await scoring_impl.score(
|
||||
input_rows=rows.rows,
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
|
@ -108,7 +104,6 @@ class TestScoring:
|
|||
response = await scoring_impl.score_batch(
|
||||
dataset_id="test_dataset",
|
||||
scoring_functions=scoring_functions,
|
||||
scoring_params=params,
|
||||
)
|
||||
assert len(response.results) == len(scoring_functions)
|
||||
for x in scoring_functions:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue