fix tests

This commit is contained in:
Hardik Shah 2025-01-14 17:38:10 -08:00 committed by Ashwin Bharambe
parent d9d827ff25
commit 0edd3ce78b
3 changed files with 7 additions and 8 deletions

View file

@ -7,6 +7,7 @@
import os
import pytest
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import (
@ -22,12 +23,8 @@ from llama_stack.apis.agents import (
ToolExecutionStep,
Turn,
)
from llama_stack.apis.inference import (
CompletionMessage,
SamplingParams,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks
@ -476,7 +477,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert last.event.delta.content.type == "tool_call"
assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content
assert call.tool_name == "get_weather"

View file

@ -74,6 +74,7 @@ def get_sampling_options(params: SamplingParams) -> dict:
options = {}
if params:
options.update(get_sampling_strategy_options(params))
if params.max_tokens:
options["max_tokens"] = params.max_tokens
if params.repetition_penalty is not None and params.repetition_penalty != 1.0: