diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index cd92a10f5..33112012b 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -2598,6 +2598,22 @@
}
]
},
+ "Message": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/UserMessage"
+ },
+ {
+ "$ref": "#/components/schemas/SystemMessage"
+ },
+ {
+ "$ref": "#/components/schemas/ToolResponseMessage"
+ },
+ {
+ "$ref": "#/components/schemas/CompletionMessage"
+ }
+ ]
+ },
"SamplingParams": {
"type": "object",
"properties": {
@@ -2893,9 +2909,16 @@
]
},
"URL": {
- "type": "string",
- "format": "uri",
- "pattern": "^(https?://|file://|data:)"
+ "type": "object",
+ "properties": {
+ "uri": {
+ "type": "string"
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "uri"
+ ]
},
"UserMessage": {
"type": "object",
@@ -2929,20 +2952,7 @@
"items": {
"type": "array",
"items": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/UserMessage"
- },
- {
- "$ref": "#/components/schemas/SystemMessage"
- },
- {
- "$ref": "#/components/schemas/ToolResponseMessage"
- },
- {
- "$ref": "#/components/schemas/CompletionMessage"
- }
- ]
+ "$ref": "#/components/schemas/Message"
}
}
},
@@ -3052,6 +3062,90 @@
"job_uuid"
]
},
+ "ResponseFormat": {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "json_schema",
+ "default": "json_schema"
+ },
+ "json_schema": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "json_schema"
+ ]
+ },
+ {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "grammar",
+ "default": "grammar"
+ },
+ "bnf": {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "null"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "string"
+ },
+ {
+ "type": "array"
+ },
+ {
+ "type": "object"
+ }
+ ]
+ }
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "bnf"
+ ]
+ }
+ ]
+ },
"ChatCompletionRequest": {
"type": "object",
"properties": {
@@ -3061,20 +3155,7 @@
"messages": {
"type": "array",
"items": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/UserMessage"
- },
- {
- "$ref": "#/components/schemas/SystemMessage"
- },
- {
- "$ref": "#/components/schemas/ToolResponseMessage"
- },
- {
- "$ref": "#/components/schemas/CompletionMessage"
- }
- ]
+ "$ref": "#/components/schemas/Message"
}
},
"sampling_params": {
@@ -3093,88 +3174,7 @@
"$ref": "#/components/schemas/ToolPromptFormat"
},
"response_format": {
- "oneOf": [
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "json_schema",
- "default": "json_schema"
- },
- "json_schema": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "json_schema"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "grammar",
- "default": "grammar"
- },
- "bnf": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "bnf"
- ]
- }
- ]
+ "$ref": "#/components/schemas/ResponseFormat"
},
"stream": {
"type": "boolean"
@@ -3329,88 +3329,7 @@
"$ref": "#/components/schemas/SamplingParams"
},
"response_format": {
- "oneOf": [
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "json_schema",
- "default": "json_schema"
- },
- "json_schema": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "json_schema"
- ]
- },
- {
- "type": "object",
- "properties": {
- "type": {
- "type": "string",
- "const": "grammar",
- "default": "grammar"
- },
- "bnf": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "null"
- },
- {
- "type": "boolean"
- },
- {
- "type": "number"
- },
- {
- "type": "string"
- },
- {
- "type": "array"
- },
- {
- "type": "object"
- }
- ]
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "type",
- "bnf"
- ]
- }
- ]
+ "$ref": "#/components/schemas/ResponseFormat"
},
"stream": {
"type": "boolean"
@@ -7278,20 +7197,7 @@
"messages": {
"type": "array",
"items": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/UserMessage"
- },
- {
- "$ref": "#/components/schemas/SystemMessage"
- },
- {
- "$ref": "#/components/schemas/ToolResponseMessage"
- },
- {
- "$ref": "#/components/schemas/CompletionMessage"
- }
- ]
+ "$ref": "#/components/schemas/Message"
}
},
"params": {
@@ -7657,20 +7563,7 @@
"dialogs": {
"type": "array",
"items": {
- "oneOf": [
- {
- "$ref": "#/components/schemas/UserMessage"
- },
- {
- "$ref": "#/components/schemas/SystemMessage"
- },
- {
- "$ref": "#/components/schemas/ToolResponseMessage"
- },
- {
- "$ref": "#/components/schemas/CompletionMessage"
- }
- ]
+ "$ref": "#/components/schemas/Message"
}
},
"filtering_function": {
@@ -8129,6 +8022,10 @@
"name": "MemoryToolDefinition",
"description": ""
},
+ {
+ "name": "Message",
+ "description": ""
+ },
{
"name": "MetricEvent",
"description": ""
@@ -8247,6 +8144,10 @@
"name": "RegisterShieldRequest",
"description": ""
},
+ {
+ "name": "ResponseFormat",
+ "description": ""
+ },
{
"name": "RestAPIExecutionConfig",
"description": ""
@@ -8591,6 +8492,7 @@
"MemoryBankDocument",
"MemoryRetrievalStep",
"MemoryToolDefinition",
+ "Message",
"MetricEvent",
"Model",
"ModelCandidate",
@@ -8619,6 +8521,7 @@
"RegisterModelRequest",
"RegisterScoringFunctionRequest",
"RegisterShieldRequest",
+ "ResponseFormat",
"RestAPIExecutionConfig",
"RestAPIMethod",
"RouteInfo",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 08db0699e..abd57e17e 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -313,11 +313,7 @@ components:
messages_batch:
items:
items:
- oneOf:
- - $ref: '#/components/schemas/UserMessage'
- - $ref: '#/components/schemas/SystemMessage'
- - $ref: '#/components/schemas/ToolResponseMessage'
- - $ref: '#/components/schemas/CompletionMessage'
+ $ref: '#/components/schemas/Message'
type: array
type: array
model:
@@ -422,56 +418,12 @@ components:
type: object
messages:
items:
- oneOf:
- - $ref: '#/components/schemas/UserMessage'
- - $ref: '#/components/schemas/SystemMessage'
- - $ref: '#/components/schemas/ToolResponseMessage'
- - $ref: '#/components/schemas/CompletionMessage'
+ $ref: '#/components/schemas/Message'
type: array
model_id:
type: string
response_format:
- oneOf:
- - additionalProperties: false
- properties:
- json_schema:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type:
- const: json_schema
- default: json_schema
- type: string
- required:
- - type
- - json_schema
- type: object
- - additionalProperties: false
- properties:
- bnf:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type:
- const: grammar
- default: grammar
- type: string
- required:
- - type
- - bnf
- type: object
+ $ref: '#/components/schemas/ResponseFormat'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
stream:
@@ -598,47 +550,7 @@ components:
model_id:
type: string
response_format:
- oneOf:
- - additionalProperties: false
- properties:
- json_schema:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type:
- const: json_schema
- default: json_schema
- type: string
- required:
- - type
- - json_schema
- type: object
- - additionalProperties: false
- properties:
- bnf:
- additionalProperties:
- oneOf:
- - type: 'null'
- - type: boolean
- - type: number
- - type: string
- - type: array
- - type: object
- type: object
- type:
- const: grammar
- default: grammar
- type: string
- required:
- - type
- - bnf
- type: object
+ $ref: '#/components/schemas/ResponseFormat'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
stream:
@@ -1467,6 +1379,12 @@ components:
- max_tokens_in_context
- max_chunks
type: object
+ Message:
+ oneOf:
+ - $ref: '#/components/schemas/UserMessage'
+ - $ref: '#/components/schemas/SystemMessage'
+ - $ref: '#/components/schemas/ToolResponseMessage'
+ - $ref: '#/components/schemas/CompletionMessage'
MetricEvent:
additionalProperties: false
properties:
@@ -2121,6 +2039,48 @@ components:
required:
- shield_id
type: object
+ ResponseFormat:
+ oneOf:
+ - additionalProperties: false
+ properties:
+ json_schema:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ type:
+ const: json_schema
+ default: json_schema
+ type: string
+ required:
+ - type
+ - json_schema
+ type: object
+ - additionalProperties: false
+ properties:
+ bnf:
+ additionalProperties:
+ oneOf:
+ - type: 'null'
+ - type: boolean
+ - type: number
+ - type: string
+ - type: array
+ - type: object
+ type: object
+ type:
+ const: grammar
+ default: grammar
+ type: string
+ required:
+ - type
+ - bnf
+ type: object
RestAPIExecutionConfig:
additionalProperties: false
properties:
@@ -2203,11 +2163,7 @@ components:
properties:
messages:
items:
- oneOf:
- - $ref: '#/components/schemas/UserMessage'
- - $ref: '#/components/schemas/SystemMessage'
- - $ref: '#/components/schemas/ToolResponseMessage'
- - $ref: '#/components/schemas/CompletionMessage'
+ $ref: '#/components/schemas/Message'
type: array
params:
additionalProperties:
@@ -2744,11 +2700,7 @@ components:
properties:
dialogs:
items:
- oneOf:
- - $ref: '#/components/schemas/UserMessage'
- - $ref: '#/components/schemas/SystemMessage'
- - $ref: '#/components/schemas/ToolResponseMessage'
- - $ref: '#/components/schemas/CompletionMessage'
+ $ref: '#/components/schemas/Message'
type: array
filtering_function:
enum:
@@ -3105,9 +3057,13 @@ components:
title: A single turn in an interaction with an Agentic System.
type: object
URL:
- format: uri
- pattern: ^(https?://|file://|data:)
- type: string
+ additionalProperties: false
+ properties:
+ uri:
+ type: string
+ required:
+ - uri
+ type: object
UnregisterDatasetRequest:
additionalProperties: false
properties:
@@ -5020,6 +4976,8 @@ tags:
- description:
name: MemoryToolDefinition
+- description:
+ name: Message
- description:
name: MetricEvent
- description:
@@ -5104,6 +5062,8 @@ tags:
- description:
name: RegisterShieldRequest
+- description:
+ name: ResponseFormat
- description:
name: RestAPIExecutionConfig
@@ -5367,6 +5327,7 @@ x-tagGroups:
- MemoryBankDocument
- MemoryRetrievalStep
- MemoryToolDefinition
+ - Message
- MetricEvent
- Model
- ModelCandidate
@@ -5395,6 +5356,7 @@ x-tagGroups:
- RegisterModelRequest
- RegisterScoringFunctionRequest
- RegisterShieldRequest
+ - ResponseFormat
- RestAPIExecutionConfig
- RestAPIMethod
- RouteInfo
diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py
index 316a4a5d6..121218a29 100644
--- a/llama_stack/apis/common/content_types.py
+++ b/llama_stack/apis/common/content_types.py
@@ -11,15 +11,10 @@ from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
-@json_schema_type(
- schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
-)
+@json_schema_type
class URL(BaseModel):
uri: str
- def __str__(self) -> str:
- return self.uri
-
class _URLOrData(BaseModel):
url: Optional[URL] = None
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index c481d04d7..28b9d9106 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
-from llama_models.schema_utils import json_schema_type, webmethod
+from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
@@ -100,15 +100,18 @@ class CompletionMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
-Message = Annotated[
- Union[
- UserMessage,
- SystemMessage,
- ToolResponseMessage,
- CompletionMessage,
+Message = register_schema(
+ Annotated[
+ Union[
+ UserMessage,
+ SystemMessage,
+ ToolResponseMessage,
+ CompletionMessage,
+ ],
+ Field(discriminator="role"),
],
- Field(discriminator="role"),
-]
+ name="Message",
+)
@json_schema_type
@@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
-ResponseFormat = Annotated[
- Union[JsonSchemaResponseFormat, GrammarResponseFormat],
- Field(discriminator="type"),
-]
+ResponseFormat = register_schema(
+ Annotated[
+ Union[JsonSchemaResponseFormat, GrammarResponseFormat],
+ Field(discriminator="type"),
+ ],
+ name="ResponseFormat",
+)
@json_schema_type
diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py
index 5671082d5..f5180b0db 100644
--- a/llama_stack/distribution/stack.py
+++ b/llama_stack/distribution/stack.py
@@ -144,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None:
raise EnvVarError(env_var, path)
else:
- value = default_val if default_val != "null" else None
+ value = default_val
# expand "~" from the values
return os.path.expanduser(value)
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index f80f72a8e..ddf59fda8 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -6,20 +6,25 @@
from typing import * # noqa: F403
import json
-import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
-from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
+from llama_stack.providers.utils.inference.openai_compat import (
+ OpenAICompatCompletionChoice,
+ OpenAICompatCompletionResponse,
+ process_chat_completion_response,
+ process_chat_completion_stream_response,
+)
from llama_stack.providers.utils.inference.prompt_adapter import (
+ chat_completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
)
@@ -46,7 +51,6 @@ MODEL_ALIASES = [
]
-# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
@@ -76,232 +80,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
- @staticmethod
- def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
- if bedrock_stop_reason == "max_tokens":
- return StopReason.out_of_tokens
- return StopReason.end_of_turn
-
- @staticmethod
- def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
- for builtin_tool in BuiltinTool:
- if builtin_tool.value == tool_name_str:
- return builtin_tool
- else:
- return tool_name_str
-
- @staticmethod
- def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
- stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
- converse_api_res["stopReason"]
- )
-
- bedrock_message = converse_api_res["output"]["message"]
-
- role = bedrock_message["role"]
- contents = bedrock_message["content"]
-
- tool_calls = []
- text_content = ""
- for content in contents:
- if "toolUse" in content:
- tool_use = content["toolUse"]
- tool_calls.append(
- ToolCall(
- tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
- tool_use["name"]
- ),
- arguments=tool_use["input"] if "input" in tool_use else None,
- call_id=tool_use["toolUseId"],
- )
- )
- elif "text" in content:
- text_content += content["text"]
-
- return CompletionMessage(
- role=role,
- content=text_content,
- stop_reason=stop_reason,
- tool_calls=tool_calls,
- )
-
- @staticmethod
- def _messages_to_bedrock_messages(
- messages: List[Message],
- ) -> Tuple[List[Dict], Optional[List[Dict]]]:
- bedrock_messages = []
- system_bedrock_messages = []
-
- user_contents = []
- assistant_contents = None
- for message in messages:
- role = message.role
- content_list = (
- message.content
- if isinstance(message.content, list)
- else [message.content]
- )
- if role == "ipython" or role == "user":
- if not user_contents:
- user_contents = []
-
- if role == "ipython":
- user_contents.extend(
- [
- {
- "toolResult": {
- "toolUseId": message.call_id or str(uuid.uuid4()),
- "content": [
- {"text": content} for content in content_list
- ],
- }
- }
- ]
- )
- else:
- user_contents.extend(
- [{"text": content} for content in content_list]
- )
-
- if assistant_contents:
- bedrock_messages.append(
- {"role": "assistant", "content": assistant_contents}
- )
- assistant_contents = None
- elif role == "system":
- system_bedrock_messages.extend(
- [{"text": content} for content in content_list]
- )
- elif role == "assistant":
- if not assistant_contents:
- assistant_contents = []
-
- assistant_contents.extend(
- [
- {
- "text": content,
- }
- for content in content_list
- ]
- + [
- {
- "toolUse": {
- "input": tool_call.arguments,
- "name": (
- tool_call.tool_name
- if isinstance(tool_call.tool_name, str)
- else tool_call.tool_name.value
- ),
- "toolUseId": tool_call.call_id,
- }
- }
- for tool_call in message.tool_calls
- ]
- )
-
- if user_contents:
- bedrock_messages.append({"role": "user", "content": user_contents})
- user_contents = None
- else:
- # Unknown role
- pass
-
- if user_contents:
- bedrock_messages.append({"role": "user", "content": user_contents})
- if assistant_contents:
- bedrock_messages.append(
- {"role": "assistant", "content": assistant_contents}
- )
-
- if system_bedrock_messages:
- return bedrock_messages, system_bedrock_messages
-
- return bedrock_messages, None
-
- @staticmethod
- def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
- inference_config = {}
- if sampling_params:
- param_mapping = {
- "max_tokens": "maxTokens",
- "temperature": "temperature",
- "top_p": "topP",
- }
-
- for k, v in param_mapping.items():
- if getattr(sampling_params, k):
- inference_config[v] = getattr(sampling_params, k)
-
- return inference_config
-
- @staticmethod
- def _tool_parameters_to_input_schema(
- tool_parameters: Optional[Dict[str, ToolParamDefinition]],
- ) -> Dict:
- input_schema = {"type": "object"}
- if not tool_parameters:
- return input_schema
-
- json_properties = {}
- required = []
- for name, param in tool_parameters.items():
- json_property = {
- "type": param.param_type,
- }
-
- if param.description:
- json_property["description"] = param.description
- if param.required:
- required.append(name)
- json_properties[name] = json_property
-
- input_schema["properties"] = json_properties
- if required:
- input_schema["required"] = required
- return input_schema
-
- @staticmethod
- def _tools_to_tool_config(
- tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
- ) -> Optional[Dict]:
- if not tools:
- return None
-
- bedrock_tools = []
- for tool in tools:
- tool_name = (
- tool.tool_name
- if isinstance(tool.tool_name, str)
- else tool.tool_name.value
- )
-
- tool_spec = {
- "toolSpec": {
- "name": tool_name,
- "inputSchema": {
- "json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
- tool.parameters
- ),
- },
- }
- }
-
- if tool.description:
- tool_spec["toolSpec"]["description"] = tool.description
-
- bedrock_tools.append(tool_spec)
- tool_config = {
- "tools": bedrock_tools,
- }
-
- if tool_choice:
- tool_config["toolChoice"] = (
- {"any": {}}
- if tool_choice.value == ToolChoice.required
- else {"auto": {}}
- )
- return tool_config
-
async def chat_completion(
self,
model_id: str,
@@ -337,118 +115,70 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
- params = self._get_params_for_chat_completion(request)
- converse_api_res = self.client.converse(**params)
+ params = await self._get_params_for_chat_completion(request)
+ res = self.client.invoke_model(**params)
+ chunk = next(res["body"])
+ result = json.loads(chunk.decode("utf-8"))
- output_message = BedrockInferenceAdapter._bedrock_message_to_message(
- converse_api_res
+ choice = OpenAICompatCompletionChoice(
+ finish_reason=result["stop_reason"],
+ text=result["generation"],
)
- return ChatCompletionResponse(
- completion_message=output_message,
- logprobs=None,
- )
+ response = OpenAICompatCompletionResponse(choices=[choice])
+ return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
- params = self._get_params_for_chat_completion(request)
- converse_stream_api_res = self.client.converse_stream(**params)
- event_stream = converse_stream_api_res["stream"]
+ params = await self._get_params_for_chat_completion(request)
+ res = self.client.invoke_model_with_response_stream(**params)
+ event_stream = res["body"]
- for chunk in event_stream:
- if "messageStart" in chunk:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.start,
- delta="",
- )
+ async def _generate_and_convert_to_openai_compat():
+ for chunk in event_stream:
+ chunk = chunk["chunk"]["bytes"]
+ result = json.loads(chunk.decode("utf-8"))
+ choice = OpenAICompatCompletionChoice(
+ finish_reason=result["stop_reason"],
+ text=result["generation"],
)
- elif "contentBlockStart" in chunk:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.progress,
- delta=ToolCallDelta(
- content=ToolCall(
- tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
- call_id=chunk["contentBlockStart"]["toolUse"][
- "toolUseId"
- ],
- ),
- parse_status=ToolCallParseStatus.started,
- ),
- )
- )
- elif "contentBlockDelta" in chunk:
- if "text" in chunk["contentBlockDelta"]["delta"]:
- delta = chunk["contentBlockDelta"]["delta"]["text"]
- else:
- delta = ToolCallDelta(
- content=ToolCall(
- arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
- "input"
- ]
- ),
- parse_status=ToolCallParseStatus.success,
- )
+ yield OpenAICompatCompletionResponse(choices=[choice])
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.progress,
- delta=delta,
- )
- )
- elif "contentBlockStop" in chunk:
- # Ignored
- pass
- elif "messageStop" in chunk:
- stop_reason = (
- BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
- chunk["messageStop"]["stopReason"]
- )
- )
+ stream = _generate_and_convert_to_openai_compat()
+ async for chunk in process_chat_completion_stream_response(
+ stream, self.formatter
+ ):
+ yield chunk
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.complete,
- delta="",
- stop_reason=stop_reason,
- )
- )
- elif "metadata" in chunk:
- # Ignored
- pass
- else:
- # Ignored
- pass
-
- def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
+ async def _get_params_for_chat_completion(
+ self, request: ChatCompletionRequest
+ ) -> Dict:
bedrock_model = request.model
- inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
- request.sampling_params
- )
- tool_config = BedrockInferenceAdapter._tools_to_tool_config(
- request.tools, request.tool_choice
- )
- bedrock_messages, system_bedrock_messages = (
- BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
- )
-
- converse_api_params = {
- "modelId": bedrock_model,
- "messages": bedrock_messages,
+ inference_config = {}
+ param_mapping = {
+ "max_tokens": "max_gen_len",
+ "temperature": "temperature",
+ "top_p": "top_p",
}
- if inference_config:
- converse_api_params["inferenceConfig"] = inference_config
- # Tool use is not supported in streaming mode
- if tool_config and not request.stream:
- converse_api_params["toolConfig"] = tool_config
- if system_bedrock_messages:
- converse_api_params["system"] = system_bedrock_messages
+ for k, v in param_mapping.items():
+ if getattr(request.sampling_params, k):
+ inference_config[v] = getattr(request.sampling_params, k)
- return converse_api_params
+ prompt = await chat_completion_request_to_prompt(
+ request, self.get_llama_model(request.model), self.formatter
+ )
+ return {
+ "modelId": bedrock_model,
+ "body": json.dumps(
+ {
+ "prompt": prompt,
+ **inference_config,
+ }
+ ),
+ }
async def embeddings(
self,
diff --git a/requirements.txt b/requirements.txt
index f57f688b7..304467ddc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,8 @@ blobfile
fire
httpx
huggingface-hub
-llama-models>=0.0.62
-llama-stack-client>=0.0.62
+llama-models>=0.0.63
+llama-stack-client>=0.0.63
prompt-toolkit
python-dotenv
pydantic>=2
diff --git a/setup.py b/setup.py
index e8e3de5b2..c0f8cf575 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
- version="0.0.62",
+ version="0.0.63",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",