Merge branch 'main' into pr1573

This commit is contained in:
Xi Yan 2025-03-12 11:38:02 -07:00
commit 31e3409909
16 changed files with 737 additions and 115 deletions

View file

@ -422,6 +422,7 @@ def main():
"host": listen_host,
"port": port,
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
}
if ssl_config:
uvicorn_config.update(ssl_config)

View file

@ -170,6 +170,11 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
}
dictConfig(logging_config)
# Ensure third-party libraries follow the root log level
for _, logger in logging.root.manager.loggerDict.items():
if isinstance(logger, logging.Logger):
logger.setLevel(root_level)
def get_logger(name: str, category: str = "uncategorized") -> logging.LoggerAdapter:
"""

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
request_params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params
return client.inference.completion(**params)
return await client.inference.completion(**json_params)
async def chat_completion(
self,
@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference):
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.chat_completion(**params)
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings(
self,
@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}",
"api_key": "${env.TOGETHER_API_KEY:}",
}

View file

@ -615,6 +615,14 @@ def convert_tool_call(
return valid_tool_call
PYTHON_TYPE_TO_LITELLM_TYPE = {
"int": "integer",
"float": "number",
"bool": "boolean",
"str": "string",
}
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
if param.description:
properties[param_name].update(description=param.description)
if param.default:

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .open_benchmark import get_distribution_template # noqa: F401

View file

@ -0,0 +1,293 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Tuple
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
DatasetInput,
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig
from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig
from llama_stack.providers.remote.inference.gemini.config import GeminiConfig
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
# in this template, we allow each API key to be optional
providers = [
(
"openai",
[
ProviderModelEntry(
provider_model_id="openai/gpt-4o",
model_type=ModelType.llm,
)
],
OpenAIConfig.sample_run_config(api_key="${env.OPENAI_API_KEY:}"),
),
(
"anthropic",
[
ProviderModelEntry(
provider_model_id="anthropic/claude-3-5-sonnet-latest",
model_type=ModelType.llm,
)
],
AnthropicConfig.sample_run_config(api_key="${env.ANTHROPIC_API_KEY:}"),
),
(
"gemini",
[
ProviderModelEntry(
provider_model_id="gemini/gemini-1.5-flash",
model_type=ModelType.llm,
)
],
GeminiConfig.sample_run_config(api_key="${env.GEMINI_API_KEY:}"),
),
(
"groq",
[],
GroqConfig.sample_run_config(api_key="${env.GROQ_API_KEY:}"),
),
(
"together",
[],
TogetherImplConfig.sample_run_config(api_key="${env.TOGETHER_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": [p.provider_type for p in inference_providers],
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "open-benchmark"
vector_io_providers = [
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models) + [
ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq",
provider_model_id="groq/llama-3.3-70b-versatile",
model_type=ModelType.llm,
),
ModelInput(
model_id="meta-llama/Llama-3.1-405B-Instruct",
provider_id="together",
provider_model_id="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
model_type=ModelType.llm,
),
]
default_datasets = [
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"},
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"},
metadata={
"path": "llamastack/mmlu_cot",
"name": "all",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"},
metadata={
"path": "llamastack/gpqa_0shot_cot",
"name": "gpqa_main",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/math_500"},
metadata={
"path": "llamastack/math_500",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
),
]
default_benchmarks = [
BenchmarkInput(
benchmark_id="meta-reference-simpleqa",
dataset_id="simpleqa",
scoring_functions=["llm-as-judge::405b-simpleqa"],
),
BenchmarkInput(
benchmark_id="meta-reference-mmlu-cot",
dataset_id="mmlu_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-gpqa-cot",
dataset_id="gpqa_cot",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
),
BenchmarkInput(
benchmark_id="meta-reference-math-500",
dataset_id="math_500",
scoring_functions=["basic::regex_parser_math_response"],
),
]
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running open benchmarks",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers,
"vector_io": vector_io_providers,
},
default_models=default_models,
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
default_datasets=default_datasets,
default_benchmarks=default_benchmarks,
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"TOGETHER_API_KEY": (
"",
"Together API Key",
),
"OPENAI_API_KEY": (
"",
"OpenAI API Key",
),
"GEMINI_API_KEY": (
"",
"Gemini API Key",
),
"ANTHROPIC_API_KEY": (
"",
"Anthropic API Key",
),
"GROQ_API_KEY": (
"",
"Groq API Key",
),
},
)

View file

@ -38,7 +38,7 @@ providers:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/sqlite_vec.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
@ -62,14 +62,14 @@ providers:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/agents_store.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -114,18 +114,13 @@ providers:
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/registry.db
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/registry.db
models:
- metadata: {}
model_id: openai/gpt-4o
provider_id: openai
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
- metadata: {}
model_id: anthropic/claude-3-5-sonnet-latest
provider_id: anthropic
@ -141,84 +136,95 @@ models:
provider_id: groq
provider_model_id: groq/llama-3.3-70b-versatile
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct
provider_id: together
provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo
model_type: llm
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- dataset_id: simpleqa
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
name:
split: train
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_id: mmlu_cot
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_id: gpqa_cot
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_id: math_500
provider_id: huggingface
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
name:
split: test
dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
split: train
dataset_id: simpleqa
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
dataset_id: mmlu_cot
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
dataset_id: gpqa_cot
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
split: test
dataset_id: math_500
provider_id: huggingface
scoring_fns: []
benchmarks:
- benchmark_id: meta-reference-simpleqa
dataset_id: simpleqa
scoring_functions: ["llm-as-judge::405b-simpleqa"]
- benchmark_id: meta-reference-mmlu-cot
dataset_id: mmlu_cot
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
- benchmark_id: meta-reference-gpqa-cot
dataset_id: gpqa_cot
scoring_functions: ["basic::regex_parser_multiple_choice_answer"]
- benchmark_id: meta-reference-math-500
dataset_id: math_500
scoring_functions: ["basic::regex_parser_math_response"]
- dataset_id: simpleqa
scoring_functions:
- llm-as-judge::405b-simpleqa
metadata: {}
benchmark_id: meta-reference-simpleqa
- dataset_id: mmlu_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-mmlu-cot
- dataset_id: gpqa_cot
scoring_functions:
- basic::regex_parser_multiple_choice_answer
metadata: {}
benchmark_id: meta-reference-gpqa-cot
- dataset_id: math_500
scoring_functions:
- basic::regex_parser_math_response
metadata: {}
benchmark_id: meta-reference-math-500
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search

View file

@ -14,7 +14,9 @@ from pydantic import BaseModel, Field
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
Api,
BenchmarkInput,
BuildConfig,
DatasetInput,
DistributionSpec,
ModelInput,
Provider,
@ -56,6 +58,8 @@ class RunConfigSettings(BaseModel):
default_models: Optional[List[ModelInput]] = None
default_shields: Optional[List[ShieldInput]] = None
default_tool_groups: Optional[List[ToolGroupInput]] = None
default_datasets: Optional[List[DatasetInput]] = None
default_benchmarks: Optional[List[BenchmarkInput]] = None
def run_config(
self,
@ -113,6 +117,8 @@ class RunConfigSettings(BaseModel):
models=self.default_models or [],
shields=self.default_shields or [],
tool_groups=self.default_tool_groups or [],
datasets=self.default_datasets or [],
benchmarks=self.default_benchmarks or [],
)

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}
api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}

View file

@ -16,7 +16,7 @@ providers:
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: ${env.TOGETHER_API_KEY}
api_key: ${env.TOGETHER_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}