mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fix: Bedrock provider returning None due to inheritance order
Bedrock provider was returning None for both streaming and non-streaming inference, causing 'NoneType' object has no attribute 'choices' errors. Primary fix: Reorder inheritance to put mixin classes before protocol class in BedrockInferenceAdapter so actual implementations are called. Additional AWS Bedrock API compatibility fixes: - Fix non-streaming: use res["body"].read() instead of next(res["body"]) - Fix streaming: add proper event structure checks and safe access - Disable repetition_penalty (not supported by Bedrock Llama models) Fixes #3621
This commit is contained in:
parent
606f4cf281
commit
88e60c1bf6
1 changed files with 15 additions and 12 deletions
|
@ -87,9 +87,9 @@ def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||||
|
|
||||||
class BedrockInferenceAdapter(
|
class BedrockInferenceAdapter(
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
Inference,
|
|
||||||
OpenAIChatCompletionToLlamaStackMixin,
|
OpenAIChatCompletionToLlamaStackMixin,
|
||||||
OpenAICompletionToLlamaStackMixin,
|
OpenAICompletionToLlamaStackMixin,
|
||||||
|
Inference,
|
||||||
):
|
):
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||||
|
@ -155,7 +155,7 @@ class BedrockInferenceAdapter(
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params_for_chat_completion(request)
|
params = await self._get_params_for_chat_completion(request)
|
||||||
res = self.client.invoke_model(**params)
|
res = self.client.invoke_model(**params)
|
||||||
chunk = next(res["body"])
|
chunk = res["body"].read()
|
||||||
result = json.loads(chunk.decode("utf-8"))
|
result = json.loads(chunk.decode("utf-8"))
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
@ -172,14 +172,16 @@ class BedrockInferenceAdapter(
|
||||||
event_stream = res["body"]
|
event_stream = res["body"]
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
for chunk in event_stream:
|
for event in event_stream:
|
||||||
chunk = chunk["chunk"]["bytes"]
|
if "chunk" in event:
|
||||||
result = json.loads(chunk.decode("utf-8"))
|
chunk_data = event["chunk"]["bytes"]
|
||||||
choice = OpenAICompatCompletionChoice(
|
result = json.loads(chunk_data.decode("utf-8"))
|
||||||
finish_reason=result["stop_reason"],
|
if "generation" in result:
|
||||||
text=result["generation"],
|
choice = OpenAICompatCompletionChoice(
|
||||||
)
|
finish_reason=result.get("stop_reason"),
|
||||||
yield OpenAICompatCompletionResponse(choices=[choice])
|
text=result["generation"],
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
|
@ -193,8 +195,9 @@ class BedrockInferenceAdapter(
|
||||||
|
|
||||||
if sampling_params.max_tokens:
|
if sampling_params.max_tokens:
|
||||||
options["max_gen_len"] = sampling_params.max_tokens
|
options["max_gen_len"] = sampling_params.max_tokens
|
||||||
if sampling_params.repetition_penalty > 0:
|
# Note: repetition_penalty is not supported by AWS Bedrock Llama models
|
||||||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
# if sampling_params.repetition_penalty > 0:
|
||||||
|
# options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||||
|
|
||||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue