diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index cd92a10f5..33112012b 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -2598,6 +2598,22 @@ } ] }, + "Message": { + "oneOf": [ + { + "$ref": "#/components/schemas/UserMessage" + }, + { + "$ref": "#/components/schemas/SystemMessage" + }, + { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + { + "$ref": "#/components/schemas/CompletionMessage" + } + ] + }, "SamplingParams": { "type": "object", "properties": { @@ -2893,9 +2909,16 @@ ] }, "URL": { - "type": "string", - "format": "uri", - "pattern": "^(https?://|file://|data:)" + "type": "object", + "properties": { + "uri": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "uri" + ] }, "UserMessage": { "type": "object", @@ -2929,20 +2952,7 @@ "items": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } } }, @@ -3052,6 +3062,90 @@ "job_uuid" ] }, + "ResponseFormat": { + "oneOf": [ + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "json_schema", + "default": "json_schema" + }, + "json_schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "json_schema" + ] + }, + { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "grammar", + "default": "grammar" + }, + "bnf": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "type", + "bnf" + ] + } + ] + }, "ChatCompletionRequest": { "type": "object", "properties": { @@ -3061,20 +3155,7 @@ "messages": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "sampling_params": { @@ -3093,88 +3174,7 @@ "$ref": "#/components/schemas/ToolPromptFormat" }, "response_format": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json_schema", - "default": "json_schema" - }, - "json_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "json_schema" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "grammar", - "default": "grammar" - }, - "bnf": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "bnf" - ] - } - ] + "$ref": "#/components/schemas/ResponseFormat" }, "stream": { "type": "boolean" @@ -3329,88 +3329,7 @@ "$ref": "#/components/schemas/SamplingParams" }, "response_format": { - "oneOf": [ - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "json_schema", - "default": "json_schema" - }, - "json_schema": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "json_schema" - ] - }, - { - "type": "object", - "properties": { - "type": { - "type": "string", - "const": "grammar", - "default": "grammar" - }, - "bnf": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - }, - "additionalProperties": false, - "required": [ - "type", - "bnf" - ] - } - ] + "$ref": "#/components/schemas/ResponseFormat" }, "stream": { "type": "boolean" @@ -7278,20 +7197,7 @@ "messages": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "params": { @@ -7657,20 +7563,7 @@ "dialogs": { "type": "array", "items": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ] + "$ref": "#/components/schemas/Message" } }, "filtering_function": { @@ -8129,6 +8022,10 @@ "name": "MemoryToolDefinition", "description": "" }, + { + "name": "Message", + "description": "" + }, { "name": "MetricEvent", "description": "" @@ -8247,6 +8144,10 @@ "name": "RegisterShieldRequest", "description": "" }, + { + "name": "ResponseFormat", + "description": "" + }, { "name": "RestAPIExecutionConfig", "description": "" @@ -8591,6 +8492,7 @@ "MemoryBankDocument", "MemoryRetrievalStep", "MemoryToolDefinition", + "Message", "MetricEvent", "Model", "ModelCandidate", @@ -8619,6 +8521,7 @@ "RegisterModelRequest", "RegisterScoringFunctionRequest", "RegisterShieldRequest", + "ResponseFormat", "RestAPIExecutionConfig", "RestAPIMethod", "RouteInfo", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 08db0699e..abd57e17e 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -313,11 +313,7 @@ components: messages_batch: items: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array type: array model: @@ -422,56 +418,12 @@ components: type: object messages: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array model_id: type: string response_format: - oneOf: - - additionalProperties: false - properties: - json_schema: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: json_schema - default: json_schema - type: string - required: - - type - - json_schema - type: object - - additionalProperties: false - properties: - bnf: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: grammar - default: grammar - type: string - required: - - type - - bnf - type: object + $ref: '#/components/schemas/ResponseFormat' sampling_params: $ref: '#/components/schemas/SamplingParams' stream: @@ -598,47 +550,7 @@ components: model_id: type: string response_format: - oneOf: - - additionalProperties: false - properties: - json_schema: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: json_schema - default: json_schema - type: string - required: - - type - - json_schema - type: object - - additionalProperties: false - properties: - bnf: - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: object - type: - const: grammar - default: grammar - type: string - required: - - type - - bnf - type: object + $ref: '#/components/schemas/ResponseFormat' sampling_params: $ref: '#/components/schemas/SamplingParams' stream: @@ -1467,6 +1379,12 @@ components: - max_tokens_in_context - max_chunks type: object + Message: + oneOf: + - $ref: '#/components/schemas/UserMessage' + - $ref: '#/components/schemas/SystemMessage' + - $ref: '#/components/schemas/ToolResponseMessage' + - $ref: '#/components/schemas/CompletionMessage' MetricEvent: additionalProperties: false properties: @@ -2121,6 +2039,48 @@ components: required: - shield_id type: object + ResponseFormat: + oneOf: + - additionalProperties: false + properties: + json_schema: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: json_schema + default: json_schema + type: string + required: + - type + - json_schema + type: object + - additionalProperties: false + properties: + bnf: + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: object + type: + const: grammar + default: grammar + type: string + required: + - type + - bnf + type: object RestAPIExecutionConfig: additionalProperties: false properties: @@ -2203,11 +2163,7 @@ components: properties: messages: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array params: additionalProperties: @@ -2744,11 +2700,7 @@ components: properties: dialogs: items: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' + $ref: '#/components/schemas/Message' type: array filtering_function: enum: @@ -3105,9 +3057,13 @@ components: title: A single turn in an interaction with an Agentic System. type: object URL: - format: uri - pattern: ^(https?://|file://|data:) - type: string + additionalProperties: false + properties: + uri: + type: string + required: + - uri + type: object UnregisterDatasetRequest: additionalProperties: false properties: @@ -5020,6 +4976,8 @@ tags: - description: name: MemoryToolDefinition +- description: + name: Message - description: name: MetricEvent - description: @@ -5104,6 +5062,8 @@ tags: - description: name: RegisterShieldRequest +- description: + name: ResponseFormat - description: name: RestAPIExecutionConfig @@ -5367,6 +5327,7 @@ x-tagGroups: - MemoryBankDocument - MemoryRetrievalStep - MemoryToolDefinition + - Message - MetricEvent - Model - ModelCandidate @@ -5395,6 +5356,7 @@ x-tagGroups: - RegisterModelRequest - RegisterScoringFunctionRequest - RegisterShieldRequest + - ResponseFormat - RestAPIExecutionConfig - RestAPIMethod - RouteInfo diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 316a4a5d6..121218a29 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -11,15 +11,10 @@ from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, model_validator -@json_schema_type( - schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"} -) +@json_schema_type class URL(BaseModel): uri: str - def __str__(self) -> str: - return self.uri - class _URLOrData(BaseModel): url: Optional[URL] = None diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index c481d04d7..28b9d9106 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import ( ToolPromptFormat, ) -from llama_models.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated @@ -100,15 +100,18 @@ class CompletionMessage(BaseModel): tool_calls: List[ToolCall] = Field(default_factory=list) -Message = Annotated[ - Union[ - UserMessage, - SystemMessage, - ToolResponseMessage, - CompletionMessage, +Message = register_schema( + Annotated[ + Union[ + UserMessage, + SystemMessage, + ToolResponseMessage, + CompletionMessage, + ], + Field(discriminator="role"), ], - Field(discriminator="role"), -] + name="Message", +) @json_schema_type @@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel): bnf: Dict[str, Any] -ResponseFormat = Annotated[ - Union[JsonSchemaResponseFormat, GrammarResponseFormat], - Field(discriminator="type"), -] +ResponseFormat = register_schema( + Annotated[ + Union[JsonSchemaResponseFormat, GrammarResponseFormat], + Field(discriminator="type"), + ], + name="ResponseFormat", +) @json_schema_type diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 5671082d5..f5180b0db 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -144,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any: if default_val is None: raise EnvVarError(env_var, path) else: - value = default_val if default_val != "null" else None + value = default_val # expand "~" from the values return os.path.expanduser(value) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f80f72a8e..ddf59fda8 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -6,20 +6,25 @@ from typing import * # noqa: F403 import json -import uuid from botocore.client import BaseClient from llama_models.datatypes import CoreModelId from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import ToolParamDefinition from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.providers.utils.inference.model_registry import ( build_model_alias, ModelRegistryHelper, ) +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + process_chat_completion_response, + process_chat_completion_stream_response, +) from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, content_has_media, interleaved_content_as_str, ) @@ -46,7 +51,6 @@ MODEL_ALIASES = [ ] -# NOTE: this is not quite tested after the recent refactors class BedrockInferenceAdapter(ModelRegistryHelper, Inference): def __init__(self, config: BedrockConfig) -> None: ModelRegistryHelper.__init__(self, MODEL_ALIASES) @@ -76,232 +80,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) -> AsyncGenerator: raise NotImplementedError() - @staticmethod - def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason: - if bedrock_stop_reason == "max_tokens": - return StopReason.out_of_tokens - return StopReason.end_of_turn - - @staticmethod - def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]: - for builtin_tool in BuiltinTool: - if builtin_tool.value == tool_name_str: - return builtin_tool - else: - return tool_name_str - - @staticmethod - def _bedrock_message_to_message(converse_api_res: Dict) -> Message: - stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - converse_api_res["stopReason"] - ) - - bedrock_message = converse_api_res["output"]["message"] - - role = bedrock_message["role"] - contents = bedrock_message["content"] - - tool_calls = [] - text_content = "" - for content in contents: - if "toolUse" in content: - tool_use = content["toolUse"] - tool_calls.append( - ToolCall( - tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum( - tool_use["name"] - ), - arguments=tool_use["input"] if "input" in tool_use else None, - call_id=tool_use["toolUseId"], - ) - ) - elif "text" in content: - text_content += content["text"] - - return CompletionMessage( - role=role, - content=text_content, - stop_reason=stop_reason, - tool_calls=tool_calls, - ) - - @staticmethod - def _messages_to_bedrock_messages( - messages: List[Message], - ) -> Tuple[List[Dict], Optional[List[Dict]]]: - bedrock_messages = [] - system_bedrock_messages = [] - - user_contents = [] - assistant_contents = None - for message in messages: - role = message.role - content_list = ( - message.content - if isinstance(message.content, list) - else [message.content] - ) - if role == "ipython" or role == "user": - if not user_contents: - user_contents = [] - - if role == "ipython": - user_contents.extend( - [ - { - "toolResult": { - "toolUseId": message.call_id or str(uuid.uuid4()), - "content": [ - {"text": content} for content in content_list - ], - } - } - ] - ) - else: - user_contents.extend( - [{"text": content} for content in content_list] - ) - - if assistant_contents: - bedrock_messages.append( - {"role": "assistant", "content": assistant_contents} - ) - assistant_contents = None - elif role == "system": - system_bedrock_messages.extend( - [{"text": content} for content in content_list] - ) - elif role == "assistant": - if not assistant_contents: - assistant_contents = [] - - assistant_contents.extend( - [ - { - "text": content, - } - for content in content_list - ] - + [ - { - "toolUse": { - "input": tool_call.arguments, - "name": ( - tool_call.tool_name - if isinstance(tool_call.tool_name, str) - else tool_call.tool_name.value - ), - "toolUseId": tool_call.call_id, - } - } - for tool_call in message.tool_calls - ] - ) - - if user_contents: - bedrock_messages.append({"role": "user", "content": user_contents}) - user_contents = None - else: - # Unknown role - pass - - if user_contents: - bedrock_messages.append({"role": "user", "content": user_contents}) - if assistant_contents: - bedrock_messages.append( - {"role": "assistant", "content": assistant_contents} - ) - - if system_bedrock_messages: - return bedrock_messages, system_bedrock_messages - - return bedrock_messages, None - - @staticmethod - def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict: - inference_config = {} - if sampling_params: - param_mapping = { - "max_tokens": "maxTokens", - "temperature": "temperature", - "top_p": "topP", - } - - for k, v in param_mapping.items(): - if getattr(sampling_params, k): - inference_config[v] = getattr(sampling_params, k) - - return inference_config - - @staticmethod - def _tool_parameters_to_input_schema( - tool_parameters: Optional[Dict[str, ToolParamDefinition]], - ) -> Dict: - input_schema = {"type": "object"} - if not tool_parameters: - return input_schema - - json_properties = {} - required = [] - for name, param in tool_parameters.items(): - json_property = { - "type": param.param_type, - } - - if param.description: - json_property["description"] = param.description - if param.required: - required.append(name) - json_properties[name] = json_property - - input_schema["properties"] = json_properties - if required: - input_schema["required"] = required - return input_schema - - @staticmethod - def _tools_to_tool_config( - tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice] - ) -> Optional[Dict]: - if not tools: - return None - - bedrock_tools = [] - for tool in tools: - tool_name = ( - tool.tool_name - if isinstance(tool.tool_name, str) - else tool.tool_name.value - ) - - tool_spec = { - "toolSpec": { - "name": tool_name, - "inputSchema": { - "json": BedrockInferenceAdapter._tool_parameters_to_input_schema( - tool.parameters - ), - }, - } - } - - if tool.description: - tool_spec["toolSpec"]["description"] = tool.description - - bedrock_tools.append(tool_spec) - tool_config = { - "tools": bedrock_tools, - } - - if tool_choice: - tool_config["toolChoice"] = ( - {"any": {}} - if tool_choice.value == ToolChoice.required - else {"auto": {}} - ) - return tool_config - async def chat_completion( self, model_id: str, @@ -337,118 +115,70 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params_for_chat_completion(request) - converse_api_res = self.client.converse(**params) + params = await self._get_params_for_chat_completion(request) + res = self.client.invoke_model(**params) + chunk = next(res["body"]) + result = json.loads(chunk.decode("utf-8")) - output_message = BedrockInferenceAdapter._bedrock_message_to_message( - converse_api_res + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"], + text=result["generation"], ) - return ChatCompletionResponse( - completion_message=output_message, - logprobs=None, - ) + response = OpenAICompatCompletionResponse(choices=[choice]) + return process_chat_completion_response(response, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params_for_chat_completion(request) - converse_stream_api_res = self.client.converse_stream(**params) - event_stream = converse_stream_api_res["stream"] + params = await self._get_params_for_chat_completion(request) + res = self.client.invoke_model_with_response_stream(**params) + event_stream = res["body"] - for chunk in event_stream: - if "messageStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", - ) + async def _generate_and_convert_to_openai_compat(): + for chunk in event_stream: + chunk = chunk["chunk"]["bytes"] + result = json.loads(chunk.decode("utf-8")) + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"], + text=result["generation"], ) - elif "contentBlockStart" in chunk: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content=ToolCall( - tool_name=chunk["contentBlockStart"]["toolUse"]["name"], - call_id=chunk["contentBlockStart"]["toolUse"][ - "toolUseId" - ], - ), - parse_status=ToolCallParseStatus.started, - ), - ) - ) - elif "contentBlockDelta" in chunk: - if "text" in chunk["contentBlockDelta"]["delta"]: - delta = chunk["contentBlockDelta"]["delta"]["text"] - else: - delta = ToolCallDelta( - content=ToolCall( - arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][ - "input" - ] - ), - parse_status=ToolCallParseStatus.success, - ) + yield OpenAICompatCompletionResponse(choices=[choice]) - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - ) - ) - elif "contentBlockStop" in chunk: - # Ignored - pass - elif "messageStop" in chunk: - stop_reason = ( - BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason( - chunk["messageStop"]["stopReason"] - ) - ) + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response( + stream, self.formatter + ): + yield chunk - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta="", - stop_reason=stop_reason, - ) - ) - elif "metadata" in chunk: - # Ignored - pass - else: - # Ignored - pass - - def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: + async def _get_params_for_chat_completion( + self, request: ChatCompletionRequest + ) -> Dict: bedrock_model = request.model - inference_config = BedrockInferenceAdapter.get_bedrock_inference_config( - request.sampling_params - ) - tool_config = BedrockInferenceAdapter._tools_to_tool_config( - request.tools, request.tool_choice - ) - bedrock_messages, system_bedrock_messages = ( - BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages) - ) - - converse_api_params = { - "modelId": bedrock_model, - "messages": bedrock_messages, + inference_config = {} + param_mapping = { + "max_tokens": "max_gen_len", + "temperature": "temperature", + "top_p": "top_p", } - if inference_config: - converse_api_params["inferenceConfig"] = inference_config - # Tool use is not supported in streaming mode - if tool_config and not request.stream: - converse_api_params["toolConfig"] = tool_config - if system_bedrock_messages: - converse_api_params["system"] = system_bedrock_messages + for k, v in param_mapping.items(): + if getattr(request.sampling_params, k): + inference_config[v] = getattr(request.sampling_params, k) - return converse_api_params + prompt = await chat_completion_request_to_prompt( + request, self.get_llama_model(request.model), self.formatter + ) + return { + "modelId": bedrock_model, + "body": json.dumps( + { + "prompt": prompt, + **inference_config, + } + ), + } async def embeddings( self, diff --git a/requirements.txt b/requirements.txt index f57f688b7..304467ddc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.62 -llama-stack-client>=0.0.62 +llama-models>=0.0.63 +llama-stack-client>=0.0.63 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index e8e3de5b2..c0f8cf575 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.62", + version="0.0.63", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack",