mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Add inference test
Run it as: ``` PROVIDER_ID=test-remote \ PROVIDER_CONFIG=$PWD/llama_stack/providers/tests/inference/provider_config_example.yaml \ pytest -s llama_stack/providers/tests/inference/test_inference.py \ --tb=auto \ --disable-warnings ```
This commit is contained in:
parent
4fa467731e
commit
3ae2b712e8
8 changed files with 356 additions and 54 deletions
|
@ -67,6 +67,7 @@ class InferenceClient(Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
if stream:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.base_url}/inference/chat_completion",
|
f"{self.base_url}/inference/chat_completion",
|
||||||
|
@ -77,7 +78,8 @@ class InferenceClient(Inference):
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
cprint(
|
cprint(
|
||||||
f"Error: HTTP {response.status_code} {content.decode()}", "red"
|
f"Error: HTTP {response.status_code} {content.decode()}",
|
||||||
|
"red",
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -85,7 +87,6 @@ class InferenceClient(Inference):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
if request.stream:
|
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
cprint(data, "red")
|
cprint(data, "red")
|
||||||
continue
|
continue
|
||||||
|
@ -93,11 +94,20 @@ class InferenceClient(Inference):
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
**json.loads(data)
|
**json.loads(data)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
yield ChatCompletionResponse(**json.loads(data))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
else:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/inference/chat_completion",
|
||||||
|
json=encodable_dict(request),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
j = response.json()
|
||||||
|
yield ChatCompletionResponse(**j)
|
||||||
|
|
||||||
|
|
||||||
async def run_main(
|
async def run_main(
|
||||||
|
|
|
@ -109,6 +109,7 @@ this could be just a hash
|
||||||
description="Reference to the conda environment if this package refers to a conda environment",
|
description="Reference to the conda environment if this package refers to a conda environment",
|
||||||
)
|
)
|
||||||
apis: List[str] = Field(
|
apis: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
description="""
|
description="""
|
||||||
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
|
||||||
)
|
)
|
||||||
|
|
|
@ -36,8 +36,8 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=OLLAMA_SUPPORTED_MODELS
|
||||||
)
|
)
|
||||||
self.url = url
|
self.url = url
|
||||||
tokenizer = Tokenizer.get_instance()
|
self.tokenizer = Tokenizer.get_instance()
|
||||||
self.formatter = ChatFormat(tokenizer)
|
self.formatter = ChatFormat(self.tokenizer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> AsyncClient:
|
def client(self) -> AsyncClient:
|
||||||
|
@ -65,17 +65,6 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
|
|
||||||
ollama_messages = []
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "ipython":
|
|
||||||
role = "tool"
|
|
||||||
else:
|
|
||||||
role = message.role
|
|
||||||
ollama_messages.append({"role": role, "content": message.content})
|
|
||||||
|
|
||||||
return ollama_messages
|
|
||||||
|
|
||||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
if request.sampling_params is not None:
|
if request.sampling_params is not None:
|
||||||
|
@ -113,6 +102,9 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = augment_messages_for_tools(request)
|
messages = augment_messages_for_tools(request)
|
||||||
|
model_input = self.formatter.encode_dialog_prompt(messages)
|
||||||
|
prompt = self.tokenizer.decode(model_input.tokens)
|
||||||
|
|
||||||
# accumulate sampling params and other options to pass to ollama
|
# accumulate sampling params and other options to pass to ollama
|
||||||
options = self.get_ollama_chat_options(request)
|
options = self.get_ollama_chat_options(request)
|
||||||
ollama_model = self.map_to_provider_model(request.model)
|
ollama_model = self.map_to_provider_model(request.model)
|
||||||
|
@ -131,13 +123,16 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
status["status"] == "success"
|
status["status"] == "success"
|
||||||
), f"Failed to pull model {self.model} in ollama"
|
), f"Failed to pull model {self.model} in ollama"
|
||||||
|
|
||||||
|
common_params = {
|
||||||
|
"model": ollama_model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"options": options,
|
||||||
|
"raw": True,
|
||||||
|
"stream": request.stream,
|
||||||
|
}
|
||||||
|
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
r = await self.client.chat(
|
r = await self.client.generate(**common_params)
|
||||||
model=ollama_model,
|
|
||||||
messages=self._messages_to_ollama_messages(messages),
|
|
||||||
stream=False,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
if r["done"]:
|
if r["done"]:
|
||||||
if r["done_reason"] == "stop":
|
if r["done_reason"] == "stop":
|
||||||
|
@ -146,7 +141,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
r["message"]["content"], stop_reason
|
r["response"], stop_reason
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
completion_message=completion_message,
|
completion_message=completion_message,
|
||||||
|
@ -159,12 +154,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
delta="",
|
delta="",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
stream = await self.client.chat(
|
stream = await self.client.generate(**common_params)
|
||||||
model=ollama_model,
|
|
||||||
messages=self._messages_to_ollama_messages(messages),
|
|
||||||
stream=True,
|
|
||||||
options=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
ipython = False
|
ipython = False
|
||||||
|
@ -178,8 +168,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
text = chunk["message"]["content"]
|
text = chunk["response"]
|
||||||
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
if not ipython and text.startswith("<|python_tag|>"):
|
||||||
ipython = True
|
ipython = True
|
||||||
|
|
|
@ -100,8 +100,6 @@ class _HfAdapter(Inference):
|
||||||
self.max_tokens - input_tokens - 1,
|
self.max_tokens - input_tokens - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Calculated max_new_tokens: {max_new_tokens}")
|
|
||||||
|
|
||||||
options = self.get_chat_options(request)
|
options = self.get_chat_options(request)
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
response = await self.client.text_generation(
|
response = await self.client.text_generation(
|
||||||
|
@ -119,8 +117,9 @@ class _HfAdapter(Inference):
|
||||||
elif response.details.finish_reason == "length":
|
elif response.details.finish_reason == "length":
|
||||||
stop_reason = StopReason.out_of_tokens
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
|
generated_text = "".join(t.text for t in response.details.tokens)
|
||||||
completion_message = self.formatter.decode_assistant_message_from_content(
|
completion_message = self.formatter.decode_assistant_message_from_content(
|
||||||
response.generated_text,
|
generated_text,
|
||||||
stop_reason,
|
stop_reason,
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponse(
|
yield ChatCompletionResponse(
|
||||||
|
|
5
llama_stack/providers/tests/__init__.py
Normal file
5
llama_stack/providers/tests/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
5
llama_stack/providers/tests/inference/__init__.py
Normal file
5
llama_stack/providers/tests/inference/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -0,0 +1,15 @@
|
||||||
|
providers:
|
||||||
|
- provider_id: test-ollama
|
||||||
|
provider_type: remote::ollama
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 11434
|
||||||
|
- provider_id: test-tgi
|
||||||
|
provider_type: remote::tgi
|
||||||
|
config:
|
||||||
|
url: http://localhost:7001
|
||||||
|
- provider_id: test-remote
|
||||||
|
provider_type: remote
|
||||||
|
config:
|
||||||
|
host: localhost
|
||||||
|
port: 7002
|
278
llama_stack/providers/tests/inference/test_inference.py
Normal file
278
llama_stack/providers/tests/inference/test_inference.py
Normal file
|
@ -0,0 +1,278 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
|
from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
|
|
||||||
|
|
||||||
|
def group_chunks(response):
|
||||||
|
return {
|
||||||
|
event_type: list(group)
|
||||||
|
for event_type, group in itertools.groupby(
|
||||||
|
response, key=lambda chunk: chunk.event.event_type
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Llama_8B = "Llama3.1-8B-Instruct"
|
||||||
|
Llama_3B = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
def get_expected_stop_reason(model: str):
|
||||||
|
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
|
async def stack_impls(model):
|
||||||
|
if "PROVIDER_CONFIG" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if "providers" not in config_dict:
|
||||||
|
raise ValueError("Config file should contain a `providers` key")
|
||||||
|
|
||||||
|
providers_by_id = {x["provider_id"]: x for x in config_dict["providers"]}
|
||||||
|
if len(providers_by_id) == 0:
|
||||||
|
raise ValueError("No providers found in config file")
|
||||||
|
|
||||||
|
if "PROVIDER_ID" in os.environ:
|
||||||
|
provider_id = os.environ["PROVIDER_ID"]
|
||||||
|
if provider_id not in providers_by_id:
|
||||||
|
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||||
|
provider = providers_by_id[provider_id]
|
||||||
|
else:
|
||||||
|
provider = list(providers_by_id.values())[0]
|
||||||
|
print(f"No provider ID specified, picking first {provider['provider_id']}")
|
||||||
|
|
||||||
|
config_dict = dict(
|
||||||
|
built_at=datetime.now(),
|
||||||
|
image_name="test-fixture",
|
||||||
|
apis=[
|
||||||
|
Api.inference,
|
||||||
|
Api.models,
|
||||||
|
],
|
||||||
|
providers=dict(
|
||||||
|
inference=[
|
||||||
|
Provider(**provider),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
models=[
|
||||||
|
ModelDef(
|
||||||
|
identifier=model,
|
||||||
|
llama_model=model,
|
||||||
|
provider_id=provider["provider_id"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
shields=[],
|
||||||
|
memory_banks=[],
|
||||||
|
)
|
||||||
|
run_config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
impls = await resolve_impls_with_routing(run_config)
|
||||||
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
# This is going to create multiple Stack impls without tearing down the previous one
|
||||||
|
# Fix that!
|
||||||
|
@pytest_asyncio.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=[
|
||||||
|
{"model": Llama_8B},
|
||||||
|
{"model": Llama_3B},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def inference_settings(request):
|
||||||
|
model = request.param["model"]
|
||||||
|
impls = await stack_impls(model)
|
||||||
|
return {
|
||||||
|
"impl": impls[Api.inference],
|
||||||
|
"common_params": {
|
||||||
|
"model": model,
|
||||||
|
"tool_choice": ToolChoice.auto,
|
||||||
|
"tool_prompt_format": (
|
||||||
|
ToolPromptFormat.json
|
||||||
|
if "Llama3.1" in model
|
||||||
|
else ToolPromptFormat.python_list
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_messages():
|
||||||
|
return [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
UserMessage(content="What's the weather like today?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tool_definition():
|
||||||
|
return ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get the current weather",
|
||||||
|
parameters={
|
||||||
|
"location": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The city and state, e.g. San Francisco, CA",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=False,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) == 1
|
||||||
|
assert isinstance(response[0], ChatCompletionResponse)
|
||||||
|
assert response[0].completion_message.role == "assistant"
|
||||||
|
assert isinstance(response[0].completion_message.content, str)
|
||||||
|
assert len(response[0].completion_message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_streaming(inference_settings, sample_messages):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||||
|
)
|
||||||
|
grouped = group_chunks(response)
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
|
assert end.event.stop_reason == StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_tool_calling(
|
||||||
|
inference_settings,
|
||||||
|
sample_messages,
|
||||||
|
sample_tool_definition,
|
||||||
|
):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
messages = sample_messages + [
|
||||||
|
UserMessage(
|
||||||
|
content="What's the weather like in San Francisco?",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
tools=[sample_tool_definition],
|
||||||
|
stream=False,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) == 1
|
||||||
|
assert isinstance(response[0], ChatCompletionResponse)
|
||||||
|
|
||||||
|
message = response[0].completion_message
|
||||||
|
|
||||||
|
stop_reason = get_expected_stop_reason(inference_settings["common_params"]["model"])
|
||||||
|
assert message.stop_reason == stop_reason
|
||||||
|
assert message.tool_calls is not None
|
||||||
|
assert len(message.tool_calls) > 0
|
||||||
|
|
||||||
|
call = message.tool_calls[0]
|
||||||
|
assert call.tool_name == "get_weather"
|
||||||
|
assert "location" in call.arguments
|
||||||
|
assert "San Francisco" in call.arguments["location"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_completion_with_tool_calling_streaming(
|
||||||
|
inference_settings,
|
||||||
|
sample_messages,
|
||||||
|
sample_tool_definition,
|
||||||
|
):
|
||||||
|
inference_impl = inference_settings["impl"]
|
||||||
|
messages = sample_messages + [
|
||||||
|
UserMessage(
|
||||||
|
content="What's the weather like in San Francisco?",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in inference_impl.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
tools=[sample_tool_definition],
|
||||||
|
stream=True,
|
||||||
|
**inference_settings["common_params"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||||
|
)
|
||||||
|
grouped = group_chunks(response)
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
|
end = grouped[ChatCompletionResponseEventType.complete][0]
|
||||||
|
expected_stop_reason = get_expected_stop_reason(
|
||||||
|
inference_settings["common_params"]["model"]
|
||||||
|
)
|
||||||
|
assert end.event.stop_reason == expected_stop_reason
|
||||||
|
|
||||||
|
model = inference_settings["common_params"]["model"]
|
||||||
|
if "Llama3.1" in model:
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk.event.delta, ToolCallDelta)
|
||||||
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
|
)
|
||||||
|
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||||
|
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||||
|
|
||||||
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
|
assert last.event.stop_reason == expected_stop_reason
|
||||||
|
assert last.event.delta.parse_status == ToolCallParseStatus.success
|
||||||
|
assert isinstance(last.event.delta.content, ToolCall)
|
||||||
|
|
||||||
|
call = last.event.delta.content
|
||||||
|
assert call.tool_name == "get_weather"
|
||||||
|
assert "location" in call.arguments
|
||||||
|
assert "San Francisco" in call.arguments["location"]
|
Loading…
Add table
Add a link
Reference in a new issue