diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 2206aa641..7685cbd34 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -87,9 +87,9 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str: class BedrockInferenceAdapter( ModelRegistryHelper, - Inference, OpenAIChatCompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin, + Inference, ): def __init__(self, config: BedrockConfig) -> None: ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -155,7 +155,7 @@ class BedrockInferenceAdapter( async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: params = await self._get_params_for_chat_completion(request) res = self.client.invoke_model(**params) - chunk = next(res["body"]) + chunk = res["body"].read() result = json.loads(chunk.decode("utf-8")) choice = OpenAICompatCompletionChoice( @@ -172,14 +172,16 @@ class BedrockInferenceAdapter( event_stream = res["body"] async def _generate_and_convert_to_openai_compat(): - for chunk in event_stream: - chunk = chunk["chunk"]["bytes"] - result = json.loads(chunk.decode("utf-8")) - choice = OpenAICompatCompletionChoice( - finish_reason=result["stop_reason"], - text=result["generation"], - ) - yield OpenAICompatCompletionResponse(choices=[choice]) + for event in event_stream: + if "chunk" in event: + chunk_data = event["chunk"]["bytes"] + result = json.loads(chunk_data.decode("utf-8")) + if "generation" in result: + choice = OpenAICompatCompletionChoice( + finish_reason=result.get("stop_reason"), + text=result["generation"], + ) + yield OpenAICompatCompletionResponse(choices=[choice]) stream = _generate_and_convert_to_openai_compat() async for chunk in process_chat_completion_stream_response(stream, request): @@ -193,8 +195,9 @@ class BedrockInferenceAdapter( if sampling_params.max_tokens: options["max_gen_len"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - options["repetition_penalty"] = sampling_params.repetition_penalty + # Note: repetition_penalty is not supported by AWS Bedrock Llama models + # if sampling_params.repetition_penalty > 0: + # options["repetition_penalty"] = sampling_params.repetition_penalty prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))