mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 00:40:02 +00:00
Merge remote-tracking branch 'origin/main' into math_500_judge
This commit is contained in:
commit
8eff285017
87 changed files with 2299 additions and 495 deletions
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MetaReferenceAgentsImplConfig, deps: Dict[Api, Any]):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from rich.markup import escape
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
|
|
@ -182,7 +181,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
|
|
@ -192,7 +191,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
with tracing.span("resume_turn") as span:
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
|
@ -391,7 +390,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
async with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
|
|
@ -509,7 +508,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference") as span:
|
||||
async with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
|
@ -686,7 +685,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_name = tool_call.tool_name
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
with tracing.span(
|
||||
async with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
|
@ -1030,7 +1029,7 @@ async def execute_tool_call_maybe(
|
|||
**toolgroup_args.get(group_name, {}),
|
||||
},
|
||||
)
|
||||
logger.info(f"tool call {name} completed with result: {escape(str(result))}")
|
||||
logger.info(f"tool call {name} completed with result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
|
|
@ -21,6 +22,8 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
|
@ -84,7 +87,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
|
|
@ -120,7 +123,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
|
|
@ -160,7 +163,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
|
|
@ -188,12 +191,12 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
return turn
|
||||
|
||||
|
|
@ -210,7 +213,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id: str,
|
||||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
agent = await self.get_agent(agent_id)
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
|
@ -232,3 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -32,15 +33,14 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: List[Message], identifiers: List[str]) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LocalFSDatasetIOConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .datasetio import LocalFSDatasetIOImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
url = str(dataset_info.dataset_def.url.uri)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceEvalConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .eval import MetaReferenceEvalImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
|
@ -11,7 +13,7 @@ from llama_stack.providers.inline.inference.sentence_transformers.config import
|
|||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: VLLMConfig, _deps) -> Any:
|
||||
async def get_provider_impl(config: VLLMConfig, _deps: Dict[str, Any]):
|
||||
from .vllm import VLLMInferenceImpl
|
||||
|
||||
impl = VLLMInferenceImpl(config)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TorchtunePostTrainingConfig
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ from .config import TorchtunePostTrainingConfig
|
|||
|
||||
async def get_provider_impl(
|
||||
config: TorchtunePostTrainingConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .post_training import TorchtunePostTrainingImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ class TorchtunePostTrainingImpl:
|
|||
self.jobs = {}
|
||||
self.checkpoints_dict = {}
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps):
|
||||
async def get_provider_impl(config: CodeScannerConfig, deps: Dict[str, Any]):
|
||||
from .code_scanner import MetaReferenceCodeScannerSafetyImpl
|
||||
|
||||
impl = MetaReferenceCodeScannerSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps):
|
||||
async def get_provider_impl(config: LlamaGuardConfig, deps: Dict[str, Any]):
|
||||
from .llama_guard import LlamaGuardSafetyImpl
|
||||
|
||||
assert isinstance(config, LlamaGuardConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import PromptGuardConfig # noqa: F401
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps):
|
||||
async def get_provider_impl(config: PromptGuardConfig, deps: Dict[str, Any]):
|
||||
from .prompt_guard import PromptGuardSafetyImpl
|
||||
|
||||
impl = PromptGuardSafetyImpl(config, deps)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BasicScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import BasicScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,11 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
|
||||
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
|
||||
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
|
||||
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
|
||||
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn, RegexParserMathResponseScoringFn]
|
||||
|
||||
|
||||
class BasicScoringImpl(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MATH_ANSWER_REGEXES = [r".*final answer is:?\s*\$\\boxed{(?P<X>.*)}\$"]
|
||||
|
||||
|
||||
regex_parser_math_response = ScoringFn(
|
||||
identifier="basic::regex_parser_math_response",
|
||||
description="For math related benchmarks, extract answer from the generated response and expected_answer and see if they match",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="regex-parser-math-response",
|
||||
params=RegexParserScoringFnParams(
|
||||
parsing_regexes=MATH_ANSWER_REGEXES,
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
regex_parser_math_response,
|
||||
)
|
||||
|
||||
|
||||
class RegexParserMathResponseScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for math benchamrks that parses answer from generated response according to context and check match with expected_answer.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
regex_parser_math_response.identifier: regex_parser_math_response,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: Dict[str, Any],
|
||||
scoring_fn_identifier: Optional[str] = None,
|
||||
scoring_params: Optional[ScoringFnParams] = None,
|
||||
) -> ScoringResultRow:
|
||||
assert scoring_fn_identifier is not None, "Scoring function identifier not found."
|
||||
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
|
||||
if scoring_params is not None:
|
||||
fn_def.params = scoring_params
|
||||
|
||||
assert fn_def.params is not None and fn_def.params.type == ScoringFnParamsType.regex_parser.value, (
|
||||
f"RegexParserScoringFnParams not found for {fn_def}."
|
||||
)
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
||||
parsing_regexes = fn_def.params.parsing_regexes
|
||||
assert len(parsing_regexes) == 1, (
|
||||
"Only one parsing regex is supported for regex_parser_math_response scoring function."
|
||||
)
|
||||
parsing_regexes = fn_def.params.parsing_regexes[0]
|
||||
|
||||
normalized_generated_answer = normalize_final_answer(
|
||||
first_answer(generated_answer),
|
||||
parsing_regexes,
|
||||
match_first=True,
|
||||
)
|
||||
normalized_generated_answer = try_evaluate_frac(try_evaluate_latex(normalized_generated_answer))
|
||||
|
||||
normalized_expected_answer = normalize_final_answer(expected_answer, r".*")
|
||||
normalized_expected_answer = try_evaluate_frac(try_evaluate_latex(normalized_expected_answer))
|
||||
|
||||
score = 1.0 if normalized_generated_answer == normalized_expected_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
330
llama_stack/providers/inline/scoring/basic/utils/math_utils.py
Normal file
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
from llama_stack.providers.utils.scoring.basic_scoring_utils import time_limit
|
||||
|
||||
# from minerva
|
||||
SUBSTITUTIONS = [
|
||||
("an ", ""),
|
||||
("a ", ""),
|
||||
(".$", "$"),
|
||||
("\\$", ""),
|
||||
(r"\ ", ""),
|
||||
(" ", ""),
|
||||
("mbox", "text"),
|
||||
(",\\text{and}", ","),
|
||||
("\\text{and}", ","),
|
||||
("\\text{m}", "\\text{}"),
|
||||
]
|
||||
|
||||
REMOVED_EXPRESSIONS = [
|
||||
"square",
|
||||
"ways",
|
||||
"integers",
|
||||
"dollars",
|
||||
"mph",
|
||||
"inches",
|
||||
"ft",
|
||||
"hours",
|
||||
"km",
|
||||
"units",
|
||||
"\\ldots",
|
||||
"sue",
|
||||
"points",
|
||||
"feet",
|
||||
"minutes",
|
||||
"digits",
|
||||
"cents",
|
||||
"degrees",
|
||||
"cm",
|
||||
"gm",
|
||||
"pounds",
|
||||
"meters",
|
||||
"meals",
|
||||
"edges",
|
||||
"students",
|
||||
"childrentickets",
|
||||
"multiples",
|
||||
"\\text{s}",
|
||||
"\\text{.}",
|
||||
"\\text{\ns}",
|
||||
"\\text{}^2",
|
||||
"\\text{}^3",
|
||||
"\\text{\n}",
|
||||
"\\text{}",
|
||||
r"\mathrm{th}",
|
||||
r"^\circ",
|
||||
r"^{\circ}",
|
||||
r"\;",
|
||||
r",\!",
|
||||
"{,}",
|
||||
'"',
|
||||
"\\dots",
|
||||
]
|
||||
|
||||
|
||||
def try_evaluate_frac(expression: str, fmt: str = "0.2e") -> str:
|
||||
if isinstance(expression, float):
|
||||
return expression
|
||||
new_expression = f"{expression}"
|
||||
regex = re.compile(r"\\frac{([^}]+)}{([^}]+)}")
|
||||
for match in re.finditer(regex, expression):
|
||||
try:
|
||||
value = float(match.group(1)) / float(match.group(2))
|
||||
new_expression = new_expression.replace(
|
||||
match.group(),
|
||||
f"{{value:{fmt}}}".format(value=value),
|
||||
1,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return new_expression
|
||||
|
||||
|
||||
def try_evaluate_latex(expression: str, fmt: str = ".2e") -> str:
|
||||
try:
|
||||
with time_limit(seconds=5):
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
value = parse_latex(expression).evalf() # type: ignore
|
||||
return f"{{value:{fmt}}}".format(value=value)
|
||||
except Exception:
|
||||
return expression
|
||||
|
||||
|
||||
def first_answer(text: str, markers: Sequence[str] = ("Q:", "A:")) -> str:
|
||||
for marker in markers:
|
||||
text = text.split(marker)[0]
|
||||
return text
|
||||
|
||||
|
||||
def extract_result_from_boxed(answer: str) -> str:
|
||||
box_start = "\\boxed"
|
||||
# format is `\\boxed <value>$` or `\\boxed{<value>}`, with potential white spaces framing `<value>`
|
||||
start = answer.rfind(box_start)
|
||||
if start < 0:
|
||||
return ""
|
||||
answer = answer[start + len(box_start) :].strip()
|
||||
ends_with_curly = answer.startswith("{")
|
||||
i = 0
|
||||
open_braces = 0
|
||||
while i < len(answer):
|
||||
if answer[i] == "{":
|
||||
open_braces += 1
|
||||
elif answer[i] == "}":
|
||||
open_braces -= 1
|
||||
if open_braces == 0:
|
||||
if ends_with_curly:
|
||||
answer = answer[: i + 1].strip()
|
||||
break
|
||||
elif answer[i] == "$":
|
||||
answer = answer[:i].strip()
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
return ""
|
||||
# remove extra curly braces
|
||||
while True:
|
||||
if answer.startswith("{") and answer.endswith("}"):
|
||||
answer = answer[1:-1].strip()
|
||||
else:
|
||||
break
|
||||
return answer
|
||||
|
||||
|
||||
# from minerva paper + _normalise_result from xavierm
|
||||
def normalize_final_answer(final_answer: str, regex_pattern: str, match_first: bool = True) -> str:
|
||||
"""Extract and normalize a final answer to a quantitative reasoning question."""
|
||||
match = re.findall(regex_pattern, final_answer)
|
||||
extraction: str
|
||||
if len(match) > 0:
|
||||
if match_first:
|
||||
extraction = match[0]
|
||||
else:
|
||||
extraction = match[-1]
|
||||
else:
|
||||
extraction = extract_result_from_boxed(final_answer)
|
||||
|
||||
if len(extraction) == 0:
|
||||
return final_answer
|
||||
else:
|
||||
final_answer = extraction
|
||||
final_answer = final_answer.split("=")[-1]
|
||||
for before, after in SUBSTITUTIONS:
|
||||
final_answer = final_answer.replace(before, after)
|
||||
for expr in REMOVED_EXPRESSIONS:
|
||||
final_answer = final_answer.replace(expr, "")
|
||||
# Extract answer that is in LaTeX math, is bold,
|
||||
# is surrounded by a box, etc.
|
||||
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
||||
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
||||
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
||||
# Normalize shorthand TeX:
|
||||
# \fracab -> \frac{a}{b}
|
||||
# \frac{abc}{bef} -> \frac{abc}{bef}
|
||||
# \fracabc -> \frac{a}{b}c
|
||||
# \sqrta -> \sqrt{a}
|
||||
# \sqrtab -> sqrt{a}b
|
||||
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
||||
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
||||
final_answer = final_answer.replace("$", "")
|
||||
# Normalize 100,000 -> 100000
|
||||
if final_answer.replace(",", "").isdigit():
|
||||
final_answer = final_answer.replace(",", "")
|
||||
# If the final answer is a single letter in parentheses, remove the parentheses
|
||||
# Example: (a) -> a (but not (ab) -> ab)
|
||||
if re.match(r"\([a-zA-Z]\)", final_answer):
|
||||
final_answer = final_answer[1]
|
||||
return _normalise_result(final_answer)
|
||||
|
||||
|
||||
def _normalise_result(string: str) -> str:
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("cfrac", "frac")
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\le", "")
|
||||
string = string.replace("\\right", "")
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
string = string.split("=")[-1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||||
try:
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
except AssertionError:
|
||||
return string
|
||||
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if len(split) == 0:
|
||||
return string
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if len(substr) == 0:
|
||||
return string
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except AssertionError:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a = string.split("/")[0]
|
||||
b = string.split("/")[1]
|
||||
try:
|
||||
ia = int(a)
|
||||
ib = int(b)
|
||||
assert string == "{}/{}".format(ia, ib)
|
||||
new_string = "\\frac{" + str(ia) + "}{" + str(ib) + "}"
|
||||
return new_string
|
||||
except (ValueError, AssertionError):
|
||||
return string
|
||||
|
|
@ -3,11 +3,11 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
|
@ -18,7 +18,7 @@ class BraintrustProviderDataValidator(BaseModel):
|
|||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .braintrust import BraintrustScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import LlmAsJudgeScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: LlmAsJudgeScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
deps: Dict[Api, Any],
|
||||
):
|
||||
from .scoring import LlmAsJudgeScoringImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import CodeInterpreterToolConfig
|
||||
|
||||
__all__ = ["CodeInterpreterToolConfig", "CodeInterpreterToolRuntimeImpl"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps):
|
||||
async def get_provider_impl(config: CodeInterpreterToolConfig, _deps: Dict[str, Any]):
|
||||
from .code_interpreter import CodeInterpreterToolRuntimeImpl
|
||||
|
||||
impl = CodeInterpreterToolRuntimeImpl(config)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import ChromaVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import (
|
||||
ChromaVectorIOAdapter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import FaissVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .faiss import FaissVectorIOAdapter
|
||||
|
||||
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
|
@ -99,7 +100,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
await self._save_index()
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
distances, indices = self.index.search(embedding.reshape(1, -1).astype(np.float32), k)
|
||||
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import MilvusVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
from .config import SQLiteVectorIOConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]):
|
||||
async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, Any]):
|
||||
from .sqlite_vec import SQLiteVecVectorIOAdapter
|
||||
|
||||
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue