eleuther generate until

This commit is contained in:
Xi Yan 2024-10-08 23:57:22 -07:00
parent 6abef716dd
commit 9c38d9ae13
3 changed files with 34 additions and 40 deletions

View file

@ -91,52 +91,47 @@ class MetaReferenceInferenceImpl(Inference):
else: else:
return self._nonstream_chat_completion(request) return self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
async with SEMAPHORE: messages = chat_completion_request_to_messages(request)
messages = chat_completion_request_to_messages(request)
tokens = [] tokens = []
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
for token_result in self.generator.chat_completion( for token_result in self.generator.chat_completion(
messages=messages, messages=messages,
temperature=request.sampling_params.temperature, temperature=request.sampling_params.temperature,
top_p=request.sampling_params.top_p, top_p=request.sampling_params.top_p,
max_gen_len=request.sampling_params.max_tokens, max_gen_len=request.sampling_params.max_tokens,
logprobs=request.logprobs, logprobs=request.logprobs,
tool_prompt_format=request.tool_prompt_format, tool_prompt_format=request.tool_prompt_format,
): ):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
elif token_result.text == "<|eom_id|>": elif token_result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
if request.logprobs: if request.logprobs:
assert len(token_result.logprobs) == 1 assert len(token_result.logprobs) == 1
logprobs.append( logprobs.append(
TokenLogProbs( TokenLogProbs(
logprobs_by_token={ logprobs_by_token={token_result.text: token_result.logprobs[0]}
token_result.text: token_result.logprobs[0]
}
)
) )
)
if stop_reason is None: if stop_reason is None:
stop_reason = StopReason.out_of_tokens stop_reason = StopReason.out_of_tokens
message = self.generator.formatter.decode_assistant_message( message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
tokens, stop_reason return ChatCompletionResponse(
) completion_message=message,
return ChatCompletionResponse( logprobs=logprobs if request.logprobs else None,
completion_message=message, )
logprobs=logprobs if request.logprobs else None,
)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest

View file

@ -102,9 +102,8 @@ class EleutherEvalsWrapper(LM):
stream=False, stream=False,
) )
print(response) print(response)
res.append(response.completion_message) res.append(response.completion_message.content)
print(response)
return res return res

View file

@ -14,8 +14,8 @@ apis:
- evals - evals
providers: providers:
evals: evals:
- provider_id: meta-reference - provider_id: eleuther
provider_type: meta-reference provider_type: eleuther
config: {} config: {}
inference: inference:
- provider_id: meta-reference - provider_id: meta-reference