From 8af695110697d5e9a611e71cb80b0289f25bc6d8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 10 Jan 2025 10:41:53 -0800 Subject: [PATCH] remove conflicting default for tool prompt format in chat completion (#742) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? We are setting a default value of json for tool prompt format, which conflicts with llama 3.2/3.3 models since they use python list. This PR changes the defaults to None and in the code, we infer default based on the model. Addresses: #695 Tests: ❯ LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v tests/client-sdk/inference/test_inference.py -k "test_text_chat_completion" pytest llama_stack/providers/tests/inference/test_prompt_adapter.py --- docs/resources/llama-stack-spec.html | 8 -------- docs/resources/llama-stack-spec.yaml | 6 ------ .../apis/batch_inference/batch_inference.py | 7 ++----- llama_stack/apis/inference/inference.py | 15 +++------------ llama_stack/apis/tools/tools.py | 7 ------- llama_stack/distribution/routers/routers.py | 2 +- .../distribution/routers/routing_tables.py | 1 - .../inline/inference/meta_reference/inference.py | 4 +--- .../sentence_transformers.py | 3 ++- .../providers/inline/inference/vllm/vllm.py | 6 +----- .../providers/remote/inference/bedrock/bedrock.py | 5 +---- .../remote/inference/cerebras/cerebras.py | 7 +------ .../remote/inference/databricks/databricks.py | 7 +------ .../remote/inference/fireworks/fireworks.py | 4 +--- .../providers/remote/inference/groq/groq.py | 5 ++--- .../providers/remote/inference/nvidia/nvidia.py | 4 +--- .../providers/remote/inference/ollama/ollama.py | 4 +--- llama_stack/providers/remote/inference/tgi/tgi.py | 4 +--- .../remote/inference/together/together.py | 2 +- .../providers/remote/inference/vllm/vllm.py | 5 +---- .../providers/utils/inference/prompt_adapter.py | 12 ++++++------ 21 files changed, 27 insertions(+), 91 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 7ace983f8..0ce216479 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -4503,10 +4503,6 @@ } ] } - }, - "tool_prompt_format": { - "$ref": "#/components/schemas/ToolPromptFormat", - "default": "json" } }, "additionalProperties": false, @@ -6522,10 +6518,6 @@ } ] } - }, - "tool_prompt_format": { - "$ref": "#/components/schemas/ToolPromptFormat", - "default": "json" } }, "additionalProperties": false, diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index a2f6bc005..031178ce9 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -2600,9 +2600,6 @@ components: type: string tool_host: $ref: '#/components/schemas/ToolHost' - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json toolgroup_id: type: string type: @@ -2704,9 +2701,6 @@ components: items: $ref: '#/components/schemas/ToolParameter' type: array - tool_prompt_format: - $ref: '#/components/schemas/ToolPromptFormat' - default: json required: - name type: object diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py index f7b8b4387..81826a7b1 100644 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ b/llama_stack/apis/batch_inference/batch_inference.py @@ -7,7 +7,6 @@ from typing import List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod - from pydantic import BaseModel, Field from llama_stack.apis.inference import ( @@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel): # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.json - ) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) logprobs: Optional[LogProbConfig] = None @@ -75,6 +72,6 @@ class BatchInference(Protocol): # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = list, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, logprobs: Optional[LogProbConfig] = None, ) -> BatchChatCompletionResponse: ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e48042091..a6a096041 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from enum import Enum - from typing import ( Any, AsyncIterator, @@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import ( ToolDefinition, ToolPromptFormat, ) - from llama_models.schema_utils import json_schema_type, register_schema, webmethod - from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated from llama_stack.apis.common.content_types import InterleavedContent - from llama_stack.apis.models import Model - from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel): # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.json - ) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) response_format: Optional[ResponseFormat] = None stream: Optional[bool] = False @@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel): # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) - tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.json - ) + tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None) logprobs: Optional[LogProbConfig] = None @@ -334,7 +325,7 @@ class Inference(Protocol): # zero-shot tool definitions as input to the model tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index e430ec46d..d2bdf9873 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -7,7 +7,6 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from llama_models.llama3.api.datatypes import ToolPromptFormat from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable @@ -41,9 +40,6 @@ class Tool(Resource): description: str parameters: List[ToolParameter] metadata: Optional[Dict[str, Any]] = None - tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.json - ) @json_schema_type @@ -52,9 +48,6 @@ class ToolDef(BaseModel): description: Optional[str] = None parameters: Optional[List[ToolParameter]] = None metadata: Optional[Dict[str, Any]] = None - tool_prompt_format: Optional[ToolPromptFormat] = Field( - default=ToolPromptFormat.json - ) @json_schema_type diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 05d43ad4f..8080b9dff 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -127,7 +127,7 @@ class InferenceRouter(Inference): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d4cb708a2..a3a64bf6b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -523,7 +523,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): description=tool_def.description or "", parameters=tool_def.parameters or [], provider_id=provider_id, - tool_prompt_format=tool_def.tool_prompt_format, provider_resource_id=tool_def.name, metadata=tool_def.metadata, tool_host=tool_host, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index d89bb21f7..5b502a581 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -6,7 +6,6 @@ import asyncio import logging - from typing import AsyncGenerator, List, Optional, Union from llama_models.llama3.api.datatypes import ( @@ -37,7 +36,6 @@ from llama_stack.apis.inference import ( ToolCallParseStatus, ToolChoice, ) - from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( @@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl( response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 0896b44af..3920ee1ad 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) + from .config import SentenceTransformersInferenceConfig log = logging.getLogger(__name__) @@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl( response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 73f7adecd..03bcad3e9 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -10,10 +10,8 @@ import uuid from typing import AsyncGenerator, List, Optional from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import resolve_model - from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams as VLLMSamplingParams @@ -36,7 +34,6 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model - from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, @@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMConfig - log = logging.getLogger(__name__) @@ -146,7 +142,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index d340bbbea..59f30024e 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -30,7 +29,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) - MODEL_ALIASES = [ build_model_alias( "meta.llama3-1-8b-instruct-v1:0", @@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 586447012..b78471787 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -7,11 +7,8 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras - from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.common.content_types import InterleavedContent @@ -29,7 +26,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import CerebrasImplConfig - model_aliases = [ build_model_alias( "llama3.1-8b", @@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 3d88423c5..2964b2aaa 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -7,11 +7,8 @@ from typing import AsyncGenerator, List, Optional from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.tokenizer import Tokenizer - from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent @@ -28,7 +25,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import DatabricksImplConfig - model_aliases = [ build_model_alias( "databricks-meta-llama-3-1-70b-instruct", @@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index e0603a5dc..84dd28102 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -52,7 +51,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig - MODEL_ALIASES = [ build_model_alias( "fireworks/llama-v3p1-8b-instruct", @@ -198,7 +196,7 @@ class FireworksInferenceAdapter( sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 5db4c0894..2fbe48c44 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -33,6 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import ( build_model_alias_with_just_provider_model_id, ModelRegistryHelper, ) + from .groq_utils import ( convert_chat_completion_request, convert_chat_completion_response, @@ -94,9 +95,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ - ToolPromptFormat - ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 42c4db53e..81751e038 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ - ToolPromptFormat - ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> Union[ diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2de5a994e..38721ea22 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union import httpx from llama_models.datatypes import CoreModelId - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from ollama import AsyncClient @@ -35,7 +34,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, build_model_alias_with_just_provider_model_id, @@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 25d2e0cb8..985fd3606 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -30,13 +30,11 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model - from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, ) - from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, OpenAICompatCompletionChoice, @@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 76f411c45..8f679cb56 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -184,7 +184,7 @@ class TogetherInferenceAdapter( sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 9f9072922..317d05207 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models - from openai import OpenAI from llama_stack.apis.common.content_types import InterleavedContent @@ -33,7 +32,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate - from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, @@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig - log = logging.getLogger(__name__) @@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index d296105e0..2d66dc60b 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -358,14 +358,13 @@ def augment_messages_for_tools_llama_3_1( has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) if has_custom_tools: - if request.tool_prompt_format == ToolPromptFormat.json: + fmt = request.tool_prompt_format or ToolPromptFormat.json + if fmt == ToolPromptFormat.json: tool_gen = JsonCustomToolGenerator() - elif request.tool_prompt_format == ToolPromptFormat.function_tag: + elif fmt == ToolPromptFormat.function_tag: tool_gen = FunctionTagCustomToolGenerator() else: - raise ValueError( - f"Non supported ToolPromptFormat {request.tool_prompt_format}" - ) + raise ValueError(f"Non supported ToolPromptFormat {fmt}") custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_template = tool_gen.gen(custom_tools) @@ -410,7 +409,8 @@ def augment_messages_for_tools_llama_3_2( custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] if custom_tools: - if request.tool_prompt_format != ToolPromptFormat.python_list: + fmt = request.tool_prompt_format or ToolPromptFormat.python_list + if fmt != ToolPromptFormat.python_list: raise ValueError( f"Non supported ToolPromptFormat {request.tool_prompt_format}" )