Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-18 14:14:14 -08:00
commit 75c881770a
8 changed files with 272 additions and 676 deletions

View file

@ -2598,6 +2598,22 @@
} }
] ]
}, },
"Message": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
},
"SamplingParams": { "SamplingParams": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -2893,9 +2909,16 @@
] ]
}, },
"URL": { "URL": {
"type": "string", "type": "object",
"format": "uri", "properties": {
"pattern": "^(https?://|file://|data:)" "uri": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"uri"
]
}, },
"UserMessage": { "UserMessage": {
"type": "object", "type": "object",
@ -2929,20 +2952,7 @@
"items": { "items": {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [ "$ref": "#/components/schemas/Message"
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
} }
} }
}, },
@ -3052,47 +3062,7 @@
"job_uuid" "job_uuid"
] ]
}, },
"ChatCompletionRequest": { "ResponseFormat": {
"type": "object",
"properties": {
"model_id": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
}
},
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"tools": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDefinition"
}
},
"tool_choice": {
"$ref": "#/components/schemas/ToolChoice"
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat"
},
"response_format": {
"oneOf": [ "oneOf": [
{ {
"type": "object", "type": "object",
@ -3176,6 +3146,36 @@
} }
] ]
}, },
"ChatCompletionRequest": {
"type": "object",
"properties": {
"model_id": {
"type": "string"
},
"messages": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
},
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"tools": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDefinition"
}
},
"tool_choice": {
"$ref": "#/components/schemas/ToolChoice"
},
"tool_prompt_format": {
"$ref": "#/components/schemas/ToolPromptFormat"
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat"
},
"stream": { "stream": {
"type": "boolean" "type": "boolean"
}, },
@ -3329,88 +3329,7 @@
"$ref": "#/components/schemas/SamplingParams" "$ref": "#/components/schemas/SamplingParams"
}, },
"response_format": { "response_format": {
"oneOf": [ "$ref": "#/components/schemas/ResponseFormat"
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
]
},
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "grammar",
"default": "grammar"
},
"bnf": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"bnf"
]
}
]
}, },
"stream": { "stream": {
"type": "boolean" "type": "boolean"
@ -7278,20 +7197,7 @@
"messages": { "messages": {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [ "$ref": "#/components/schemas/Message"
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
} }
}, },
"params": { "params": {
@ -7657,20 +7563,7 @@
"dialogs": { "dialogs": {
"type": "array", "type": "array",
"items": { "items": {
"oneOf": [ "$ref": "#/components/schemas/Message"
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
} }
}, },
"filtering_function": { "filtering_function": {
@ -8129,6 +8022,10 @@
"name": "MemoryToolDefinition", "name": "MemoryToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MemoryToolDefinition\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/MemoryToolDefinition\" />"
}, },
{
"name": "Message",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Message\" />"
},
{ {
"name": "MetricEvent", "name": "MetricEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MetricEvent\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/MetricEvent\" />"
@ -8247,6 +8144,10 @@
"name": "RegisterShieldRequest", "name": "RegisterShieldRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RegisterShieldRequest\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/RegisterShieldRequest\" />"
}, },
{
"name": "ResponseFormat",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ResponseFormat\" />"
},
{ {
"name": "RestAPIExecutionConfig", "name": "RestAPIExecutionConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RestAPIExecutionConfig\" />" "description": "<SchemaDefinition schemaRef=\"#/components/schemas/RestAPIExecutionConfig\" />"
@ -8591,6 +8492,7 @@
"MemoryBankDocument", "MemoryBankDocument",
"MemoryRetrievalStep", "MemoryRetrievalStep",
"MemoryToolDefinition", "MemoryToolDefinition",
"Message",
"MetricEvent", "MetricEvent",
"Model", "Model",
"ModelCandidate", "ModelCandidate",
@ -8619,6 +8521,7 @@
"RegisterModelRequest", "RegisterModelRequest",
"RegisterScoringFunctionRequest", "RegisterScoringFunctionRequest",
"RegisterShieldRequest", "RegisterShieldRequest",
"ResponseFormat",
"RestAPIExecutionConfig", "RestAPIExecutionConfig",
"RestAPIMethod", "RestAPIMethod",
"RouteInfo", "RouteInfo",

View file

@ -313,11 +313,7 @@ components:
messages_batch: messages_batch:
items: items:
items: items:
oneOf: $ref: '#/components/schemas/Message'
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array type: array
type: array type: array
model: model:
@ -422,56 +418,12 @@ components:
type: object type: object
messages: messages:
items: items:
oneOf: $ref: '#/components/schemas/Message'
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array type: array
model_id: model_id:
type: string type: string
response_format: response_format:
oneOf: $ref: '#/components/schemas/ResponseFormat'
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
sampling_params: sampling_params:
$ref: '#/components/schemas/SamplingParams' $ref: '#/components/schemas/SamplingParams'
stream: stream:
@ -598,47 +550,7 @@ components:
model_id: model_id:
type: string type: string
response_format: response_format:
oneOf: $ref: '#/components/schemas/ResponseFormat'
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
sampling_params: sampling_params:
$ref: '#/components/schemas/SamplingParams' $ref: '#/components/schemas/SamplingParams'
stream: stream:
@ -1467,6 +1379,12 @@ components:
- max_tokens_in_context - max_tokens_in_context
- max_chunks - max_chunks
type: object type: object
Message:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
MetricEvent: MetricEvent:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2121,6 +2039,48 @@ components:
required: required:
- shield_id - shield_id
type: object type: object
ResponseFormat:
oneOf:
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
RestAPIExecutionConfig: RestAPIExecutionConfig:
additionalProperties: false additionalProperties: false
properties: properties:
@ -2203,11 +2163,7 @@ components:
properties: properties:
messages: messages:
items: items:
oneOf: $ref: '#/components/schemas/Message'
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array type: array
params: params:
additionalProperties: additionalProperties:
@ -2744,11 +2700,7 @@ components:
properties: properties:
dialogs: dialogs:
items: items:
oneOf: $ref: '#/components/schemas/Message'
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
type: array type: array
filtering_function: filtering_function:
enum: enum:
@ -3105,9 +3057,13 @@ components:
title: A single turn in an interaction with an Agentic System. title: A single turn in an interaction with an Agentic System.
type: object type: object
URL: URL:
format: uri additionalProperties: false
pattern: ^(https?://|file://|data:) properties:
uri:
type: string type: string
required:
- uri
type: object
UnregisterDatasetRequest: UnregisterDatasetRequest:
additionalProperties: false additionalProperties: false
properties: properties:
@ -5020,6 +4976,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryToolDefinition" - description: <SchemaDefinition schemaRef="#/components/schemas/MemoryToolDefinition"
/> />
name: MemoryToolDefinition name: MemoryToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/Message" />
name: Message
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" /> - description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
name: MetricEvent name: MetricEvent
- description: <SchemaDefinition schemaRef="#/components/schemas/Model" /> - description: <SchemaDefinition schemaRef="#/components/schemas/Model" />
@ -5104,6 +5062,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest" - description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
/> />
name: RegisterShieldRequest name: RegisterShieldRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ResponseFormat" />
name: ResponseFormat
- description: <SchemaDefinition schemaRef="#/components/schemas/RestAPIExecutionConfig" - description: <SchemaDefinition schemaRef="#/components/schemas/RestAPIExecutionConfig"
/> />
name: RestAPIExecutionConfig name: RestAPIExecutionConfig
@ -5367,6 +5327,7 @@ x-tagGroups:
- MemoryBankDocument - MemoryBankDocument
- MemoryRetrievalStep - MemoryRetrievalStep
- MemoryToolDefinition - MemoryToolDefinition
- Message
- MetricEvent - MetricEvent
- Model - Model
- ModelCandidate - ModelCandidate
@ -5395,6 +5356,7 @@ x-tagGroups:
- RegisterModelRequest - RegisterModelRequest
- RegisterScoringFunctionRequest - RegisterScoringFunctionRequest
- RegisterShieldRequest - RegisterShieldRequest
- ResponseFormat
- RestAPIExecutionConfig - RestAPIExecutionConfig
- RestAPIMethod - RestAPIMethod
- RouteInfo - RouteInfo

View file

@ -11,15 +11,10 @@ from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@json_schema_type( @json_schema_type
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
class URL(BaseModel): class URL(BaseModel):
uri: str uri: str
def __str__(self) -> str:
return self.uri
class _URLOrData(BaseModel): class _URLOrData(BaseModel):
url: Optional[URL] = None url: Optional[URL] = None

View file

@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat, ToolPromptFormat,
) )
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
@ -100,7 +100,8 @@ class CompletionMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list)
Message = Annotated[ Message = register_schema(
Annotated[
Union[ Union[
UserMessage, UserMessage,
SystemMessage, SystemMessage,
@ -108,7 +109,9 @@ Message = Annotated[
CompletionMessage, CompletionMessage,
], ],
Field(discriminator="role"), Field(discriminator="role"),
] ],
name="Message",
)
@json_schema_type @json_schema_type
@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any] bnf: Dict[str, Any]
ResponseFormat = Annotated[ ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat], Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"), Field(discriminator="type"),
] ],
name="ResponseFormat",
)
@json_schema_type @json_schema_type

View file

@ -144,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None: if default_val is None:
raise EnvVarError(env_var, path) raise EnvVarError(env_var, path)
else: else:
value = default_val if default_val != "null" else None value = default_val
# expand "~" from the values # expand "~" from the values
return os.path.expanduser(value) return os.path.expanduser(value)

View file

@ -6,20 +6,25 @@
from typing import * # noqa: F403 from typing import * # noqa: F403
import json import json
import uuid
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
build_model_alias, build_model_alias,
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
content_has_media, content_has_media,
interleaved_content_as_str, interleaved_content_as_str,
) )
@ -46,7 +51,6 @@ MODEL_ALIASES = [
] ]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference): class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None: def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES) ModelRegistryHelper.__init__(self, MODEL_ALIASES)
@ -76,232 +80,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator: ) -> AsyncGenerator:
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens
return StopReason.end_of_turn
@staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str:
return builtin_tool
else:
return tool_name_str
@staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"]
)
bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"]
contents = bedrock_message["content"]
tool_calls = []
text_content = ""
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_calls.append(
ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"]
),
arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"],
)
)
elif "text" in content:
text_content += content["text"]
return CompletionMessage(
role=role,
content=text_content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
@staticmethod
def _messages_to_bedrock_messages(
messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = []
system_bedrock_messages = []
user_contents = []
assistant_contents = None
for message in messages:
role = message.role
content_list = (
message.content
if isinstance(message.content, list)
else [message.content]
)
if role == "ipython" or role == "user":
if not user_contents:
user_contents = []
if role == "ipython":
user_contents.extend(
[
{
"toolResult": {
"toolUseId": message.call_id or str(uuid.uuid4()),
"content": [
{"text": content} for content in content_list
],
}
}
]
)
else:
user_contents.extend(
[{"text": content} for content in content_list]
)
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
assistant_contents = None
elif role == "system":
system_bedrock_messages.extend(
[{"text": content} for content in content_list]
)
elif role == "assistant":
if not assistant_contents:
assistant_contents = []
assistant_contents.extend(
[
{
"text": content,
}
for content in content_list
]
+ [
{
"toolUse": {
"input": tool_call.arguments,
"name": (
tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value
),
"toolUseId": tool_call.call_id,
}
}
for tool_call in message.tool_calls
]
)
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None
else:
# Unknown role
pass
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages
return bedrock_messages, None
@staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {}
if sampling_params:
param_mapping = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
for k, v in param_mapping.items():
if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k)
return inference_config
@staticmethod
def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
) -> Dict:
input_schema = {"type": "object"}
if not tool_parameters:
return input_schema
json_properties = {}
required = []
for name, param in tool_parameters.items():
json_property = {
"type": param.param_type,
}
if param.description:
json_property["description"] = param.description
if param.required:
required.append(name)
json_properties[name] = json_property
input_schema["properties"] = json_properties
if required:
input_schema["required"] = required
return input_schema
@staticmethod
def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]:
if not tools:
return None
bedrock_tools = []
for tool in tools:
tool_name = (
tool.tool_name
if isinstance(tool.tool_name, str)
else tool.tool_name.value
)
tool_spec = {
"toolSpec": {
"name": tool_name,
"inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters
),
},
}
}
if tool.description:
tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec)
tool_config = {
"tools": bedrock_tools,
}
if tool_choice:
tool_config["toolChoice"] = (
{"any": {}}
if tool_choice.value == ToolChoice.required
else {"auto": {}}
)
return tool_config
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -337,118 +115,70 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params_for_chat_completion(request) params = await self._get_params_for_chat_completion(request)
converse_api_res = self.client.converse(**params) res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
output_message = BedrockInferenceAdapter._bedrock_message_to_message( choice = OpenAICompatCompletionChoice(
converse_api_res finish_reason=result["stop_reason"],
text=result["generation"],
) )
return ChatCompletionResponse( response = OpenAICompatCompletionResponse(choices=[choice])
completion_message=output_message, return process_chat_completion_response(response, self.formatter)
logprobs=None,
)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = self._get_params_for_chat_completion(request) params = await self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params) res = self.client.invoke_model_with_response_stream(**params)
event_stream = converse_stream_api_res["stream"] event_stream = res["body"]
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream: for chunk in event_stream:
if "messageStart" in chunk: chunk = chunk["chunk"]["bytes"]
yield ChatCompletionResponseStreamChunk( result = json.loads(chunk.decode("utf-8"))
event=ChatCompletionResponseEvent( choice = OpenAICompatCompletionChoice(
event_type=ChatCompletionResponseEventType.start, finish_reason=result["stop_reason"],
delta="", text=result["generation"],
)
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"input"
]
),
parse_status=ToolCallParseStatus.success,
) )
yield OpenAICompatCompletionResponse(choices=[choice])
yield ChatCompletionResponseStreamChunk( stream = _generate_and_convert_to_openai_compat()
event=ChatCompletionResponseEvent( async for chunk in process_chat_completion_stream_response(
event_type=ChatCompletionResponseEventType.progress, stream, self.formatter
delta=delta, ):
) yield chunk
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
yield ChatCompletionResponseStreamChunk( async def _get_params_for_chat_completion(
event=ChatCompletionResponseEvent( self, request: ChatCompletionRequest
event_type=ChatCompletionResponseEventType.complete, ) -> Dict:
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
bedrock_model = request.model bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config( inference_config = {}
request.tools, request.tool_choice param_mapping = {
) "max_tokens": "max_gen_len",
bedrock_messages, system_bedrock_messages = ( "temperature": "temperature",
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages) "top_p": "top_p",
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
} }
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode for k, v in param_mapping.items():
if tool_config and not request.stream: if getattr(request.sampling_params, k):
converse_api_params["toolConfig"] = tool_config inference_config[v] = getattr(request.sampling_params, k)
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
return converse_api_params prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
return {
"modelId": bedrock_model,
"body": json.dumps(
{
"prompt": prompt,
**inference_config,
}
),
}
async def embeddings( async def embeddings(
self, self,

View file

@ -2,8 +2,8 @@ blobfile
fire fire
httpx httpx
huggingface-hub huggingface-hub
llama-models>=0.0.62 llama-models>=0.0.63
llama-stack-client>=0.0.62 llama-stack-client>=0.0.63
prompt-toolkit prompt-toolkit
python-dotenv python-dotenv
pydantic>=2 pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup( setup(
name="llama_stack", name="llama_stack",
version="0.0.62", version="0.0.63",
author="Meta Llama", author="Meta Llama",
author_email="llama-oss@meta.com", author_email="llama-oss@meta.com",
description="Llama Stack", description="Llama Stack",