chore: more mypy fixes (#2029)

# What does this PR do?

Mainly tried to cover the entire llama_stack/apis directory, we only
have one left. Some excludes were just noop.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-06 18:52:31 +02:00 committed by GitHub
parent feb9eb8b0d
commit 1a529705da
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 581 additions and 166 deletions

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from datetime import datetime
from enum import Enum
@ -35,6 +36,14 @@ from .openai_responses import (
OpenAIResponseObjectStream,
)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel):
"""An attachment to an agent turn.
@ -73,7 +82,7 @@ class StepCommon(BaseModel):
completed_at: datetime | None = None
class StepType(Enum):
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
@ -97,7 +106,7 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@ -109,7 +118,7 @@ class ToolExecutionStep(StepCommon):
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@ -121,7 +130,7 @@ class ShieldCallStep(StepCommon):
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@ -133,7 +142,7 @@ class MemoryRetrievalStep(StepCommon):
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
@ -154,7 +163,7 @@ class Turn(BaseModel):
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=list)
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@ -182,10 +191,10 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=list)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
@ -246,7 +255,7 @@ class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(Enum):
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
@ -258,15 +267,15 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=dict)
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@ -276,7 +285,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
@ -285,21 +294,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = (
AgentTurnResponseEventType.turn_awaiting_input.value
)
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
@ -341,7 +348,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None

View file

@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property
def benchmark_id(self) -> str:
return self.identifier
@property
def provider_benchmark_id(self) -> str:
def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id

View file

@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64
data: str | None = Field(contentEncoding="base64", default=None)
data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
@model_validator(mode="before")
@classmethod

View file

@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
@json_schema_type
class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
type: Literal[ResourceType.dataset] = ResourceType.dataset
@property
def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str:
def provider_dataset_id(self) -> str | None:
return self.provider_resource_id

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from collections.abc import AsyncIterator
from enum import Enum
from typing import (
@ -35,6 +36,16 @@ register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
@json_schema_type
class GreedySamplingStrategy(BaseModel):
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant"
content: InterleavedContent
stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=list)
tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: StopReason | None = None
class ResponseFormatType(Enum):
class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding.
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
"""
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value
type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
json_schema: dict[str, Any]
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
:param bnf: The BNF grammar specification the response should conform to
"""
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value
type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
bnf: dict[str, Any]
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=list)
tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type
class OpenAIJSONSchema(TypedDict, total=False):
name: str
description: str | None = None
strict: bool | None = None
description: str | None
strict: bool | None
# Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict.
schema: dict[str, Any] | None = None
schema: dict[str, Any] | None
@json_schema_type

View file

@ -29,14 +29,14 @@ class ModelType(str, Enum):
@json_schema_type
class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value
type: Literal[ResourceType.model] = ResourceType.model
@property
def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
def provider_model_id(self) -> str | None:
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())

View file

@ -4,12 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from enum import Enum
from pydantic import BaseModel, Field
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class ResourceType(Enum):
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
class ResourceType(StrEnum):
model = "model"
shield = "shield"
vector_db = "vector_db"
@ -25,9 +36,9 @@ class Resource(BaseModel):
identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
provider_resource_id: str | None = Field(
default=None,
description="Unique identifier for this resource in the provider",
)
provider_id: str = Field(description="ID of the provider that owns this resource")

View file

@ -53,5 +53,5 @@ class Safety(Protocol):
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
params: dict[str, Any],
) -> RunShieldResponse: ...

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
import sys
from enum import Enum
from typing import (
Annotated,
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
class ScoringFnParamsType(StrEnum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
basic = "basic"
@json_schema_type
class AggregationFunctionType(Enum):
class AggregationFunctionType(StrEnum):
average = "average"
weighted_average = "weighted_average"
median = "median"
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str
prompt_template: str | None = None
judge_score_regexes: list[str] | None = Field(
judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] | None = Field(
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
default_factory=lambda: [],
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value
parsing_regexes: list[str] | None = Field(
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
default_factory=lambda: [],
)
aggregation_functions: list[AggregationFunctionType] | None = Field(
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
default_factory=lambda: [],
)
@json_schema_type
class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
aggregation_functions: list[AggregationFunctionType] | None = Field(
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row",
default_factory=list,
)
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type
class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id

View file

@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
type: Literal[ResourceType.shield] = ResourceType.shield
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
def provider_shield_id(self) -> str | None:
return self.provider_resource_id

View file

@ -37,7 +37,7 @@ class Span(BaseModel):
name: str
start_time: datetime
end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=dict)
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any):
if self.attributes is None:
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=dict)
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type
class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str
severity: LogSeverity
@json_schema_type
class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value
type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum
value: int | float
unit: str
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
@json_schema_type
class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str
parent_span_id: str | None = None
@json_schema_type
class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type
class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload

View file

@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type
class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value
type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str
tool_host: ToolHost
description: str
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
@json_schema_type
class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value
type: Literal[ResourceType.tool_group] = ResourceType.tool_group
mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None

View file

@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class VectorDB(Resource):
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value
type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str
embedding_dimension: int
@ -25,7 +25,7 @@ class VectorDB(Resource):
return self.identifier
@property
def provider_vector_db_id(self) -> str:
def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id

View file

@ -38,7 +38,10 @@ class LlamaCLIParser:
print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
args = self.parser.parse_args()
if not isinstance(args, argparse.Namespace):
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
return args
def run(self, args: argparse.Namespace) -> None:
args.func(args)

View file

@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis]
providers = [(api, p) for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
rows = []
specs = [spec for p in providers for spec in p.values()]
specs = [spec for api, p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages),
",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
]
)
print_table(

View file

@ -73,11 +73,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, [])
if existing_providers:
logger.info(
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
logger.info(f"Re-configuring existing providers for API `{api_str}`...")
updated_providers = []
for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`")
@ -91,7 +87,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist:
raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"])
logger.info(f"Configuring API `{api_str}`...")
updated_providers = []
for i, provider_type in enumerate(plist):
if i >= 1:

View file

@ -30,7 +30,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
@ -216,7 +216,18 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
"yellow",
)
if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers)
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
),
)
print_pip_install_help(build_config)
else:
prefix = "!" if in_notebook() else ""
cprint(

View file

@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}"
if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type
validator_class = spec.provider_data_validator

View file

@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
else:
full_response = response
message_placeholder.markdown(full_response.completion_message.content)
full_response = response.completion_message.content
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
"""
)
return PromptTemplate(
template = PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
).render()
)
rendered: str = template.render()
return rendered
def data_examples(self) -> list[list[ToolDefinition]]:
return [

View file

@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
elif model.core_model_id == CoreModelId.llama_guard_2_8b:
folder = "llama-guard-2"
else:
if model.huggingface_repo is None:
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
folder = model.huggingface_repo.split("/")[-1]
if "Llama-2" in folder:
folder = folder.lower()
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
return 54121549657
else:
return 100426653046
return 0

View file

@ -139,6 +139,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -202,6 +204,8 @@ class OllamaInferenceAdapter(
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
@ -346,6 +350,8 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id is None:
raise ValueError("Model provider_resource_id cannot be None")
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id

View file

@ -272,6 +272,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
@ -302,6 +304,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
if model.provider_resource_id is None:
raise ValueError(f"Model {model_id} has no provider_resource_id set")
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice

View file

@ -382,7 +382,7 @@ def augment_messages_for_tools_llama_3_1(
messages.append(SystemMessage(content=sys_content))
has_custom_tools = any(isinstance(dfn.tool_name, str) for dfn in request.tools)
has_custom_tools = request.tools is not None and any(isinstance(dfn.tool_name, str) for dfn in request.tools)
if has_custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.json
if fmt == ToolPromptFormat.json: