error handling

This commit is contained in:
Xi Yan 2024-10-07 17:40:43 -07:00
parent 8a67e7a2bd
commit f1d31fe9b5
2 changed files with 6 additions and 2 deletions

View file

@ -101,7 +101,7 @@ class InferenceRouter(Inference):
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = False, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
params = dict( params = dict(
model=model, model=model,
@ -125,7 +125,7 @@ class InferenceRouter(Inference):
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = False, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ) -> Union[CompletionResponse, CompletionResponseStreamChunk]:
return await self.routing_table.get_provider_impl(model).completion( return await self.routing_table.get_provider_impl(model).completion(
model=model, model=model,

View file

@ -135,6 +135,10 @@ class MetaReferenceInferenceImpl(Inference, RoutableProvider):
assert ( assert (
len(token_result.logprobs) == 1 len(token_result.logprobs) == 1
), "Expected logprob to contain 1 result for the current token" ), "Expected logprob to contain 1 result for the current token"
assert (
logprobs.top_k == 1
), "Only top_k=1 is supported for LogProbConfig"
logprobs.append( logprobs.append(
TokenLogProbs( TokenLogProbs(
logprobs_by_token={ logprobs_by_token={