Merge branch 'main' into vllm

This commit is contained in:
Fred Reiss 2025-01-08 15:47:58 -08:00 committed by GitHub
commit 73fede90a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
175 changed files with 7948 additions and 876 deletions

View file

@ -7,13 +7,32 @@
import pytest
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from pydantic import BaseModel, ValidationError
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
LogProbConfig,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
UserMessage,
)
from llama_stack.apis.models import Model
from .utils import group_chunks
@ -193,6 +212,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::tgi",
"remote::together",
"remote::fireworks",
@ -255,6 +275,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::ollama",
"remote::fireworks",
"remote::tgi",
"remote::together",
@ -352,6 +373,14 @@ class TestInference:
sample_messages,
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
@ -392,6 +421,13 @@ class TestInference:
sample_tool_definition,
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type in ("remote::groq",):
pytest.skip(
provider.__provider_spec__.provider_type
+ " doesn't support tool calling yet"
)
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",