fix tests

This commit is contained in:
Hardik Shah 2025-01-14 17:38:10 -08:00 committed by Ashwin Bharambe
parent d9d827ff25
commit 0edd3ce78b
3 changed files with 7 additions and 8 deletions

View file

@ -7,6 +7,7 @@
import os import os
import pytest import pytest
from llama_models.datatypes import SamplingParams, TopPSamplingStrategy
from llama_models.llama3.api.datatypes import BuiltinTool from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
@ -22,12 +23,8 @@ from llama_stack.apis.agents import (
ToolExecutionStep, ToolExecutionStep,
Turn, Turn,
) )
from llama_stack.apis.inference import (
CompletionMessage, from llama_stack.apis.inference import CompletionMessage, UserMessage
SamplingParams,
TopPSamplingStrategy,
UserMessage,
)
from llama_stack.apis.safety import ViolationLevel from llama_stack.apis.safety import ViolationLevel
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api

View file

@ -32,6 +32,7 @@ from llama_stack.apis.inference import (
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from .utils import group_chunks from .utils import group_chunks
@ -476,7 +477,7 @@ class TestInference:
last = grouped[ChatCompletionResponseEventType.progress][-1] last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason # assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert last.event.delta.content.type == "tool_call" assert isinstance(last.event.delta.content, ToolCall)
call = last.event.delta.content call = last.event.delta.content
assert call.tool_name == "get_weather" assert call.tool_name == "get_weather"

View file

@ -74,7 +74,8 @@ def get_sampling_options(params: SamplingParams) -> dict:
options = {} options = {}
if params: if params:
options.update(get_sampling_strategy_options(params)) options.update(get_sampling_strategy_options(params))
options["max_tokens"] = params.max_tokens if params.max_tokens:
options["max_tokens"] = params.max_tokens
if params.repetition_penalty is not None and params.repetition_penalty != 1.0: if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty options["repeat_penalty"] = params.repetition_penalty