remove conflicting default for tool prompt format in chat completion (#742)

# What does this PR do?
We are setting a default value of json for tool prompt format, which
conflicts with llama 3.2/3.3 models since they use python list. This PR
changes the defaults to None and in the code, we infer default based on
the model.

Addresses: #695 

Tests:
❯ LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v
tests/client-sdk/inference/test_inference.py -k
"test_text_chat_completion"

 pytest llama_stack/providers/tests/inference/test_prompt_adapter.py
This commit is contained in:
Dinesh Yeduguru 2025-01-10 10:41:53 -08:00 committed by GitHub
parent 24fa1adc2f
commit 8af6951106
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 27 additions and 91 deletions

View file

@ -4503,10 +4503,6 @@
} }
] ]
} }
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat",
"default": "json"
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -6522,10 +6518,6 @@
} }
] ]
} }
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat",
"default": "json"
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -2600,9 +2600,6 @@ components:
type: string type: string
tool_host: tool_host:
$ref: '#/components/schemas/ToolHost' $ref: '#/components/schemas/ToolHost'
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
toolgroup_id: toolgroup_id:
type: string type: string
type: type:
@ -2704,9 +2701,6 @@ components:
items: items:
$ref: '#/components/schemas/ToolParameter' $ref: '#/components/schemas/ToolParameter'
type: array type: array
tool_prompt_format:
$ref: '#/components/schemas/ToolPromptFormat'
default: json
required: required:
- name - name
type: object type: object

View file

@ -7,7 +7,6 @@
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -75,6 +72,6 @@ class BatchInference(Protocol):
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse: ...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.schema_utils import json_schema_type, register_schema, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
default=ToolPromptFormat.json
)
response_format: Optional[ResponseFormat] = None response_format: Optional[ResponseFormat] = None
stream: Optional[bool] = False stream: Optional[bool] = False
@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = Field(default_factory=list) tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
tool_prompt_format: Optional[ToolPromptFormat] = Field( tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
default=ToolPromptFormat.json
)
logprobs: Optional[LogProbConfig] = None logprobs: Optional[LogProbConfig] = None
@ -334,7 +325,7 @@ class Inference(Protocol):
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -7,7 +7,6 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
@ -41,9 +40,6 @@ class Tool(Resource):
description: str description: str
parameters: List[ToolParameter] parameters: List[ToolParameter]
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type @json_schema_type
@ -52,9 +48,6 @@ class ToolDef(BaseModel):
description: Optional[str] = None description: Optional[str] = None
parameters: Optional[List[ToolParameter]] = None parameters: Optional[List[ToolParameter]] = None
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None
tool_prompt_format: Optional[ToolPromptFormat] = Field(
default=ToolPromptFormat.json
)
@json_schema_type @json_schema_type

View file

@ -127,7 +127,7 @@ class InferenceRouter(Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -523,7 +523,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
description=tool_def.description or "", description=tool_def.description or "",
parameters=tool_def.parameters or [], parameters=tool_def.parameters or [],
provider_id=provider_id, provider_id=provider_id,
tool_prompt_format=tool_def.tool_prompt_format,
provider_resource_id=tool_def.name, provider_resource_id=tool_def.name,
metadata=tool_def.metadata, metadata=tool_def.metadata,
tool_host=tool_host, tool_host=tool_host,

View file

@ -6,7 +6,6 @@
import asyncio import asyncio
import logging import logging
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
ToolCallParseStatus, ToolCallParseStatus,
ToolChoice, ToolChoice,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
) )
from .config import SentenceTransformersInferenceConfig from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -10,10 +10,8 @@ import uuid
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams from vllm.sampling_params import SamplingParams as VLLMSamplingParams
@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMConfig from .config import VLLMConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -146,7 +142,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
MODEL_ALIASES = [ MODEL_ALIASES = [
build_model_alias( build_model_alias(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import CerebrasImplConfig from .config import CerebrasImplConfig
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"llama3.1-8b", "llama3.1-8b",
@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -7,11 +7,8 @@
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
model_aliases = [ model_aliases = [
build_model_alias( build_model_alias(
"databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct",
@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
@ -52,7 +51,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
MODEL_ALIASES = [ MODEL_ALIASES = [
build_model_alias( build_model_alias(
"fireworks/llama-v3p1-8b-instruct", "fireworks/llama-v3p1-8b-instruct",
@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -33,6 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_alias_with_just_provider_model_id, build_model_alias_with_just_provider_model_id,
ModelRegistryHelper, ModelRegistryHelper,
) )
from .groq_utils import ( from .groq_utils import (
convert_chat_completion_request, convert_chat_completion_request,
convert_chat_completion_response, convert_chat_completion_response,
@ -94,9 +95,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ tool_prompt_format: Optional[ToolPromptFormat] = None,
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[

View file

@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ tool_prompt_format: Optional[ToolPromptFormat] = None,
ToolPromptFormat
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[ ) -> Union[

View file

@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
import httpx import httpx
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
build_model_alias_with_just_provider_model_id, build_model_alias_with_just_provider_model_id,
@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -184,7 +184,7 @@ class TogetherInferenceAdapter(
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,

View file

@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:

View file

@ -358,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools) has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools: if has_custom_tools:
if request.tool_prompt_format == ToolPromptFormat.json: fmt = request.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator() tool_gen = JsonCustomToolGenerator()
elif request.tool_prompt_format == ToolPromptFormat.function_tag: elif fmt == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator() tool_gen = FunctionTagCustomToolGenerator()
else: else:
raise ValueError( raise ValueError(f"Non supported ToolPromptFormat {fmt}")
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
)
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)] custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
custom_template = tool_gen.gen(custom_tools) custom_template = tool_gen.gen(custom_tools)
@ -410,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)] custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
if custom_tools: if custom_tools:
if request.tool_prompt_format != ToolPromptFormat.python_list: fmt = request.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError( raise ValueError(
f"Non supported ToolPromptFormat {request.tool_prompt_format}" f"Non supported ToolPromptFormat {request.tool_prompt_format}"
) )