Fix conversion to RawMessage everywhere

This commit is contained in:
Ashwin Bharambe 2024-12-17 13:38:01 -08:00
parent fbca51d6da
commit b7a7caa9a8
11 changed files with 87 additions and 78 deletions

View file

@ -94,14 +94,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
@ -141,7 +141,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
r = await self.client.completions.create(**params)
@ -150,7 +150,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
@ -159,7 +159,7 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
):
yield chunk
def _get_params(
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
@ -167,11 +167,11 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
prompt = ""
if isinstance(request, ChatCompletionRequest):
prompt = chat_completion_request_to_prompt(
prompt = await chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif isinstance(request, CompletionRequest):
prompt = completion_request_to_prompt(request, self.formatter)
prompt = await completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")