diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 7ace983f8..0ce216479 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -4503,10 +4503,6 @@
}
]
}
- },
- "tool_prompt_format": {
- "$ref": "#/components/schemas/ToolPromptFormat",
- "default": "json"
}
},
"additionalProperties": false,
@@ -6522,10 +6518,6 @@
}
]
}
- },
- "tool_prompt_format": {
- "$ref": "#/components/schemas/ToolPromptFormat",
- "default": "json"
}
},
"additionalProperties": false,
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index a2f6bc005..031178ce9 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -2600,9 +2600,6 @@ components:
type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
- tool_prompt_format:
- $ref: '#/components/schemas/ToolPromptFormat'
- default: json
toolgroup_id:
type: string
type:
@@ -2704,9 +2701,6 @@ components:
items:
$ref: '#/components/schemas/ToolParameter'
type: array
- tool_prompt_format:
- $ref: '#/components/schemas/ToolPromptFormat'
- default: json
required:
- name
type: object
diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py
index f7b8b4387..81826a7b1 100644
--- a/llama_stack/apis/batch_inference/batch_inference.py
+++ b/llama_stack/apis/batch_inference/batch_inference.py
@@ -7,7 +7,6 @@
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
-
from pydantic import BaseModel, Field
from llama_stack.apis.inference import (
@@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
- tool_prompt_format: Optional[ToolPromptFormat] = Field(
- default=ToolPromptFormat.json
- )
+ tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@@ -75,6 +72,6 @@ class BatchInference(Protocol):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index e48042091..a6a096041 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -5,7 +5,6 @@
# the root directory of this source tree.
from enum import Enum
-
from typing import (
Any,
AsyncIterator,
@@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
-
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
-
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent
-
from llama_stack.apis.models import Model
-
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
- tool_prompt_format: Optional[ToolPromptFormat] = Field(
- default=ToolPromptFormat.json
- )
+ tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False
@@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
- tool_prompt_format: Optional[ToolPromptFormat] = Field(
- default=ToolPromptFormat.json
- )
+ tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
logprobs: Optional[LogProbConfig] = None
@@ -334,7 +325,7 @@ class Inference(Protocol):
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py
index e430ec46d..d2bdf9873 100644
--- a/llama_stack/apis/tools/tools.py
+++ b/llama_stack/apis/tools/tools.py
@@ -7,7 +7,6 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
-from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
@@ -41,9 +40,6 @@ class Tool(Resource):
description: str
parameters: List[ToolParameter]
metadata: Optional[Dict[str, Any]] = None
- tool_prompt_format: Optional[ToolPromptFormat] = Field(
- default=ToolPromptFormat.json
- )
@json_schema_type
@@ -52,9 +48,6 @@ class ToolDef(BaseModel):
description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None
- tool_prompt_format: Optional[ToolPromptFormat] = Field(
- default=ToolPromptFormat.json
- )
@json_schema_type
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 05d43ad4f..8080b9dff 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -127,7 +127,7 @@ class InferenceRouter(Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index d4cb708a2..a3a64bf6b 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -523,7 +523,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
- tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index d89bb21f7..5b502a581 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -6,7 +6,6 @@
import asyncio
import logging
-
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import (
@@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
ToolCallParseStatus,
ToolChoice,
)
-
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
@@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
index 0896b44af..3920ee1ad 100644
--- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
+++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
@@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
+
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index 73f7adecd..03bcad3e9 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -10,10 +10,8 @@ import uuid
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
-
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
-
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
@@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
-
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
@@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMConfig
-
log = logging.getLogger(__name__)
@@ -146,7 +142,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index d340bbbea..59f30024e 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
-
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
@@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
-
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
-
MODEL_ALIASES = [
build_model_alias(
"meta.llama3-1-8b-instruct-v1:0",
@@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index 586447012..b78471787 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
-
from llama_models.datatypes import CoreModelId
-
from llama_models.llama3.api.chat_format import ChatFormat
-
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent
@@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
-
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CerebrasImplConfig
-
model_aliases = [
build_model_alias(
"llama3.1-8b",
@@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index 3d88423c5..2964b2aaa 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId
-
from llama_models.llama3.api.chat_format import ChatFormat
-
from llama_models.llama3.api.tokenizer import Tokenizer
-
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
-
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig
-
model_aliases = [
build_model_alias(
"databricks-meta-llama-3-1-70b-instruct",
@@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index e0603a5dc..84dd28102 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId
-
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
@@ -52,7 +51,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
-
MODEL_ALIASES = [
build_model_alias(
"fireworks/llama-v3p1-8b-instruct",
@@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py
index 5db4c0894..2fbe48c44 100644
--- a/llama_stack/providers/remote/inference/groq/groq.py
+++ b/llama_stack/providers/remote/inference/groq/groq.py
@@ -33,6 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_alias_with_just_provider_model_id,
ModelRegistryHelper,
)
+
from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
@@ -94,9 +95,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[
- ToolPromptFormat
- ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index 42c4db53e..81751e038 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[
- ToolPromptFormat
- ] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 2de5a994e..38721ea22 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx
from llama_models.datatypes import CoreModelId
-
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
@@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
-
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
build_model_alias_with_just_provider_model_id,
@@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 25d2e0cb8..985fd3606 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
ToolPromptFormat,
)
from llama_stack.apis.models import Model
-
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
-
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
@@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 76f411c45..8f679cb56 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -184,7 +184,7 @@ class TogetherInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 9f9072922..317d05207 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
-
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
@@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
-
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
@@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
-
log = logging.getLogger(__name__)
@@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
- tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
+ tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py
index d296105e0..2d66dc60b 100644
--- a/llama_stack/providers/utils/inference/prompt_adapter.py
+++ b/llama_stack/providers/utils/inference/prompt_adapter.py
@@ -358,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
- if request.tool_prompt_format == ToolPromptFormat.json:
+ fmt = request.tool_prompt_format or ToolPromptFormat.json
+ if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
- elif request.tool_prompt_format == ToolPromptFormat.function_tag:
+ elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
- raise ValueError(
- f"Non supported ToolPromptFormat {request.tool_prompt_format}"
- )
+ raise ValueError(f"Non supported ToolPromptFormat {fmt}")
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools)
@@ -410,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools:
- if request.tool_prompt_format != ToolPromptFormat.python_list:
+ fmt = request.tool_prompt_format or ToolPromptFormat.python_list
+ if fmt != ToolPromptFormat.python_list:
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)