mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
fix
This commit is contained in:
parent
9aee115abe
commit
dfbc61fb67
1 changed files with 22 additions and 10 deletions
|
@ -130,13 +130,19 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self._get_client().completions.create(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_completion_response(r)
|
return process_completion_response(r)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = await self._get_client().completions.create(**params)
|
# Wrapper for async generator similar
|
||||||
|
async def _to_async_generator():
|
||||||
|
stream = self._get_client().completion.create(**params)
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream):
|
async for chunk in process_completion_stream_response(stream):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -205,18 +211,23 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await self._get_client().chat.completions.create(**params)
|
r = await self._get_client().chat.completions.acreate(**params)
|
||||||
else:
|
else:
|
||||||
r = await self._get_client().completions.create(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, request)
|
return process_chat_completion_response(r, request)
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
if "messages" in params:
|
async def _to_async_generator():
|
||||||
stream = await self._get_client().chat.completions.create(**params)
|
if "messages" in params:
|
||||||
else:
|
stream = self._get_client().chat.completions.acreate(**params)
|
||||||
stream = await self._get_client().completions.create(**params)
|
else:
|
||||||
|
stream = self._get_client().completion.acreate(**params)
|
||||||
|
async for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -239,7 +250,8 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if "prompt" in input_dict:
|
if "prompt" in input_dict:
|
||||||
input_dict["prompt"] = self._preprocess_prompt_for_fireworks(input_dict["prompt"])
|
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||||
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
@ -267,7 +279,7 @@ class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, Nee
|
||||||
assert all(not content_has_media(content) for content in contents), (
|
assert all(not content_has_media(content) for content in contents), (
|
||||||
"Fireworks does not support media for embeddings"
|
"Fireworks does not support media for embeddings"
|
||||||
)
|
)
|
||||||
response = await self._get_client().embeddings.create(
|
response = self._get_client().embeddings.create(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
input=[interleaved_content_as_str(content) for content in contents],
|
input=[interleaved_content_as_str(content) for content in contents],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue