mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 04:32:01 +00:00
Merge branch 'main' into fix/nvidia-launch-customization
This commit is contained in:
commit
6659ed995a
53 changed files with 2203 additions and 217 deletions
|
|
@ -75,19 +75,24 @@ def openai_client(client_with_models):
|
|||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||
def compat_client(request):
|
||||
return request.getfixturevalue(request.param)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
|
|
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
|||
0,
|
||||
],
|
||||
)
|
||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
||||
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "Hello, world!"
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
|||
assert len(choice.prompt_logprobs) > 0
|
||||
|
||||
|
||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
||||
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
|
||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||
|
||||
prompt = "I am feeling really sad today."
|
||||
response = openai_client.completions.create(
|
||||
response = llama_stack_client.completions.create(
|
||||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
assert choice.text in ["joy", "sadness"]
|
||||
|
||||
|
||||
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[
|
||||
{
|
||||
|
|
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
|
|||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
response = compat_client.chat.completions.create(
|
||||
model=text_model_id,
|
||||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
|
|
|
|||
|
|
@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
SystemMessage,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||
VLLMInferenceAdapter,
|
||||
|
|
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
assert request.tool_config.tool_choice == ToolChoice.none
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_response(vllm_inference_adapter):
|
||||
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||
into the expected JSON format."""
|
||||
|
||||
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||
with patch.object(
|
||||
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_nonstream_completion:
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
UserMessage(content="How many?"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="foo",
|
||||
tool_name="knowledge_search",
|
||||
arguments={"query": "How many?"},
|
||||
arguments_json='{"query": "How many?"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||
]
|
||||
await vllm_inference_adapter.chat_completion(
|
||||
"mock-model",
|
||||
messages,
|
||||
stream=False,
|
||||
tools=[],
|
||||
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||
)
|
||||
|
||||
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||
{
|
||||
"id": "foo",
|
||||
"type": "function",
|
||||
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_delta_empty_tool_call_buf():
|
||||
"""
|
||||
|
|
|
|||
201
tests/unit/providers/nvidia/test_eval.py
Normal file
201
tests/unit/providers/nvidia/test_eval.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||
|
||||
MOCK_DATASET_ID = "default/test-dataset"
|
||||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
||||
|
||||
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
|
||||
# Create mock APIs
|
||||
self.datasetio_api = MagicMock()
|
||||
self.datasets_api = MagicMock()
|
||||
self.scoring_api = MagicMock()
|
||||
self.inference_api = MagicMock()
|
||||
self.agents_api = MagicMock()
|
||||
|
||||
self.config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
|
||||
self.eval_impl = NVIDIAEvalImpl(
|
||||
config=self.config,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
scoring_api=self.scoring_api,
|
||||
inference_api=self.inference_api,
|
||||
agents_api=self.agents_api,
|
||||
)
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.evaluator_get_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||
)
|
||||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = self.mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_register_benchmark(self):
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Run the Evaluation job
|
||||
result = self.run_async(
|
||||
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
|
||||
def test_job_status(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||
|
||||
# Get the Evaluation job
|
||||
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
def test_job_cancel(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Cancel the Evaluation job
|
||||
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
|
||||
def test_job_result(self):
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
self.mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert self.mock_evaluator_get.call_count == 2
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
|
|
@ -11,6 +11,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.post_training.post_training import (
|
||||
DataConfig,
|
||||
DatasetFormat,
|
||||
|
|
@ -21,6 +22,7 @@ from llama_stack.apis.post_training.post_training import (
|
|||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
|
@ -44,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
|
|
@ -316,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
|
@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected():
|
|||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
# Start reading the events, ensuring this doesn't raise an exception
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should see 1 event before the client disconnected
|
||||
assert len(seen_events) == 1
|
||||
assert seen_events[0] == create_sse_event("Test event 1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||
# Disconnect before the response starts
|
||||
async def async_event_gen():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# No events should be seen since the client disconnected immediately
|
||||
assert len(seen_events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_generator_error_before_response_starts():
|
||||
# Raise an error before the response starts
|
||||
async def async_event_gen():
|
||||
raise Exception("Test error")
|
||||
|
||||
sse_gen = sse_generator(async_event_gen())
|
||||
assert sse_gen is not None
|
||||
|
||||
seen_events = []
|
||||
async for event in sse_gen:
|
||||
seen_events.append(event)
|
||||
|
||||
# We should have 1 error event
|
||||
assert len(seen_events) == 1
|
||||
assert 'data: {"error":' in seen_events[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue