Merge branch 'main' into henrytu/cerebras-integration

This commit is contained in:
Henry Tu 2024-12-02 10:57:59 -05:00 committed by GitHub
commit c29e3271d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 523 additions and 139 deletions

View file

@ -35,7 +35,7 @@ class NeedsRequestProviderData:
provider_data = validator(**val)
return provider_data
except Exception as e:
log.error("Error parsing provider data", e)
log.error(f"Error parsing provider data: {e}")
def set_request_provider_data(headers: Dict[str, str]):

View file

@ -0,0 +1,11 @@
# LLama Stack UI
[!NOTE] This is a work in progress.
## Running Streamlit App
```
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,173 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import pandas as pd
import streamlit as st
from modules.api import LlamaStackEvaluation
from modules.utils import process_dataset
EVALUATION_API = LlamaStackEvaluation()
def main():
# Add collapsible sidebar
with st.sidebar:
# Add collapse button
if "sidebar_state" not in st.session_state:
st.session_state.sidebar_state = True
if st.session_state.sidebar_state:
st.title("Navigation")
page = st.radio(
"Select a Page",
["Application Evaluation"],
index=0,
)
else:
page = "Application Evaluation" # Default page when sidebar is collapsed
# Main content area
st.title("🦙 Llama Stack Evaluations")
if page == "Application Evaluation":
application_evaluation_page()
def application_evaluation_page():
# File uploader
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
if uploaded_file is None:
st.error("No file uploaded")
return
# Process uploaded file
df = process_dataset(uploaded_file)
if df is None:
st.error("Error processing file")
return
# Display dataset information
st.success("Dataset loaded successfully!")
# Display dataframe preview
st.subheader("Dataset Preview")
st.dataframe(df)
# Select Scoring Functions to Run Evaluation On
st.subheader("Select Scoring Functions")
scoring_functions = EVALUATION_API.list_scoring_functions()
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
scoring_functions_names = list(scoring_functions.keys())
selected_scoring_functions = st.multiselect(
"Choose one or more scoring functions",
options=scoring_functions_names,
help="Choose one or more scoring functions.",
)
available_models = EVALUATION_API.list_models()
available_models = [m.identifier for m in available_models]
scoring_params = {}
if selected_scoring_functions:
st.write("Selected:")
for scoring_fn_id in selected_scoring_functions:
scoring_fn = scoring_functions[scoring_fn_id]
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
new_params = None
if scoring_fn.params:
new_params = {}
for param_name, param_value in scoring_fn.params.to_dict().items():
if param_name == "type":
new_params[param_name] = param_value
continue
if param_name == "judge_model":
value = st.selectbox(
f"Select **{param_name}** for {scoring_fn_id}",
options=available_models,
index=0,
key=f"{scoring_fn_id}_{param_name}",
)
new_params[param_name] = value
else:
value = st.text_area(
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
value=json.dumps(param_value, indent=2),
height=80,
)
try:
new_params[param_name] = json.loads(value)
except json.JSONDecodeError:
st.error(
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
)
st.json(new_params)
scoring_params[scoring_fn_id] = new_params
# Add run evaluation button & slider
total_rows = len(df)
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = df.to_dict(orient="records")
if num_rows < total_rows:
rows = rows[:num_rows]
# Create separate containers for progress text and results
progress_text_container = st.empty()
results_container = st.empty()
output_res = {}
for i, r in enumerate(rows):
# Update progress
progress = i / len(rows)
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
score_res = EVALUATION_API.run_scoring(
r,
scoring_function_ids=selected_scoring_functions,
scoring_params=scoring_params,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])
for fn_id in selected_scoring_functions:
if fn_id not in output_res:
output_res[fn_id] = []
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
# Display current row results using separate containers
progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})"
)
results_container.json(
score_res.to_json(),
expanded=2,
)
progress_bar.progress(1.0, text="Evaluation complete!")
# Display results in dataframe
if output_res:
output_df = pd.DataFrame(output_res)
st.subheader("Evaluation Results")
st.dataframe(output_df)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient
class LlamaStackEvaluation:
def __init__(self):
self.client = LlamaStackClient(
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"),
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
},
)
def list_scoring_functions(self):
"""List all available scoring functions"""
return self.client.scoring_functions.list()
def list_models(self):
"""List all available judge models"""
return self.client.models.list()
def run_scoring(
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
return self.client.scoring.score(
input_rows=[row], scoring_functions=scoring_params
)

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pandas as pd
import streamlit as st
def process_dataset(file):
if file is None:
return "No file uploaded", None
try:
# Determine file type and read accordingly
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == ".csv":
df = pd.read_csv(file)
elif file_ext in [".xlsx", ".xls"]:
df = pd.read_excel(file)
else:
return "Unsupported file format. Please upload a CSV or Excel file.", None
return df
except Exception as e:
st.error(f"Error processing file: {str(e)}")
return None

View file

@ -0,0 +1,3 @@
streamlit
pandas
llama-stack-client>=0.0.55

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -6,10 +6,15 @@
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from pydantic import BaseModel
from .config import BraintrustScoringConfig
class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],

View file

@ -12,9 +12,11 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
@ -24,7 +26,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
):
def __init__(
self,
config: BraintrustScoringConfig,
@ -79,12 +83,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None:
# api key is in the request headers
if self.config.openai_api_key is None:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -105,6 +122,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
@ -118,6 +136,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:

View file

@ -6,4 +6,8 @@
from llama_stack.apis.scoring import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)

View file

@ -10,7 +10,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
params=None,
provider_id="braintrust",
provider_resource_id="answer-correctness",

View file

@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class HuggingfaceDatasetIOConfig(BaseModel):

View file

@ -9,6 +9,7 @@ from llama_stack.apis.datasetio import * # noqa: F403
import datasets as hf_datasets
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl

View file

@ -35,7 +35,9 @@ class NVIDIAConfig(BaseModel):
"""
url: str = Field(
default="https://integrate.api.nvidia.com",
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[str] = Field(

View file

@ -89,8 +89,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model_id,
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
@ -194,8 +195,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model_id,
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
@ -249,7 +251,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,

View file

@ -6,10 +6,14 @@
import pytest
from ..agents.fixtures import AGENTS_FIXTURES
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
@ -20,6 +24,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
@ -30,6 +37,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
@ -40,6 +50,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
@ -75,6 +88,9 @@ def pytest_generate_tests(metafunc):
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
"agents": AGENTS_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -40,14 +40,30 @@ async def eval_stack(request):
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
for key in [
"datasetio",
"eval",
"scoring",
"inference",
"agents",
"safety",
"memory",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
[
Api.eval,
Api.datasetio,
Api.inference,
Api.scoring,
Api.agents,
Api.safety,
Api.memory,
],
providers,
provider_data,
)

View file

@ -21,6 +21,7 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -172,6 +173,22 @@ def inference_nvidia() -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_tgi() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="tgi",
provider_type="remote::tgi",
config=TGIImplConfig(
url=get_env_or_fail("TGI_URL"),
api_token=os.getenv("TGI_API_TOKEN", None),
).model_dump(),
)
],
)
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
@ -207,6 +224,7 @@ INFERENCE_FIXTURES = [
"bedrock",
"cerebras",
"nvidia",
"tgi",
]

View file

@ -10,9 +10,10 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -40,7 +41,9 @@ def scoring_braintrust() -> ProviderFixture:
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
config={},
config=BraintrustScoringConfig(
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
).model_dump(),
)
],
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -29,7 +29,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints

View file

@ -31,7 +31,7 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints