mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 17:23:06 +00:00
fix: Fix embedding model listing and usage for watsonx
Signed-off-by: Bill Murdock <bmurdock@redhat.com>
This commit is contained in:
parent
999c28e809
commit
ecafe40a84
2 changed files with 36 additions and 26 deletions
|
|
@ -56,15 +56,40 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
async def list_models(self) -> list[Model] | None:
|
async def list_models(self) -> list[Model] | None:
|
||||||
models = []
|
models = []
|
||||||
for model_spec in self._get_model_specs():
|
for model_spec in self._get_model_specs():
|
||||||
models.append(
|
functions = [f['id'] for f in model_spec.get("functions", [])]
|
||||||
Model(
|
# Format: {"embedding_dimension": 1536, "context_length": 8192}
|
||||||
identifier=model_spec["model_id"],
|
|
||||||
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
|
# Example of an embedding model:
|
||||||
provider_id=self.__provider_id__,
|
# {'model_id': 'ibm/granite-embedding-278m-multilingual',
|
||||||
metadata={},
|
# 'label': 'granite-embedding-278m-multilingual',
|
||||||
model_type=ModelType.llm,
|
# 'model_limits': {'max_sequence_length': 512, 'embedding_dimension': 768},
|
||||||
|
# ...
|
||||||
|
if "embedding" in functions:
|
||||||
|
embedding_dimension = model_spec["model_limits"]["embedding_dimension"]
|
||||||
|
context_length = model_spec["model_limits"]["max_sequence_length"]
|
||||||
|
embedding_metadata = {
|
||||||
|
"embedding_dimension": embedding_dimension,
|
||||||
|
"context_length": context_length,
|
||||||
|
}
|
||||||
|
models.append(
|
||||||
|
Model(
|
||||||
|
identifier=model_spec["model_id"],
|
||||||
|
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
metadata=embedding_metadata,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "text_chat" in functions:
|
||||||
|
models.append(
|
||||||
|
Model(
|
||||||
|
identifier=model_spec["model_id"],
|
||||||
|
provider_resource_id=f"{self.__provider_id__}/{model_spec['model_id']}",
|
||||||
|
provider_id=self.__provider_id__,
|
||||||
|
metadata={},
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
# LiteLLM provides methods to list models for many providers, but not for watsonx.ai.
|
||||||
|
|
@ -91,18 +116,3 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
if "resources" not in response_data:
|
if "resources" not in response_data:
|
||||||
raise ValueError("Resources not found in response")
|
raise ValueError("Resources not found in response")
|
||||||
return response_data["resources"]
|
return response_data["resources"]
|
||||||
|
|
||||||
|
|
||||||
# TO DO: Delete the test main method.
|
|
||||||
if __name__ == "__main__":
|
|
||||||
config = WatsonXConfig(url="https://us-south.ml.cloud.ibm.com", api_key="xxx", project_id="xxx", timeout=60)
|
|
||||||
adapter = WatsonXInferenceAdapter(config)
|
|
||||||
model_specs = adapter._get_model_specs()
|
|
||||||
models = asyncio.run(adapter.list_models())
|
|
||||||
for model in models:
|
|
||||||
print(model.identifier)
|
|
||||||
print(model.provider_resource_id)
|
|
||||||
print(model.provider_id)
|
|
||||||
print(model.metadata)
|
|
||||||
print(model.model_type)
|
|
||||||
print("--------------------------------")
|
|
||||||
|
|
|
||||||
|
|
@ -1405,7 +1405,7 @@ def prepare_openai_embeddings_params(
|
||||||
|
|
||||||
|
|
||||||
def b64_encode_openai_embeddings_response(
|
def b64_encode_openai_embeddings_response(
|
||||||
response_data: dict, encoding_format: str | None = "float"
|
response_data: list[dict], encoding_format: str | None = "float"
|
||||||
) -> list[OpenAIEmbeddingData]:
|
) -> list[OpenAIEmbeddingData]:
|
||||||
"""
|
"""
|
||||||
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
|
@ -1414,12 +1414,12 @@ def b64_encode_openai_embeddings_response(
|
||||||
for i, embedding_data in enumerate(response_data):
|
for i, embedding_data in enumerate(response_data):
|
||||||
if encoding_format == "base64":
|
if encoding_format == "base64":
|
||||||
byte_array = bytearray()
|
byte_array = bytearray()
|
||||||
for embedding_value in embedding_data.embedding:
|
for embedding_value in embedding_data["embedding"]:
|
||||||
byte_array.extend(struct.pack("f", float(embedding_value)))
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
response_embedding = embedding_data.embedding
|
response_embedding = embedding_data["embedding"]
|
||||||
data.append(
|
data.append(
|
||||||
OpenAIEmbeddingData(
|
OpenAIEmbeddingData(
|
||||||
embedding=response_embedding,
|
embedding=response_embedding,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue