mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge branch 'main' into evals_5
This commit is contained in:
commit
caf253e08f
5 changed files with 168 additions and 6 deletions
|
@ -70,10 +70,10 @@ class TogetherInferenceAdapter(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
response_format: Optional[ResponseFormat] = None,
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -96,6 +96,7 @@ class TogetherInferenceAdapter(
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format,
|
||||||
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -130,11 +131,23 @@ class TogetherInferenceAdapter(
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
options = get_sampling_options(request)
|
||||||
|
if fmt := request.response_format:
|
||||||
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
options["response_format"] = {
|
||||||
|
"type": "json_object",
|
||||||
|
"schema": fmt.schema,
|
||||||
|
}
|
||||||
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
|
raise NotImplementedError("Grammar response format not supported yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": self.map_to_provider_model(request.model),
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**get_sampling_options(request),
|
**options,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
148
llama_stack/providers/tests/agents/test_agent_persistence.py
Normal file
148
llama_stack/providers/tests/agents/test_agent_persistence.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# 1. Ensure you have a conda environment with the right dependencies installed.
|
||||||
|
# This includes `pytest` and `pytest-asyncio`.
|
||||||
|
#
|
||||||
|
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||||
|
#
|
||||||
|
# 3. Run:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# PROVIDER_ID=<your_provider> \
|
||||||
|
# PROVIDER_CONFIG=provider_config.yaml \
|
||||||
|
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
|
||||||
|
# --tb=short --disable-warnings
|
||||||
|
# ```
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def agents_settings():
|
||||||
|
impls = await resolve_impls_for_test(
|
||||||
|
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"impl": impls[Api.agents],
|
||||||
|
"memory_impl": impls[Api.memory],
|
||||||
|
"common_params": {
|
||||||
|
"model": "Llama3.1-8B-Instruct",
|
||||||
|
"instructions": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(content="What's the weather like today?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
|
||||||
|
agents_impl = agents_settings["impl"]
|
||||||
|
# First, create an agent
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
persistence_store = await kvstore_impl(agents_settings["persistence"])
|
||||||
|
|
||||||
|
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||||
|
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||||
|
|
||||||
|
await agents_impl.delete_agents(agent_id)
|
||||||
|
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||||
|
|
||||||
|
assert session_response is None
|
||||||
|
assert agent_response is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
|
||||||
|
agents_impl = agents_settings["impl"]
|
||||||
|
|
||||||
|
# First, create an agent
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
turn_id = final_event.turn.turn_id
|
||||||
|
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
|
||||||
|
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||||
|
|
||||||
|
assert isinstance(response, Turn)
|
||||||
|
assert response == final_event.turn
|
||||||
|
assert turn == final_event.turn
|
||||||
|
|
||||||
|
steps = final_event.turn.steps
|
||||||
|
step_id = steps[0].step_id
|
||||||
|
step_response = await agents_impl.get_agents_step(
|
||||||
|
agent_id, session_id, turn_id, step_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(step_response.step, Step)
|
||||||
|
assert step_response.step == steps[0]
|
|
@ -195,6 +195,7 @@ async def test_structured_output(inference_settings):
|
||||||
"meta-reference",
|
"meta-reference",
|
||||||
"remote::fireworks",
|
"remote::fireworks",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
|
"remote::together",
|
||||||
):
|
):
|
||||||
pytest.skip("Other inference providers don't support structured output yet")
|
pytest.skip("Other inference providers don't support structured output yet")
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,10 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
from llama_stack.inference.api import * # noqa: F403
|
from llama_stack.apis.inference.inference import * # noqa: F403
|
||||||
from llama_stack.inference.prompt_adapter import chat_completion_request_to_messages
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_messages,
|
||||||
|
)
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
|
@ -34,8 +34,6 @@ def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||||
if params := request.sampling_params:
|
if params := request.sampling_params:
|
||||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||||
if getattr(params, attr):
|
if getattr(params, attr):
|
||||||
if attr == "max_tokens":
|
|
||||||
options["num_predict"] = getattr(params, attr)
|
|
||||||
options[attr] = getattr(params, attr)
|
options[attr] = getattr(params, attr)
|
||||||
|
|
||||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue