mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
fix tests
This commit is contained in:
parent
d9d827ff25
commit
0edd3ce78b
3 changed files with 7 additions and 8 deletions
|
@ -7,6 +7,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
|
||||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
|
@ -22,12 +23,8 @@ from llama_stack.apis.agents import (
|
||||||
ToolExecutionStep,
|
ToolExecutionStep,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
CompletionMessage,
|
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||||
SamplingParams,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import ViolationLevel
|
from llama_stack.apis.safety import ViolationLevel
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
|
||||||
from .utils import group_chunks
|
from .utils import group_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@ -476,7 +477,7 @@ class TestInference:
|
||||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||||
# assert last.event.stop_reason == expected_stop_reason
|
# assert last.event.stop_reason == expected_stop_reason
|
||||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||||
assert last.event.delta.content.type == "tool_call"
|
assert isinstance(last.event.delta.content, ToolCall)
|
||||||
|
|
||||||
call = last.event.delta.content
|
call = last.event.delta.content
|
||||||
assert call.tool_name == "get_weather"
|
assert call.tool_name == "get_weather"
|
||||||
|
|
|
@ -74,7 +74,8 @@ def get_sampling_options(params: SamplingParams) -> dict:
|
||||||
options = {}
|
options = {}
|
||||||
if params:
|
if params:
|
||||||
options.update(get_sampling_strategy_options(params))
|
options.update(get_sampling_strategy_options(params))
|
||||||
options["max_tokens"] = params.max_tokens
|
if params.max_tokens:
|
||||||
|
options["max_tokens"] = params.max_tokens
|
||||||
|
|
||||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||||
options["repeat_penalty"] = params.repetition_penalty
|
options["repeat_penalty"] = params.repetition_penalty
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue