completion() for fireworks (#329)

This commit is contained in:
Dinesh Yeduguru 2024-10-25 16:12:10 -07:00 committed by GitHub
parent 7ec79f3b9d
commit 9b85d9a841
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 44 additions and 4 deletions

View file

@ -20,9 +20,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import FireworksImplConfig
@ -60,7 +63,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_completion(request, client)
else:
return await self._nonstream_completion(request, client)
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
) -> CompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def chat_completion(
self,
@ -110,8 +141,15 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt = chat_completion_request_to_prompt(request, self.formatter)
def _get_params(self, request) -> dict:
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(request, self.formatter)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]