mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 13:32:32 +00:00
Address review comments
Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
parent
ca771cd921
commit
1d941b6aa0
9 changed files with 61 additions and 200 deletions
|
|
@ -3,3 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -3,44 +3,33 @@ distribution_spec:
|
||||||
description: Use watsonx for running LLM inference
|
description: Use watsonx for running LLM inference
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_type: remote::watsonx
|
||||||
provider_type: remote::watsonx
|
- provider_type: inline::sentence-transformers
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_type: inline::faiss
|
||||||
provider_type: inline::faiss
|
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_type: inline::llama-guard
|
||||||
provider_type: inline::llama-guard
|
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_type: inline::meta-reference
|
||||||
provider_type: inline::meta-reference
|
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: huggingface
|
- provider_type: remote::huggingface
|
||||||
provider_type: remote::huggingface
|
- provider_type: inline::localfs
|
||||||
- provider_id: localfs
|
|
||||||
provider_type: inline::localfs
|
|
||||||
scoring:
|
scoring:
|
||||||
- provider_id: basic
|
- provider_type: inline::basic
|
||||||
provider_type: inline::basic
|
- provider_type: inline::llm-as-judge
|
||||||
- provider_id: llm-as-judge
|
- provider_type: inline::braintrust
|
||||||
provider_type: inline::llm-as-judge
|
|
||||||
- provider_id: braintrust
|
|
||||||
provider_type: inline::braintrust
|
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_type: remote::brave-search
|
- provider_type: remote::brave-search
|
||||||
- provider_type: remote::tavily-search
|
- provider_type: remote::tavily-search
|
||||||
- provider_type: inline::rag-runtime
|
- provider_type: inline::rag-runtime
|
||||||
- provider_type: remote::model-context-protocol
|
- provider_type: remote::model-context-protocol
|
||||||
|
files:
|
||||||
|
- provider_type: inline::localfs
|
||||||
image_type: venv
|
image_type: venv
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
|
- aiosqlite
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
- aiosqlite
|
|
||||||
- aiosqlite
|
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ apis:
|
||||||
- agents
|
- agents
|
||||||
- datasetio
|
- datasetio
|
||||||
- eval
|
- eval
|
||||||
|
- files
|
||||||
- inference
|
- inference
|
||||||
- safety
|
- safety
|
||||||
- scoring
|
- scoring
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
- files
|
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: watsonx
|
- provider_id: watsonx
|
||||||
|
|
@ -19,8 +19,6 @@ providers:
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
- provider_id: sentence-transformers
|
|
||||||
provider_type: inline::sentence-transformers
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
@ -48,7 +46,7 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
||||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:=sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/trace_store.db
|
||||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||||
eval:
|
eval:
|
||||||
|
|
@ -109,102 +107,7 @@ metadata_store:
|
||||||
inference_store:
|
inference_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/inference_store.db
|
||||||
models:
|
models: []
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-2-13b-chat
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-2-13b
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-2-13b-chat
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
|
||||||
provider_id: watsonx
|
|
||||||
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
|
||||||
model_type: llm
|
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: sentence-transformers
|
|
||||||
model_type: embedding
|
|
||||||
shields: []
|
shields: []
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
||||||
SentenceTransformersInferenceConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
|
|
@ -52,15 +46,6 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
config=WatsonXConfig.sample_run_config(),
|
config=WatsonXConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_provider = Provider(
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
provider_type="inline::sentence-transformers",
|
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
|
||||||
)
|
|
||||||
|
|
||||||
available_models = {
|
|
||||||
"watsonx": MODEL_ENTRIES,
|
|
||||||
}
|
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
|
|
@ -72,36 +57,25 @@ def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
embedding_model = ModelInput(
|
|
||||||
model_id="all-MiniLM-L6-v2",
|
|
||||||
provider_id="sentence-transformers",
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
metadata={
|
|
||||||
"embedding_dimension": 384,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
files_provider = Provider(
|
files_provider = Provider(
|
||||||
provider_id="meta-reference-files",
|
provider_id="meta-reference-files",
|
||||||
provider_type="inline::localfs",
|
provider_type="inline::localfs",
|
||||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
default_models, _ = get_model_registry(available_models)
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
description="Use watsonx for running LLM inference",
|
description="Use watsonx for running LLM inference",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=Path(__file__).parent / "doc_template.md",
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
available_models_by_provider=available_models,
|
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider, embedding_provider],
|
"inference": [inference_provider],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models + [embedding_model],
|
default_models=[],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -277,7 +277,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter_type="watsonx",
|
adapter_type="watsonx",
|
||||||
provider_type="remote::watsonx",
|
provider_type="remote::watsonx",
|
||||||
pip_packages=["ibm_watsonx_ai"],
|
pip_packages=["litellm"],
|
||||||
module="llama_stack.providers.remote.inference.watsonx",
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,16 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
|
# This seems like it should be required, but none of the other remote inference
|
||||||
|
# providers require it, so this is optional here too for consistency.
|
||||||
|
# The OpenAIConfig uses default=None instead, so this is following that precedent.
|
||||||
api_key: SecretStr | None = Field(
|
api_key: SecretStr | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
default=None,
|
||||||
description="The watsonx.ai API key",
|
description="The watsonx.ai API key",
|
||||||
)
|
)
|
||||||
|
# As above, this is optional here too for consistency.
|
||||||
project_id: str | None = Field(
|
project_id: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
default=None,
|
||||||
description="The watsonx.ai project ID",
|
description="The watsonx.ai project ID",
|
||||||
)
|
)
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,6 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
||||||
|
|
||||||
|
|
||||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
_config: WatsonXConfig
|
|
||||||
__provider_id__: str = "watsonx"
|
|
||||||
|
|
||||||
def __init__(self, config: WatsonXConfig):
|
def __init__(self, config: WatsonXConfig):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -29,17 +26,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
self.available_models = None
|
self.available_models = None
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# get_api_key = LiteLLMOpenAIMixin.get_api_key
|
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return self.config.url
|
return self.config.url
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
await super().initialize()
|
|
||||||
|
|
||||||
async def shutdown(self):
|
|
||||||
await super().shutdown()
|
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||||
# Get base parameters from parent
|
# Get base parameters from parent
|
||||||
params = await super()._get_params(request)
|
params = await super()._get_params(request)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -16,6 +18,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
|
@ -26,7 +29,6 @@ from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
b64_encode_openai_embeddings_response,
|
|
||||||
convert_message_to_openai_dict_new,
|
convert_message_to_openai_dict_new,
|
||||||
convert_tooldef_to_openai_tool,
|
convert_tooldef_to_openai_tool,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
|
@ -334,6 +336,7 @@ class LiteLLMOpenAIMixin(
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
|
logger.info(f"params to litellm (openai compat): {params}")
|
||||||
return await litellm.acompletion(**params)
|
return await litellm.acompletion(**params)
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
|
@ -349,3 +352,28 @@ class LiteLLMOpenAIMixin(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return model in litellm.models_by_provider[self.litellm_provider_name]
|
return model in litellm.models_by_provider[self.litellm_provider_name]
|
||||||
|
|
||||||
|
|
||||||
|
def b64_encode_openai_embeddings_response(
|
||||||
|
response_data: list[dict], encoding_format: str | None = "float"
|
||||||
|
) -> list[OpenAIEmbeddingData]:
|
||||||
|
"""
|
||||||
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response_data):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_array = bytearray()
|
||||||
|
for embedding_value in embedding_data["embedding"]:
|
||||||
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response_embedding = embedding_data["embedding"]
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=response_embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,7 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import struct
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -103,7 +101,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
|
|
@ -1402,28 +1399,3 @@ def prepare_openai_embeddings_params(
|
||||||
params["user"] = user
|
params["user"] = user
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def b64_encode_openai_embeddings_response(
|
|
||||||
response_data: list[dict], encoding_format: str | None = "float"
|
|
||||||
) -> list[OpenAIEmbeddingData]:
|
|
||||||
"""
|
|
||||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
|
||||||
"""
|
|
||||||
data = []
|
|
||||||
for i, embedding_data in enumerate(response_data):
|
|
||||||
if encoding_format == "base64":
|
|
||||||
byte_array = bytearray()
|
|
||||||
for embedding_value in embedding_data["embedding"]:
|
|
||||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
|
||||||
|
|
||||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
|
||||||
else:
|
|
||||||
response_embedding = embedding_data["embedding"]
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=response_embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue