diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index bda5e54c1..d9b3cefd8 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -91,52 +91,47 @@ class MetaReferenceInferenceImpl(Inference): else: return self._nonstream_chat_completion(request) - async def _nonstream_chat_completion( + def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - async with SEMAPHORE: - messages = chat_completion_request_to_messages(request) + messages = chat_completion_request_to_messages(request) - tokens = [] - logprobs = [] - stop_reason = None + tokens = [] + logprobs = [] + stop_reason = None - for token_result in self.generator.chat_completion( - messages=messages, - temperature=request.sampling_params.temperature, - top_p=request.sampling_params.top_p, - max_gen_len=request.sampling_params.max_tokens, - logprobs=request.logprobs, - tool_prompt_format=request.tool_prompt_format, - ): - tokens.append(token_result.token) + for token_result in self.generator.chat_completion( + messages=messages, + temperature=request.sampling_params.temperature, + top_p=request.sampling_params.top_p, + max_gen_len=request.sampling_params.max_tokens, + logprobs=request.logprobs, + tool_prompt_format=request.tool_prompt_format, + ): + tokens.append(token_result.token) - if token_result.text == "<|eot_id|>": - stop_reason = StopReason.end_of_turn - elif token_result.text == "<|eom_id|>": - stop_reason = StopReason.end_of_message + if token_result.text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + elif token_result.text == "<|eom_id|>": + stop_reason = StopReason.end_of_message - if request.logprobs: - assert len(token_result.logprobs) == 1 + if request.logprobs: + assert len(token_result.logprobs) == 1 - logprobs.append( - TokenLogProbs( - logprobs_by_token={ - token_result.text: token_result.logprobs[0] - } - ) + logprobs.append( + TokenLogProbs( + logprobs_by_token={token_result.text: token_result.logprobs[0]} ) + ) - if stop_reason is None: - stop_reason = StopReason.out_of_tokens + if stop_reason is None: + stop_reason = StopReason.out_of_tokens - message = self.generator.formatter.decode_assistant_message( - tokens, stop_reason - ) - return ChatCompletionResponse( - completion_message=message, - logprobs=logprobs if request.logprobs else None, - ) + message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) + return ChatCompletionResponse( + completion_message=message, + logprobs=logprobs if request.logprobs else None, + ) async def _stream_chat_completion( self, request: ChatCompletionRequest diff --git a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py index c22ddfef9..7f307a9d3 100644 --- a/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py +++ b/llama_stack/providers/impls/third_party/evals/eleuther/eleuther.py @@ -102,9 +102,8 @@ class EleutherEvalsWrapper(LM): stream=False, ) print(response) - res.append(response.completion_message) + res.append(response.completion_message.content) - print(response) return res diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 7b4bd66c6..a09736cd4 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -14,8 +14,8 @@ apis: - evals providers: evals: - - provider_id: meta-reference - provider_type: meta-reference + - provider_id: eleuther + provider_type: eleuther config: {} inference: - provider_id: meta-reference