mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
fix tests
This commit is contained in:
parent
d9d827ff25
commit
0edd3ce78b
3 changed files with 7 additions and 8 deletions
|
@ -7,6 +7,7 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
|
||||
from llama_models.llama3.api.datatypes import BuiltinTool
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
|
@ -22,12 +23,8 @@ from llama_stack.apis.agents import (
|
|||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.safety import ViolationLevel
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
||||
|
@ -476,7 +477,7 @@ class TestInference:
|
|||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
# assert last.event.stop_reason == expected_stop_reason
|
||||
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
|
||||
assert last.event.delta.content.type == "tool_call"
|
||||
assert isinstance(last.event.delta.content, ToolCall)
|
||||
|
||||
call = last.event.delta.content
|
||||
assert call.tool_name == "get_weather"
|
||||
|
|
|
@ -74,6 +74,7 @@ def get_sampling_options(params: SamplingParams) -> dict:
|
|||
options = {}
|
||||
if params:
|
||||
options.update(get_sampling_strategy_options(params))
|
||||
if params.max_tokens:
|
||||
options["max_tokens"] = params.max_tokens
|
||||
|
||||
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue