mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 03:19:52 +00:00
address feedback
This commit is contained in:
parent
e167e9eb93
commit
5821ec9ef3
12 changed files with 61 additions and 76 deletions
|
|
@ -12,7 +12,6 @@ from llama_models.datatypes import CoreModelId
|
|||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import Message
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from openai import OpenAI
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
|
@ -29,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
|
@ -105,9 +105,6 @@ class FireworksInferenceAdapter(
|
|||
fireworks_api_key = self._get_api_key()
|
||||
return Fireworks(api_key=fireworks_api_key)
|
||||
|
||||
def _get_openai_client(self) -> OpenAI:
|
||||
return OpenAI(base_url=self.config.url, api_key=self._get_api_key())
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -272,12 +269,16 @@ class FireworksInferenceAdapter(
|
|||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
client = self._get_openai_client()
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimensions"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||
response = client.embeddings.create(
|
||||
model=model.provider_resource_id, input=contents, **kwargs
|
||||
assert all(
|
||||
not content_has_media(content) for content in contents
|
||||
), "Fireworks does not support media for embeddings"
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue