eleuther generate until

This commit is contained in:
Xi Yan 2024-10-08 23:57:22 -07:00
parent 6abef716dd
commit 9c38d9ae13
3 changed files with 34 additions and 40 deletions

View file

@ -91,52 +91,47 @@ class MetaReferenceInferenceImpl(Inference):
else:
return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
async with SEMAPHORE:
messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request)
tokens = []
logprobs = []
stop_reason = None
tokens = []
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
for token_result in self.generator.chat_completion(
messages=messages,
temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format,
):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if request.logprobs:
assert len(token_result.logprobs) == 1
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(
TokenLogProbs(
logprobs_by_token={
token_result.text: token_result.logprobs[0]
}
)
logprobs.append(
TokenLogProbs(
logprobs_by_token={token_result.text: token_result.logprobs[0]}
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message(
tokens, stop_reason
)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest

View file

@ -102,9 +102,8 @@ class EleutherEvalsWrapper(LM):
stream=False,
)
print(response)
res.append(response.completion_message)
res.append(response.completion_message.content)
print(response)
return res

View file

@ -14,8 +14,8 @@ apis:
- evals
providers:
evals:
- provider_id: meta-reference
provider_type: meta-reference
- provider_id: eleuther
provider_type: eleuther
config: {}
inference:
- provider_id: meta-reference