mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
remove conflicting default for tool prompt format in chat completion (#742)
# What does this PR do? We are setting a default value of json for tool prompt format, which conflicts with llama 3.2/3.3 models since they use python list. This PR changes the defaults to None and in the code, we infer default based on the model. Addresses: #695 Tests: ❯ LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v tests/client-sdk/inference/test_inference.py -k "test_text_chat_completion" pytest llama_stack/providers/tests/inference/test_prompt_adapter.py
This commit is contained in:
parent
24fa1adc2f
commit
8af6951106
21 changed files with 27 additions and 91 deletions
|
@ -4503,10 +4503,6 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tool_prompt_format": {
|
||||
"$ref": "#/components/schemas/ToolPromptFormat",
|
||||
"default": "json"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
@ -6522,10 +6518,6 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"tool_prompt_format": {
|
||||
"$ref": "#/components/schemas/ToolPromptFormat",
|
||||
"default": "json"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
|
@ -2600,9 +2600,6 @@ components:
|
|||
type: string
|
||||
tool_host:
|
||||
$ref: '#/components/schemas/ToolHost'
|
||||
tool_prompt_format:
|
||||
$ref: '#/components/schemas/ToolPromptFormat'
|
||||
default: json
|
||||
toolgroup_id:
|
||||
type: string
|
||||
type:
|
||||
|
@ -2704,9 +2701,6 @@ components:
|
|||
items:
|
||||
$ref: '#/components/schemas/ToolParameter'
|
||||
type: array
|
||||
tool_prompt_format:
|
||||
$ref: '#/components/schemas/ToolPromptFormat'
|
||||
default: json
|
||||
required:
|
||||
- name
|
||||
type: object
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -44,9 +43,7 @@ class BatchChatCompletionRequest(BaseModel):
|
|||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
|
@ -75,6 +72,6 @@ class BatchInference(Protocol):
|
|||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchChatCompletionResponse: ...
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
|
@ -26,16 +25,12 @@ from llama_models.llama3.api.datatypes import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
|
@ -256,9 +251,7 @@ class ChatCompletionRequest(BaseModel):
|
|||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
@ -289,9 +282,7 @@ class BatchChatCompletionRequest(BaseModel):
|
|||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None)
|
||||
logprobs: Optional[LogProbConfig] = None
|
||||
|
||||
|
||||
|
@ -334,7 +325,7 @@ class Inference(Protocol):
|
|||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
@ -41,9 +40,6 @@ class Tool(Resource):
|
|||
description: str
|
||||
parameters: List[ToolParameter]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -52,9 +48,6 @@ class ToolDef(BaseModel):
|
|||
description: Optional[str] = None
|
||||
parameters: Optional[List[ToolParameter]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -127,7 +127,7 @@ class InferenceRouter(Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -523,7 +523,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
description=tool_def.description or "",
|
||||
parameters=tool_def.parameters or [],
|
||||
provider_id=provider_id,
|
||||
tool_prompt_format=tool_def.tool_prompt_format,
|
||||
provider_resource_id=tool_def.name,
|
||||
metadata=tool_def.metadata,
|
||||
tool_host=tool_host,
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_models.llama3.api.datatypes import (
|
||||
|
@ -37,7 +36,6 @@ from llama_stack.apis.inference import (
|
|||
ToolCallParseStatus,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
@ -262,7 +260,7 @@ class MetaReferenceInferenceImpl(
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -22,6 +22,7 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
|||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -67,7 +68,7 @@ class SentenceTransformersInferenceImpl(
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -10,10 +10,8 @@ import uuid
|
|||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||
|
@ -36,7 +34,6 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -50,7 +47,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -146,7 +142,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|||
from botocore.client import BaseClient
|
||||
from llama_models.datatypes import CoreModelId
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -30,7 +29,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
|
@ -47,7 +45,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
|
@ -101,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
|
|
|
@ -7,11 +7,8 @@
|
|||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -29,7 +26,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
|
@ -48,7 +44,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import CerebrasImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
"llama3.1-8b",
|
||||
|
@ -130,7 +125,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -7,11 +7,8 @@
|
|||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -28,7 +25,6 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
|
@ -44,7 +40,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
|
||||
model_aliases = [
|
||||
build_model_alias(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
|
@ -91,7 +86,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
|
@ -52,7 +51,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import FireworksImplConfig
|
||||
|
||||
|
||||
MODEL_ALIASES = [
|
||||
build_model_alias(
|
||||
"fireworks/llama-v3p1-8b-instruct",
|
||||
|
@ -198,7 +196,7 @@ class FireworksInferenceAdapter(
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_alias_with_just_provider_model_id,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
from .groq_utils import (
|
||||
convert_chat_completion_request,
|
||||
convert_chat_completion_response,
|
||||
|
@ -94,9 +95,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[
|
||||
ToolPromptFormat
|
||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
|
|
|
@ -175,9 +175,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[
|
||||
ToolPromptFormat
|
||||
] = None, # API default is ToolPromptFormat.json, we default to None to detect user input
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[
|
||||
|
|
|
@ -9,7 +9,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
import httpx
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from ollama import AsyncClient
|
||||
|
@ -35,7 +34,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
build_model_alias_with_just_provider_model_id,
|
||||
|
@ -222,7 +220,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -30,13 +30,11 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
OpenAICompatCompletionChoice,
|
||||
|
@ -205,7 +203,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -184,7 +184,7 @@ class TogetherInferenceAdapter(
|
|||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import all_registered_models
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
|
@ -33,7 +32,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_model_alias,
|
||||
ModelRegistryHelper,
|
||||
|
@ -54,7 +52,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import VLLMInferenceAdapterConfig
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -105,7 +102,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
|
|
|
@ -358,14 +358,13 @@ def augment_messages_for_tools_llama_3_1(
|
|||
|
||||
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
|
||||
if has_custom_tools:
|
||||
if request.tool_prompt_format == ToolPromptFormat.json:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.json
|
||||
if fmt == ToolPromptFormat.json:
|
||||
tool_gen = JsonCustomToolGenerator()
|
||||
elif request.tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
elif fmt == ToolPromptFormat.function_tag:
|
||||
tool_gen = FunctionTagCustomToolGenerator()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
raise ValueError(f"Non supported ToolPromptFormat {fmt}")
|
||||
|
||||
custom_tools = [t for t in request.tools if isinstance(t.tool_name, str)]
|
||||
custom_template = tool_gen.gen(custom_tools)
|
||||
|
@ -410,7 +409,8 @@ def augment_messages_for_tools_llama_3_2(
|
|||
|
||||
custom_tools = [dfn for dfn in request.tools if isinstance(dfn.tool_name, str)]
|
||||
if custom_tools:
|
||||
if request.tool_prompt_format != ToolPromptFormat.python_list:
|
||||
fmt = request.tool_prompt_format or ToolPromptFormat.python_list
|
||||
if fmt != ToolPromptFormat.python_list:
|
||||
raise ValueError(
|
||||
f"Non supported ToolPromptFormat {request.tool_prompt_format}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue