Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-18 14:14:14 -08:00
commit 75c881770a
8 changed files with 272 additions and 676 deletions

View file

@ -2598,6 +2598,22 @@
}
]
},
"Message": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
},
"SamplingParams": {
"type": "object",
"properties": {
@ -2893,9 +2909,16 @@
]
},
"URL": {
"type": "string",
"format": "uri",
"pattern": "^(https?://|file://|data:)"
"type": "object",
"properties": {
"uri": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"uri"
]
},
"UserMessage": {
"type": "object",
@ -2929,20 +2952,7 @@
"items": {
"type": "array",
"items": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
"$ref": "#/components/schemas/Message"
}
}
},
@ -3052,6 +3062,90 @@
"job_uuid"
]
},
"ResponseFormat": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
]
},
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "grammar",
"default": "grammar"
},
"bnf": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"bnf"
]
}
]
},
"ChatCompletionRequest": {
"type": "object",
"properties": {
@ -3061,20 +3155,7 @@
"messages": {
"type": "array",
"items": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
"$ref": "#/components/schemas/Message"
}
},
"sampling_params": {
@ -3093,88 +3174,7 @@
"$ref": "#/components/schemas/ToolPromptFormat"
},
"response_format": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
]
},
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "grammar",
"default": "grammar"
},
"bnf": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"bnf"
]
}
]
"$ref": "#/components/schemas/ResponseFormat"
},
"stream": {
"type": "boolean"
@ -3329,88 +3329,7 @@
"$ref": "#/components/schemas/SamplingParams"
},
"response_format": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "json_schema",
"default": "json_schema"
},
"json_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"json_schema"
]
},
{
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "grammar",
"default": "grammar"
},
"bnf": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
}
},
"additionalProperties": false,
"required": [
"type",
"bnf"
]
}
]
"$ref": "#/components/schemas/ResponseFormat"
},
"stream": {
"type": "boolean"
@ -7278,20 +7197,7 @@
"messages": {
"type": "array",
"items": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
"$ref": "#/components/schemas/Message"
}
},
"params": {
@ -7657,20 +7563,7 @@
"dialogs": {
"type": "array",
"items": {
"oneOf": [
{
"$ref": "#/components/schemas/UserMessage"
},
{
"$ref": "#/components/schemas/SystemMessage"
},
{
"$ref": "#/components/schemas/ToolResponseMessage"
},
{
"$ref": "#/components/schemas/CompletionMessage"
}
]
"$ref": "#/components/schemas/Message"
}
},
"filtering_function": {
@ -8129,6 +8022,10 @@
"name": "MemoryToolDefinition",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MemoryToolDefinition\" />"
},
{
"name": "Message",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/Message\" />"
},
{
"name": "MetricEvent",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/MetricEvent\" />"
@ -8247,6 +8144,10 @@
"name": "RegisterShieldRequest",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RegisterShieldRequest\" />"
},
{
"name": "ResponseFormat",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/ResponseFormat\" />"
},
{
"name": "RestAPIExecutionConfig",
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/RestAPIExecutionConfig\" />"
@ -8591,6 +8492,7 @@
"MemoryBankDocument",
"MemoryRetrievalStep",
"MemoryToolDefinition",
"Message",
"MetricEvent",
"Model",
"ModelCandidate",
@ -8619,6 +8521,7 @@
"RegisterModelRequest",
"RegisterScoringFunctionRequest",
"RegisterShieldRequest",
"ResponseFormat",
"RestAPIExecutionConfig",
"RestAPIMethod",
"RouteInfo",

View file

@ -313,11 +313,7 @@ components:
messages_batch:
items:
items:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
$ref: '#/components/schemas/Message'
type: array
type: array
model:
@ -422,56 +418,12 @@ components:
type: object
messages:
items:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
$ref: '#/components/schemas/Message'
type: array
model_id:
type: string
response_format:
oneOf:
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
$ref: '#/components/schemas/ResponseFormat'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
stream:
@ -598,47 +550,7 @@ components:
model_id:
type: string
response_format:
oneOf:
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
$ref: '#/components/schemas/ResponseFormat'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
stream:
@ -1467,6 +1379,12 @@ components:
- max_tokens_in_context
- max_chunks
type: object
Message:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
MetricEvent:
additionalProperties: false
properties:
@ -2121,6 +2039,48 @@ components:
required:
- shield_id
type: object
ResponseFormat:
oneOf:
- additionalProperties: false
properties:
json_schema:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: json_schema
default: json_schema
type: string
required:
- type
- json_schema
type: object
- additionalProperties: false
properties:
bnf:
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
type: object
type:
const: grammar
default: grammar
type: string
required:
- type
- bnf
type: object
RestAPIExecutionConfig:
additionalProperties: false
properties:
@ -2203,11 +2163,7 @@ components:
properties:
messages:
items:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
$ref: '#/components/schemas/Message'
type: array
params:
additionalProperties:
@ -2744,11 +2700,7 @@ components:
properties:
dialogs:
items:
oneOf:
- $ref: '#/components/schemas/UserMessage'
- $ref: '#/components/schemas/SystemMessage'
- $ref: '#/components/schemas/ToolResponseMessage'
- $ref: '#/components/schemas/CompletionMessage'
$ref: '#/components/schemas/Message'
type: array
filtering_function:
enum:
@ -3105,9 +3057,13 @@ components:
title: A single turn in an interaction with an Agentic System.
type: object
URL:
format: uri
pattern: ^(https?://|file://|data:)
type: string
additionalProperties: false
properties:
uri:
type: string
required:
- uri
type: object
UnregisterDatasetRequest:
additionalProperties: false
properties:
@ -5020,6 +4976,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/MemoryToolDefinition"
/>
name: MemoryToolDefinition
- description: <SchemaDefinition schemaRef="#/components/schemas/Message" />
name: Message
- description: <SchemaDefinition schemaRef="#/components/schemas/MetricEvent" />
name: MetricEvent
- description: <SchemaDefinition schemaRef="#/components/schemas/Model" />
@ -5104,6 +5062,8 @@ tags:
- description: <SchemaDefinition schemaRef="#/components/schemas/RegisterShieldRequest"
/>
name: RegisterShieldRequest
- description: <SchemaDefinition schemaRef="#/components/schemas/ResponseFormat" />
name: ResponseFormat
- description: <SchemaDefinition schemaRef="#/components/schemas/RestAPIExecutionConfig"
/>
name: RestAPIExecutionConfig
@ -5367,6 +5327,7 @@ x-tagGroups:
- MemoryBankDocument
- MemoryRetrievalStep
- MemoryToolDefinition
- Message
- MetricEvent
- Model
- ModelCandidate
@ -5395,6 +5356,7 @@ x-tagGroups:
- RegisterModelRequest
- RegisterScoringFunctionRequest
- RegisterShieldRequest
- ResponseFormat
- RestAPIExecutionConfig
- RestAPIMethod
- RouteInfo

View file

@ -11,15 +11,10 @@ from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
@json_schema_type(
schema={"type": "string", "format": "uri", "pattern": "^(https?://|file://|data:)"}
)
@json_schema_type
class URL(BaseModel):
uri: str
def __str__(self) -> str:
return self.uri
class _URLOrData(BaseModel):
url: Optional[URL] = None

View file

@ -25,7 +25,7 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
@ -100,15 +100,18 @@ class CompletionMessage(BaseModel):
tool_calls: List[ToolCall] = Field(default_factory=list)
Message = Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
Message = register_schema(
Annotated[
Union[
UserMessage,
SystemMessage,
ToolResponseMessage,
CompletionMessage,
],
Field(discriminator="role"),
],
Field(discriminator="role"),
]
name="Message",
)
@json_schema_type
@ -187,10 +190,13 @@ class GrammarResponseFormat(BaseModel):
bnf: Dict[str, Any]
ResponseFormat = Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
]
ResponseFormat = register_schema(
Annotated[
Union[JsonSchemaResponseFormat, GrammarResponseFormat],
Field(discriminator="type"),
],
name="ResponseFormat",
)
@json_schema_type

View file

@ -144,7 +144,7 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val if default_val != "null" else None
value = default_val
# expand "~" from the values
return os.path.expanduser(value)

View file

@ -6,20 +6,25 @@
from typing import * # noqa: F403
import json
import uuid
from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ToolParamDefinition
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
)
@ -46,7 +51,6 @@ MODEL_ALIASES = [
]
# NOTE: this is not quite tested after the recent refactors
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
@ -76,232 +80,6 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) -> AsyncGenerator:
raise NotImplementedError()
@staticmethod
def _bedrock_stop_reason_to_stop_reason(bedrock_stop_reason: str) -> StopReason:
if bedrock_stop_reason == "max_tokens":
return StopReason.out_of_tokens
return StopReason.end_of_turn
@staticmethod
def _builtin_tool_name_to_enum(tool_name_str: str) -> Union[BuiltinTool, str]:
for builtin_tool in BuiltinTool:
if builtin_tool.value == tool_name_str:
return builtin_tool
else:
return tool_name_str
@staticmethod
def _bedrock_message_to_message(converse_api_res: Dict) -> Message:
stop_reason = BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
converse_api_res["stopReason"]
)
bedrock_message = converse_api_res["output"]["message"]
role = bedrock_message["role"]
contents = bedrock_message["content"]
tool_calls = []
text_content = ""
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_calls.append(
ToolCall(
tool_name=BedrockInferenceAdapter._builtin_tool_name_to_enum(
tool_use["name"]
),
arguments=tool_use["input"] if "input" in tool_use else None,
call_id=tool_use["toolUseId"],
)
)
elif "text" in content:
text_content += content["text"]
return CompletionMessage(
role=role,
content=text_content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
@staticmethod
def _messages_to_bedrock_messages(
messages: List[Message],
) -> Tuple[List[Dict], Optional[List[Dict]]]:
bedrock_messages = []
system_bedrock_messages = []
user_contents = []
assistant_contents = None
for message in messages:
role = message.role
content_list = (
message.content
if isinstance(message.content, list)
else [message.content]
)
if role == "ipython" or role == "user":
if not user_contents:
user_contents = []
if role == "ipython":
user_contents.extend(
[
{
"toolResult": {
"toolUseId": message.call_id or str(uuid.uuid4()),
"content": [
{"text": content} for content in content_list
],
}
}
]
)
else:
user_contents.extend(
[{"text": content} for content in content_list]
)
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
assistant_contents = None
elif role == "system":
system_bedrock_messages.extend(
[{"text": content} for content in content_list]
)
elif role == "assistant":
if not assistant_contents:
assistant_contents = []
assistant_contents.extend(
[
{
"text": content,
}
for content in content_list
]
+ [
{
"toolUse": {
"input": tool_call.arguments,
"name": (
tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value
),
"toolUseId": tool_call.call_id,
}
}
for tool_call in message.tool_calls
]
)
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
user_contents = None
else:
# Unknown role
pass
if user_contents:
bedrock_messages.append({"role": "user", "content": user_contents})
if assistant_contents:
bedrock_messages.append(
{"role": "assistant", "content": assistant_contents}
)
if system_bedrock_messages:
return bedrock_messages, system_bedrock_messages
return bedrock_messages, None
@staticmethod
def get_bedrock_inference_config(sampling_params: Optional[SamplingParams]) -> Dict:
inference_config = {}
if sampling_params:
param_mapping = {
"max_tokens": "maxTokens",
"temperature": "temperature",
"top_p": "topP",
}
for k, v in param_mapping.items():
if getattr(sampling_params, k):
inference_config[v] = getattr(sampling_params, k)
return inference_config
@staticmethod
def _tool_parameters_to_input_schema(
tool_parameters: Optional[Dict[str, ToolParamDefinition]],
) -> Dict:
input_schema = {"type": "object"}
if not tool_parameters:
return input_schema
json_properties = {}
required = []
for name, param in tool_parameters.items():
json_property = {
"type": param.param_type,
}
if param.description:
json_property["description"] = param.description
if param.required:
required.append(name)
json_properties[name] = json_property
input_schema["properties"] = json_properties
if required:
input_schema["required"] = required
return input_schema
@staticmethod
def _tools_to_tool_config(
tools: Optional[List[ToolDefinition]], tool_choice: Optional[ToolChoice]
) -> Optional[Dict]:
if not tools:
return None
bedrock_tools = []
for tool in tools:
tool_name = (
tool.tool_name
if isinstance(tool.tool_name, str)
else tool.tool_name.value
)
tool_spec = {
"toolSpec": {
"name": tool_name,
"inputSchema": {
"json": BedrockInferenceAdapter._tool_parameters_to_input_schema(
tool.parameters
),
},
}
}
if tool.description:
tool_spec["toolSpec"]["description"] = tool.description
bedrock_tools.append(tool_spec)
tool_config = {
"tools": bedrock_tools,
}
if tool_choice:
tool_config["toolChoice"] = (
{"any": {}}
if tool_choice.value == ToolChoice.required
else {"auto": {}}
)
return tool_config
async def chat_completion(
self,
model_id: str,
@ -337,118 +115,70 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params_for_chat_completion(request)
converse_api_res = self.client.converse(**params)
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model(**params)
chunk = next(res["body"])
result = json.loads(chunk.decode("utf-8"))
output_message = BedrockInferenceAdapter._bedrock_message_to_message(
converse_api_res
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
return ChatCompletionResponse(
completion_message=output_message,
logprobs=None,
)
response = OpenAICompatCompletionResponse(choices=[choice])
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params_for_chat_completion(request)
converse_stream_api_res = self.client.converse_stream(**params)
event_stream = converse_stream_api_res["stream"]
params = await self._get_params_for_chat_completion(request)
res = self.client.invoke_model_with_response_stream(**params)
event_stream = res["body"]
for chunk in event_stream:
if "messageStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
async def _generate_and_convert_to_openai_compat():
for chunk in event_stream:
chunk = chunk["chunk"]["bytes"]
result = json.loads(chunk.decode("utf-8"))
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"],
text=result["generation"],
)
elif "contentBlockStart" in chunk:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=ToolCall(
tool_name=chunk["contentBlockStart"]["toolUse"]["name"],
call_id=chunk["contentBlockStart"]["toolUse"][
"toolUseId"
],
),
parse_status=ToolCallParseStatus.started,
),
)
)
elif "contentBlockDelta" in chunk:
if "text" in chunk["contentBlockDelta"]["delta"]:
delta = chunk["contentBlockDelta"]["delta"]["text"]
else:
delta = ToolCallDelta(
content=ToolCall(
arguments=chunk["contentBlockDelta"]["delta"]["toolUse"][
"input"
]
),
parse_status=ToolCallParseStatus.success,
)
yield OpenAICompatCompletionResponse(choices=[choice])
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
)
)
elif "contentBlockStop" in chunk:
# Ignored
pass
elif "messageStop" in chunk:
stop_reason = (
BedrockInferenceAdapter._bedrock_stop_reason_to_stop_reason(
chunk["messageStop"]["stopReason"]
)
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
elif "metadata" in chunk:
# Ignored
pass
else:
# Ignored
pass
def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
async def _get_params_for_chat_completion(
self, request: ChatCompletionRequest
) -> Dict:
bedrock_model = request.model
inference_config = BedrockInferenceAdapter.get_bedrock_inference_config(
request.sampling_params
)
tool_config = BedrockInferenceAdapter._tools_to_tool_config(
request.tools, request.tool_choice
)
bedrock_messages, system_bedrock_messages = (
BedrockInferenceAdapter._messages_to_bedrock_messages(request.messages)
)
converse_api_params = {
"modelId": bedrock_model,
"messages": bedrock_messages,
inference_config = {}
param_mapping = {
"max_tokens": "max_gen_len",
"temperature": "temperature",
"top_p": "top_p",
}
if inference_config:
converse_api_params["inferenceConfig"] = inference_config
# Tool use is not supported in streaming mode
if tool_config and not request.stream:
converse_api_params["toolConfig"] = tool_config
if system_bedrock_messages:
converse_api_params["system"] = system_bedrock_messages
for k, v in param_mapping.items():
if getattr(request.sampling_params, k):
inference_config[v] = getattr(request.sampling_params, k)
return converse_api_params
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
return {
"modelId": bedrock_model,
"body": json.dumps(
{
"prompt": prompt,
**inference_config,
}
),
}
async def embeddings(
self,

View file

@ -2,8 +2,8 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.62
llama-stack-client>=0.0.62
llama-models>=0.0.63
llama-stack-client>=0.0.63
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.62",
version="0.0.63",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",