Merge branch 'meta-llama:main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-12 15:12:42 -05:00 committed by GitHub
commit e49bcd46fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 3142 additions and 586 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
from .agents import MetaReferenceAgentsImpl
impl = MetaReferenceAgentsImpl(

View file

@ -181,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
async with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
@ -191,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
async with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
@ -218,18 +218,10 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
if isinstance(request.tool_responses[0], ToolResponseMessage):
tool_response_messages = request.tool_responses
tool_responses = [
ToolResponse(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
else:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
tool_responses = request.tool_responses
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, tool_name=x.tool_name, content=x.content)
for x in request.tool_responses
]
messages.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
@ -252,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=tool_responses,
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
@ -390,7 +382,7 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
with tracing.span("run_shields") as span:
async with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
@ -508,7 +500,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference") as span:
async with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -685,7 +677,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
with tracing.span(
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,

View file

@ -12,6 +12,7 @@ import uuid
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
Session,
Turn,
)
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id=agent_id,
)
async def get_agent(self, agent_id: str) -> ChatAgent:
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
@ -169,7 +172,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: Union[List[ToolResponse], List[ToolResponseMessage]],
tool_responses: List[ToolResponse],
stream: Optional[bool] = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
return turn
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id: str,
turn_ids: Optional[List[str]] = None,
) -> Session:
agent = await self.get_agent(agent_id)
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def shutdown(self) -> None:
pass
async def list_agents(self) -> ListAgentsResponse:
pass
async def get_agent(self, agent_id: str) -> Agent:
pass
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
pass

View file

@ -10,6 +10,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.providers.utils.telemetry import tracing
log = logging.getLogger(__name__)
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(
shield_id=identifier,
messages=messages,
)
for identifier in identifiers
]
)
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
for identifier, response in zip(identifiers, responses, strict=False):
if not response.violation:
continue

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LocalFSDatasetIOConfig
async def get_provider_impl(
config: LocalFSDatasetIOConfig,
_deps,
_deps: Dict[str, Any],
):
from .datasetio import LocalFSDatasetIOImpl

View file

@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import MetaReferenceEvalConfig
async def get_provider_impl(
config: MetaReferenceEvalConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .eval import MetaReferenceEvalImpl

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Union
from typing import Any, Dict, Union
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
async def get_provider_impl(
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
_deps,
_deps: Dict[str, Any],
):
from .inference import MetaReferenceInferenceImpl

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
_deps: Dict[str, Any],
):
from .sentence_transformers import SentenceTransformersInferenceImpl

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import Any, Dict
from .config import VLLMConfig
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
from .vllm import VLLMInferenceImpl
impl = VLLMInferenceImpl(config)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import TorchtunePostTrainingConfig
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .post_training import TorchtunePostTrainingImpl

View file

@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
self.jobs = {}
self.checkpoints_dict = {}
async def shutdown(self):
pass
async def supervised_fine_tune(
self,
job_uuid: str,

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeScannerConfig
async def get_provider_impl(config: CodeScannerConfig, deps):
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import LlamaGuardConfig
async def get_provider_impl(config: LlamaGuardConfig, deps):
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
from .llama_guard import LlamaGuardSafetyImpl
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,10 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import PromptGuardConfig # noqa: F401
async def get_provider_impl(config: PromptGuardConfig, deps):
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
from .prompt_guard import PromptGuardSafetyImpl
impl = PromptGuardSafetyImpl(config, deps)

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BasicScoringConfig
async def get_provider_impl(
config: BasicScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import BasicScoringImpl

View file

@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
class BasicScoringImpl(

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
regex_parser_math_response = ScoringFn(
identifier="basic::regex_parser_math_response",
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
return_type=NumberType(),
provider_id="basic",
provider_resource_id="regex-parser-math-response",
params=RegexParserScoringFnParams(
parsing_regexes=MATH_ANSWER_REGEXES,
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import (
regex_parser_math_response,
)
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {
regex_parser_math_response.identifier: regex_parser_math_response,
}
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
f"RegexParserScoringFnParams not found for {fn_def}."
)
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
parsing_regexes = fn_def.params.parsing_regexes
assert len(parsing_regexes) == 1, (
"Only one parsing regex is supported for regex_parser_math_response scoring function."
)
parsing_regexes = fn_def.params.parsing_regexes[0]
normalized_generated_answer = normalize_final_answer(
first_answer(generated_answer),
parsing_regexes,
match_first=True,
)
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
return {
"score": score,
}

View file

@ -0,0 +1,330 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Sequence
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
# from minerva
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
if isinstance(expression, float):
return expression
new_expression = f"{expression}"
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
for match in re.finditer(regex, expression):
try:
value = float(match.group(1)) / float(match.group(2))
new_expression = new_expression.replace(
match.group(),
f"{{value:{fmt}}}".format(value=value),
1,
)
except Exception:
continue
return new_expression
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
try:
with time_limit(seconds=5):
from sympy.parsing.latex import parse_latex
value = parse_latex(expression).evalf() # type: ignore
return f"{{value:{fmt}}}".format(value=value)
except Exception:
return expression
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
for marker in markers:
text = text.split(marker)[0]
return text
def extract_result_from_boxed(answer: str) -> str:
box_start = "\\boxed"
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
start = answer.rfind(box_start)
if start < 0:
return ""
answer = answer[start + len(box_start) :].strip()
ends_with_curly = answer.startswith("{")
i = 0
open_braces = 0
while i < len(answer):
if answer[i] == "{":
open_braces += 1
elif answer[i] == "}":
open_braces -= 1
if open_braces == 0:
if ends_with_curly:
answer = answer[: i + 1].strip()
break
elif answer[i] == "$":
answer = answer[:i].strip()
break
i += 1
else:
return ""
# remove extra curly braces
while True:
if answer.startswith("{") and answer.endswith("}"):
answer = answer[1:-1].strip()
else:
break
return answer
# from minerva paper + _normalise_result from xavierm
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
"""Extract and normalize a final answer to a quantitative reasoning question."""
match = re.findall(regex_pattern, final_answer)
extraction: str
if len(match) > 0:
if match_first:
extraction = match[0]
else:
extraction = match[-1]
else:
extraction = extract_result_from_boxed(final_answer)
if len(extraction) == 0:
return final_answer
else:
final_answer = extraction
final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
# If the final answer is a single letter in parentheses, remove the parentheses
# Example: (a) -> a (but not (ab) -> ab)
if re.match(r"\([a-zA-Z]\)", final_answer):
final_answer = final_answer[1]
return _normalise_result(final_answer)
def _normalise_result(string: str) -> str:
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("cfrac", "frac")
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\le", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = _remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace(r"\%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
string = string.split("=")[-1]
# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
def _remove_right_units(string: str) -> str:
# "\\text{ " only ever occurs (at least in the val set) when describing units
try:
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
except AssertionError:
return string
def _fix_sqrt(string: str) -> str:
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if len(split) == 0:
return string
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def _fix_fracs(string: str) -> str:
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) == 0:
return string
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string: str) -> str:
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
ia = int(a)
ib = int(b)
assert string == "{}/{}".format(ia, ib)
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
return new_string
except (ValueError, AssertionError):
return string

View file

@ -3,11 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import BraintrustScoringConfig
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .braintrust import BraintrustScoringImpl

View file

@ -3,16 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import Api
from .config import LlmAsJudgeScoringConfig
async def get_provider_impl(
config: LlmAsJudgeScoringConfig,
deps: Dict[Api, ProviderSpec],
deps: Dict[Api, Any],
):
from .scoring import LlmAsJudgeScoringImpl

View file

@ -73,6 +73,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps.get(Api.datasetio)
self.meter = None
resource = Resource.create(
{
@ -171,6 +172,8 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
if self.meter is None:
return
if isinstance(event.value, int):
counter = self._get_or_create_counter(event.metric, event.unit)
counter.add(event.value, attributes=event.attributes)

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import CodeInterpreterToolConfig
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
from .code_interpreter import CodeInterpreterToolRuntimeImpl
impl = CodeInterpreterToolRuntimeImpl(config)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import ChromaVectorIOConfig
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.chroma.chroma import (
ChromaVectorIOAdapter,
)

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import FaissVectorIOConfig
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
from .faiss import FaissVectorIOAdapter
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference])

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from typing import Any, Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.datatypes import Api
from .config import SQLiteVectorIOConfig
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"

View file

@ -34,6 +34,8 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
),
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
# source distribution and the wheels are not available for all platforms.
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::sqlite-vec",

View file

@ -24,10 +24,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,

View file

@ -4,12 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack_client import LlamaStackClient
from llama_stack_client import AsyncLlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -24,6 +26,7 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .config import PassthroughImplConfig
@ -46,7 +49,7 @@ class PassthroughInferenceAdapter(Inference):
async def register_model(self, model: Model) -> Model:
return model
def _get_client(self) -> LlamaStackClient:
def _get_client(self) -> AsyncLlamaStackClient:
passthrough_url = None
passthrough_api_key = None
provider_data = None
@ -71,7 +74,7 @@ class PassthroughInferenceAdapter(Inference):
)
passthrough_api_key = provider_data.passthrough_api_key
return LlamaStackClient(
return AsyncLlamaStackClient(
base_url=passthrough_url,
api_key=passthrough_api_key,
provider_data=provider_data,
@ -91,7 +94,7 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
request_params = {
"model_id": model.provider_resource_id,
"content": content,
"sampling_params": sampling_params,
@ -100,10 +103,13 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
# only pass through the not None params
return client.inference.completion(**params)
return await client.inference.completion(**json_params)
async def chat_completion(
self,
@ -120,10 +126,14 @@ class PassthroughInferenceAdapter(Inference):
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
client = self._get_client()
model = await self.model_store.get_model(model_id)
params = {
# TODO: revisit this remove tool_calls from messages logic
for message in messages:
if hasattr(message, "tool_calls"):
message.tool_calls = None
request_params = {
"model_id": model.provider_resource_id,
"messages": messages,
"sampling_params": sampling_params,
@ -135,10 +145,39 @@ class PassthroughInferenceAdapter(Inference):
"logprobs": logprobs,
}
params = {key: value for key, value in params.items() if value is not None}
# only pass through the not None params
return client.inference.chat_completion(**params)
request_params = {key: value for key, value in request_params.items() if value is not None}
# cast everything to json dict
json_params = self.cast_value_to_json_dict(request_params)
if stream:
return self._stream_chat_completion(json_params)
else:
return await self._nonstream_chat_completion(json_params)
async def _nonstream_chat_completion(self, json_params: Dict[str, Any]) -> ChatCompletionResponse:
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()
stream_response = await client.inference.chat_completion(**json_params)
async for chunk in stream_response:
chunk = chunk.to_dict()
# temporary hack to remove the metrics from the response
chunk["metrics"] = []
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
async def embeddings(
self,
@ -151,10 +190,29 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
model = await self.model_store.get_model(model_id)
return client.inference.embeddings(
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
def cast_value_to_json_dict(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
json_params = {}
for key, value in request_params.items():
json_input = convert_pydantic_to_json_value(value)
if isinstance(json_input, dict):
json_input = {k: v for k, v in json_input.items() if v is not None}
elif isinstance(json_input, list):
json_input = [x for x in json_input if x is not None]
new_input = []
for x in json_input:
if isinstance(x, dict):
x = {k: v for k, v in x.items() if v is not None}
new_input.append(x)
json_input = new_input
json_params[key] = json_input
return json_params

View file

@ -26,5 +26,5 @@ class TogetherImplConfig(BaseModel):
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.together.xyz/v1",
"api_key": "${env.TOGETHER_API_KEY}",
"api_key": "${env.TOGETHER_API_KEY:}",
}

View file

@ -6,7 +6,7 @@
from typing import AsyncGenerator, List, Optional, Union
from together import Together
from together import AsyncTogether
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -59,12 +59,15 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config
self._client = None
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
if self._client:
await self._client.close()
self._client = None
async def completion(
self,
@ -91,35 +94,32 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
else:
return await self._nonstream_completion(request)
def _get_client(self) -> Together:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
return Together(api_key=together_api_key)
def _get_client(self) -> AsyncTogether:
if not self._client:
together_api_key = None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
together_api_key = config_api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
self._client = AsyncTogether(api_key=together_api_key)
return self._client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client().completions.create(**params)
client = self._get_client()
r = await client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
client = await self._get_client()
stream = await client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -184,25 +184,21 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
r = self._get_client().chat.completions.create(**params)
r = await client.chat.completions.create(**params)
else:
r = self._get_client().completions.create(**params)
r = await client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
client = self._get_client()
if "messages" in params:
stream = await client.chat.completions.create(**params)
else:
stream = await client.completions.create(**params)
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
if "messages" in params:
s = self._get_client().chat.completions.create(**params)
else:
s = self._get_client().completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@ -240,7 +236,8 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
r = self._get_client().embeddings.create(
client = self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)

View file

@ -615,6 +615,14 @@ def convert_tool_call(
return valid_tool_call
PYTHON_TYPE_TO_LITELLM_TYPE = {
"int": "integer",
"float": "number",
"bool": "boolean",
"str": "string",
}
def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
"""
Convert a ToolDefinition to an OpenAI API-compatible dictionary.
@ -675,7 +683,7 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
properties = parameters["properties"]
required = []
for param_name, param in tool.parameters.items():
properties[param_name] = {"type": param.param_type}
properties[param_name] = {"type": PYTHON_TYPE_TO_LITELLM_TYPE.get(param.param_type, param.param_type)}
if param.description:
properties[param_name].update(description=param.description)
if param.default:

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import contextlib
import signal
from types import FrameType
from typing import Iterator, Optional
class TimeoutError(Exception):
pass
@contextlib.contextmanager
def time_limit(seconds: float) -> Iterator[None]:
def signal_handler(signum: int, frame: Optional[FrameType]) -> None:
raise TimeoutError("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)

View file

@ -6,6 +6,7 @@
import asyncio
import base64
import contextvars
import logging
import queue
import threading
@ -24,9 +25,10 @@ from llama_stack.apis.telemetry import (
Telemetry,
UnstructuredLogEvent,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value
log = logging.getLogger(__name__)
logger = get_logger(__name__, category="core")
def generate_short_uuid(len: int = 8):
@ -36,7 +38,7 @@ def generate_short_uuid(len: int = 8):
return encoded.rstrip(b"=").decode("ascii")[:len]
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None)
BACKGROUND_LOGGER = None
@ -51,7 +53,7 @@ class BackgroundLogger:
try:
self.log_queue.put_nowait(event)
except queue.Full:
log.error("Log queue is full, dropping event")
logger.error("Log queue is full, dropping event")
def _process_logs(self):
while True:
@ -129,35 +131,36 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
if BACKGROUND_LOGGER is None:
BACKGROUND_LOGGER = BackgroundLogger(api)
logger = logging.getLogger()
logger.setLevel(level)
logger.addHandler(TelemetryHandler())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
log.info("No Telemetry implementation set. Skipping trace initialization...")
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
return
trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
CURRENT_TRACE_CONTEXT.set(context)
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
logger.debug("No trace context to end")
return
context.pop_span(status)
CURRENT_TRACE_CONTEXT = None
CURRENT_TRACE_CONTEXT.set(None)
def severity(levelname: str) -> LogSeverity:
@ -188,7 +191,7 @@ class TelemetryHandler(logging.Handler):
if BACKGROUND_LOGGER is None:
raise RuntimeError("Telemetry API not initialized")
context = CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT.get()
if context is None:
return
@ -218,16 +221,22 @@ class SpanContextManager:
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
@ -237,16 +246,22 @@ class SpanContextManager:
async def __aenter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to push span")
return self
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
context = CURRENT_TRACE_CONTEXT.get()
if not context:
logger.debug("No trace context to pop span")
return
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
@ -275,7 +290,11 @@ def span(name: str, attributes: Dict[str, Any] = None):
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if CURRENT_TRACE_CONTEXT is None:
logger.debug("No trace context to get current span")
return None
context = CURRENT_TRACE_CONTEXT.get()
if context:
return context.get_current_span()
return None