mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: clean up test imports (#1600)
# What does this PR do?
- Clean up dead SDK code in
https://github.com/meta-llama/llama-stack-client-python/pull/198
- Regen for local cache key issue
[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])
## Test Plan
```
pytest -v -s --nbval-lax ./docs/getting_started.ipynb
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/ --text-model meta-llama/Llama-3.3-70B-Instruct
```
- CI:
1382351211
<img width="1658" alt="image"
src="https://github.com/user-attachments/assets/1a2de383-35a2-47a0-8d80-d666d4970c34"
/>
[//]: # (## Documentation)
This commit is contained in:
parent
5e54113b19
commit
98811cc034
3 changed files with 35 additions and 13 deletions
|
@ -10,8 +10,7 @@ from uuid import uuid4
|
|||
import pytest
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
from llama_stack_client.types.agents.turn_create_params import Document
|
||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||
|
||||
from llama_stack.apis.agents.agents import (
|
||||
|
@ -242,7 +241,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen
|
|||
|
||||
codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
|
||||
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
|
||||
inflation_doc = AgentDocument(
|
||||
inflation_doc = Document(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
mime_type="text/csv",
|
||||
)
|
||||
|
|
|
@ -9,11 +9,25 @@ import mimetypes
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_for_test(llama_stack_client):
|
||||
dataset_id = "test_dataset"
|
||||
register_dataset(llama_stack_client, dataset_id=dataset_id)
|
||||
yield
|
||||
# Teardown - this always runs, even if the test fails
|
||||
try:
|
||||
llama_stack_client.datasets.unregister(dataset_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to unregister test_dataset: {e}")
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
@ -80,8 +94,7 @@ def test_register_unregister_dataset(llama_stack_client):
|
|||
assert len(response) == 0
|
||||
|
||||
|
||||
def test_get_rows_paginated(llama_stack_client):
|
||||
register_dataset(llama_stack_client)
|
||||
def test_get_rows_paginated(llama_stack_client, dataset_for_test):
|
||||
response = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
|
|
|
@ -10,6 +10,19 @@ import pytest
|
|||
from ..datasetio.test_datasetio import register_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_dataset_for_test(llama_stack_client):
|
||||
dataset_id = "test_dataset"
|
||||
register_dataset(llama_stack_client, for_rag=True, dataset_id=dataset_id)
|
||||
yield # This is where the test function will run
|
||||
|
||||
# Teardown - this always runs, even if the test fails
|
||||
try:
|
||||
llama_stack_client.datasets.unregister(dataset_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to unregister test_dataset: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_judge_prompt_template():
|
||||
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
|
||||
|
@ -79,9 +92,7 @@ def test_scoring_functions_register(
|
|||
# TODO: add unregister api for scoring functions
|
||||
|
||||
|
||||
def test_scoring_score(llama_stack_client):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
|
||||
def test_scoring_score(llama_stack_client, rag_dataset_for_test):
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
@ -115,9 +126,9 @@ def test_scoring_score(llama_stack_client):
|
|||
assert len(response.results[x].score_rows) == 5
|
||||
|
||||
|
||||
def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
|
||||
def test_scoring_score_with_params_llm_as_judge(
|
||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, rag_dataset_for_test
|
||||
):
|
||||
# scoring individual rows
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
|
@ -167,9 +178,8 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge
|
|||
],
|
||||
)
|
||||
def test_scoring_score_with_aggregation_functions(
|
||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id
|
||||
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id, rag_dataset_for_test
|
||||
):
|
||||
register_dataset(llama_stack_client, for_rag=True)
|
||||
rows = llama_stack_client.datasetio.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=3,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue