mirror of
				https://github.com/meta-llama/llama-stack.git
				synced 2025-10-25 01:01:13 +00:00 
			
		
		
		
	
		
			Some checks failed
		
		
	
	SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
				
			SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
				
			Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
				
			Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
				
			Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
				
			Python Package Build Test / build (3.13) (push) Failing after 2s
				
			Python Package Build Test / build (3.12) (push) Failing after 3s
				
			Vector IO Integration Tests / test-matrix (push) Failing after 7s
				
			Test Llama Stack Build / generate-matrix (push) Successful in 6s
				
			Test Llama Stack Build / build-single-provider (push) Failing after 4s
				
			Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
				
			Test External API and Providers / test-external (venv) (push) Failing after 4s
				
			Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
				
			Unit Tests / unit-tests (3.13) (push) Failing after 4s
				
			API Conformance Tests / check-schema-compatibility (push) Successful in 12s
				
			Test Llama Stack Build / build (push) Failing after 3s
				
			Unit Tests / unit-tests (3.12) (push) Failing after 5s
				
			UI Tests / ui-tests (22) (push) Successful in 32s
				
			Pre-commit / pre-commit (push) Successful in 1m29s
				
			# What does this PR do? - The watsonx.ai provider now uses the LiteLLM mixin instead of using IBM's library, which does not seem to be working (see #3165 for context). - The watsonx.ai provider now lists all the models available by calling the watsonx.ai server instead of having a hard coded list of known models. (That list gets out of date quickly) - An edge case in [llama_stack/core/routers/inference.py](https://github.com/llamastack/llama-stack/pull/3674/files#diff-a34bc966ed9befd9f13d4883c23705dff49be0ad6211c850438cdda6113f3455) is addressed that was causing my manual tests to fail. - Fixes `b64_encode_openai_embeddings_response` which was trying to enumerate over a dictionary and then reference elements of the dictionary using .field instead of ["field"]. That method is called by the LiteLLM mixin for embedding models, so it is needed to get the watsonx.ai embedding models to work. - A unit test along the lines of the one in #3348 is added. A more comprehensive plan for automatically testing the end-to-end functionality for inference providers would be a good idea, but is out of scope for this PR. - Updates to the watsonx distribution. Some were in response to the switch to LiteLLM (e.g., updating the Python packages needed). Others seem to be things that were already broken that I found along the way (e.g., a reference to a watsonx specific doc template that doesn't seem to exist). Closes #3165 Also it is related to a line-item in #3387 but doesn't really address that goal (because it uses the LiteLLM mixin, not the OpenAI one). I tried the OpenAI one and it doesn't work with watsonx.ai, presumably because the watsonx.ai service is not OpenAI compatible. It works with LiteLLM because LiteLLM has a provider implementation for watsonx.ai. ## Test Plan The test script below goes back and forth between the OpenAI and watsonx providers. The idea is that the OpenAI provider shows how it should work and then the watsonx provider output shows that it is also working with watsonx. Note that the result from the MCP test is not as good (the Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it is still working and providing a valid response. For more details on setup and the MCP server being used for testing, see [the AI Alliance sample notebook](https://github.com/The-AI-Alliance/llama-stack-examples/blob/main/notebooks/01-responses/) that these examples are drawn from. ```python #!/usr/bin/env python3 import json from llama_stack_client import LlamaStackClient from litellm import completion import http.client def print_response(response): """Print response in a nicely formatted way""" print(f"ID: {response.id}") print(f"Status: {response.status}") print(f"Model: {response.model}") print(f"Created at: {response.created_at}") print(f"Output items: {len(response.output)}") for i, output_item in enumerate(response.output): if len(response.output) > 1: print(f"\n--- Output Item {i+1} ---") print(f"Output type: {output_item.type}") if output_item.type in ("text", "message"): print(f"Response content: {output_item.content[0].text}") elif output_item.type == "file_search_call": print(f" Tool Call ID: {output_item.id}") print(f" Tool Status: {output_item.status}") # 'queries' is a list, so we join it for clean printing print(f" Queries: {', '.join(output_item.queries)}") # Display results if they exist, otherwise note they are empty print(f" Results: {output_item.results if output_item.results else 'None'}") elif output_item.type == "mcp_list_tools": print_mcp_list_tools(output_item) elif output_item.type == "mcp_call": print_mcp_call(output_item) else: print(f"Response content: {output_item.content}") def print_mcp_call(mcp_call): """Print MCP call in a nicely formatted way""" print(f"\n🛠️ MCP Tool Call: {mcp_call.name}") print(f" Server: {mcp_call.server_label}") print(f" ID: {mcp_call.id}") print(f" Arguments: {mcp_call.arguments}") if mcp_call.error: print("Error: {mcp_call.error}") elif mcp_call.output: print("Output:") # Try to format JSON output nicely try: parsed_output = json.loads(mcp_call.output) print(json.dumps(parsed_output, indent=4)) except: # If not valid JSON, print as-is print(f" {mcp_call.output}") else: print(" ⏳ No output yet") def print_mcp_list_tools(mcp_list_tools): """Print MCP list tools in a nicely formatted way""" print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}") print(f" ID: {mcp_list_tools.id}") print(f" Available Tools: {len(mcp_list_tools.tools)}") print("=" * 80) for i, tool in enumerate(mcp_list_tools.tools, 1): print(f"\n{i}. {tool.name}") print(f" Description: {tool.description}") # Parse and display input schema schema = tool.input_schema if schema and 'properties' in schema: properties = schema['properties'] required = schema.get('required', []) print(" Parameters:") for param_name, param_info in properties.items(): param_type = param_info.get('type', 'unknown') param_desc = param_info.get('description', 'No description') required_marker = " (required)" if param_name in required else " (optional)" print(f" • {param_name} ({param_type}){required_marker}") if param_desc: print(f" {param_desc}") if i < len(mcp_list_tools.tools): print("-" * 40) def main(): """Main function to run all the tests""" # Configuration LLAMA_STACK_URL = "http://localhost:8321/" LLAMA_STACK_MODEL_IDS = [ "openai/gpt-3.5-turbo", "openai/gpt-4o", "llama-openai-compat/Llama-3.3-70B-Instruct", "watsonx/meta-llama/llama-3-3-70b-instruct" ] # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml. OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1] WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1] NPS_MCP_URL = "http://localhost:3005/sse/" print("=== Llama Stack Testing Script ===") print(f"Using OpenAI model: {OPENAI_MODEL_ID}") print(f"Using WatsonX model: {WATSONX_MODEL_ID}") print(f"MCP URL: {NPS_MCP_URL}") print() # Initialize client print("Initializing LlamaStackClient...") client = LlamaStackClient(base_url="http://localhost:8321") # Test 1: List models print("\n=== Test 1: List Models ===") try: models = client.models.list() print(f"Found {len(models)} models") except Exception as e: print(f"Error listing models: {e}") raise e # Test 2: Basic chat completion with OpenAI print("\n=== Test 2: Basic Chat Completion (OpenAI) ===") try: chat_completion_response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}] ) print("OpenAI Response:") for chunk in chat_completion_response.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with OpenAI chat completion: {e}") raise e # Test 3: Basic chat completion with WatsonX print("\n=== Test 3: Basic Chat Completion (WatsonX) ===") try: chat_completion_response_wxai = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], ) print("WatsonX Response:") for chunk in chat_completion_response_wxai.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with WatsonX chat completion: {e}") raise e # Test 4: Tool calling with OpenAI print("\n=== Test 4: Tool Calling (OpenAI) ===") tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, }, "required": ["location"], }, }, } ] messages = [ {"role": "user", "content": "What's the weather like in Boston, MA?"} ] try: print("--- Initial API Call ---") response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("OpenAI tool calling response received") except Exception as e: print(f"Error with OpenAI tool calling: {e}") raise e # Test 5: Tool calling with WatsonX print("\n=== Test 5: Tool Calling (WatsonX) ===") try: wxai_response = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("WatsonX tool calling response received") except Exception as e: print(f"Error with WatsonX tool calling: {e}") raise e # Test 6: Streaming with WatsonX print("\n=== Test 6: Streaming Response (WatsonX) ===") try: chat_completion_response_wxai_stream = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], stream=True ) print("Model response: ", end="") for chunk in chat_completion_response_wxai_stream: # Each 'chunk' is a ChatCompletionChunk object. # We want the content from the 'delta' attribute. if hasattr(chunk, 'choices') and chunk.choices is not None: content = chunk.choices[0].delta.content # The first few chunks might have None content, so we check for it. if content is not None: print(content, end="", flush=True) print() except Exception as e: print(f"Error with streaming: {e}") raise e # Test 7: MCP with OpenAI print("\n=== Test 7: MCP Integration (OpenAI) ===") try: mcp_llama_stack_client_response = client.responses.create( model=OPENAI_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (OpenAI): {e}") raise e # Test 8: MCP with WatsonX print("\n=== Test 8: MCP Integration (WatsonX) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="What is the capital of France?" ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (WatsonX): {e}") raise e # Test 9: MCP with Llama 3.3 print("\n=== Test 9: MCP Integration (Llama 3.3) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (Llama 3.3): {e}") raise e # Test 10: Embeddings print("\n=== Test 10: Embeddings ===") try: conn = http.client.HTTPConnection("localhost:8321") payload = json.dumps({ "model": "watsonx/ibm/granite-embedding-278m-multilingual", "input": "Hello, world!", }) headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } conn.request("POST", "/v1/openai/v1/embeddings", payload, headers) res = conn.getresponse() data = res.read() print(data.decode("utf-8")) except Exception as e: print(f"Error with Embeddings: {e}") raise e print("\n=== Testing Complete ===") if __name__ == "__main__": main() ``` --------- Signed-off-by: Bill Murdock <bmurdock@redhat.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
		
			
				
	
	
		
			378 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			378 lines
		
	
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import base64
 | |
| import struct
 | |
| from collections.abc import AsyncIterator
 | |
| from typing import Any
 | |
| 
 | |
| import litellm
 | |
| 
 | |
| from llama_stack.apis.inference import (
 | |
|     ChatCompletionRequest,
 | |
|     InferenceProvider,
 | |
|     JsonSchemaResponseFormat,
 | |
|     OpenAIChatCompletion,
 | |
|     OpenAIChatCompletionChunk,
 | |
|     OpenAICompletion,
 | |
|     OpenAIEmbeddingData,
 | |
|     OpenAIEmbeddingsResponse,
 | |
|     OpenAIEmbeddingUsage,
 | |
|     OpenAIMessageParam,
 | |
|     OpenAIResponseFormatParam,
 | |
|     ToolChoice,
 | |
| )
 | |
| from llama_stack.core.request_headers import NeedsRequestProviderData
 | |
| from llama_stack.log import get_logger
 | |
| from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
 | |
| from llama_stack.providers.utils.inference.openai_compat import (
 | |
|     convert_message_to_openai_dict_new,
 | |
|     convert_tooldef_to_openai_tool,
 | |
|     get_sampling_options,
 | |
|     prepare_openai_completion_params,
 | |
| )
 | |
| 
 | |
| logger = get_logger(name=__name__, category="providers::utils")
 | |
| 
 | |
| 
 | |
| class LiteLLMOpenAIMixin(
 | |
|     ModelRegistryHelper,
 | |
|     InferenceProvider,
 | |
|     NeedsRequestProviderData,
 | |
| ):
 | |
|     # TODO: avoid exposing the litellm specific model names to the user.
 | |
|     #       potential change: add a prefix param that gets added to the model name
 | |
|     #                         when calling litellm.
 | |
|     def __init__(
 | |
|         self,
 | |
|         litellm_provider_name: str,
 | |
|         api_key_from_config: str | None,
 | |
|         provider_data_api_key_field: str | None = None,
 | |
|         model_entries: list[ProviderModelEntry] | None = None,
 | |
|         openai_compat_api_base: str | None = None,
 | |
|         download_images: bool = False,
 | |
|         json_schema_strict: bool = True,
 | |
|     ):
 | |
|         """
 | |
|         Initialize the LiteLLMOpenAIMixin.
 | |
| 
 | |
|         :param model_entries: The model entries to register.
 | |
|         :param api_key_from_config: The API key to use from the config.
 | |
|         :param provider_data_api_key_field: The field in the provider data that contains the API key (optional).
 | |
|         :param litellm_provider_name: The name of the provider, used for model lookups.
 | |
|         :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility.
 | |
|         :param download_images: Whether to download images and convert to base64 for message conversion.
 | |
|         :param json_schema_strict: Whether to use strict mode for JSON schema validation.
 | |
|         """
 | |
|         ModelRegistryHelper.__init__(self, model_entries=model_entries)
 | |
| 
 | |
|         self.litellm_provider_name = litellm_provider_name
 | |
|         self.api_key_from_config = api_key_from_config
 | |
|         self.provider_data_api_key_field = provider_data_api_key_field
 | |
|         self.api_base = openai_compat_api_base
 | |
|         self.download_images = download_images
 | |
|         self.json_schema_strict = json_schema_strict
 | |
| 
 | |
|         if openai_compat_api_base:
 | |
|             self.is_openai_compat = True
 | |
|         else:
 | |
|             self.is_openai_compat = False
 | |
| 
 | |
|     async def initialize(self):
 | |
|         pass
 | |
| 
 | |
|     async def shutdown(self):
 | |
|         pass
 | |
| 
 | |
|     def get_litellm_model_name(self, model_id: str) -> str:
 | |
|         # users may be using openai/ prefix in their model names. the openai/models.py did this by default.
 | |
|         # model_id.startswith("openai/") is for backwards compatibility.
 | |
|         return (
 | |
|             f"{self.litellm_provider_name}/{model_id}"
 | |
|             if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
 | |
|             else model_id
 | |
|         )
 | |
| 
 | |
|     def _add_additional_properties_recursive(self, schema):
 | |
|         """
 | |
|         Recursively add additionalProperties: False to all object schemas
 | |
|         """
 | |
|         if isinstance(schema, dict):
 | |
|             if schema.get("type") == "object":
 | |
|                 schema["additionalProperties"] = False
 | |
| 
 | |
|                 # Add required field with all property keys if properties exist
 | |
|                 if "properties" in schema and schema["properties"]:
 | |
|                     schema["required"] = list(schema["properties"].keys())
 | |
| 
 | |
|             if "properties" in schema:
 | |
|                 for prop_schema in schema["properties"].values():
 | |
|                     self._add_additional_properties_recursive(prop_schema)
 | |
| 
 | |
|             for key in ["anyOf", "allOf", "oneOf"]:
 | |
|                 if key in schema:
 | |
|                     for sub_schema in schema[key]:
 | |
|                         self._add_additional_properties_recursive(sub_schema)
 | |
| 
 | |
|             if "not" in schema:
 | |
|                 self._add_additional_properties_recursive(schema["not"])
 | |
| 
 | |
|             # Handle $defs/$ref
 | |
|             if "$defs" in schema:
 | |
|                 for def_schema in schema["$defs"].values():
 | |
|                     self._add_additional_properties_recursive(def_schema)
 | |
| 
 | |
|         return schema
 | |
| 
 | |
|     async def _get_params(self, request: ChatCompletionRequest) -> dict:
 | |
|         input_dict = {}
 | |
| 
 | |
|         input_dict["messages"] = [
 | |
|             await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
 | |
|         ]
 | |
|         if fmt := request.response_format:
 | |
|             if not isinstance(fmt, JsonSchemaResponseFormat):
 | |
|                 raise ValueError(
 | |
|                     f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
 | |
|                 )
 | |
| 
 | |
|             fmt = fmt.json_schema
 | |
|             name = fmt["title"]
 | |
|             del fmt["title"]
 | |
|             fmt["additionalProperties"] = False
 | |
| 
 | |
|             # Apply additionalProperties: False recursively to all objects
 | |
|             fmt = self._add_additional_properties_recursive(fmt)
 | |
| 
 | |
|             input_dict["response_format"] = {
 | |
|                 "type": "json_schema",
 | |
|                 "json_schema": {
 | |
|                     "name": name,
 | |
|                     "schema": fmt,
 | |
|                     "strict": self.json_schema_strict,
 | |
|                 },
 | |
|             }
 | |
|         if request.tools:
 | |
|             input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
 | |
|             if request.tool_config.tool_choice:
 | |
|                 input_dict["tool_choice"] = (
 | |
|                     request.tool_config.tool_choice.value
 | |
|                     if isinstance(request.tool_config.tool_choice, ToolChoice)
 | |
|                     else request.tool_config.tool_choice
 | |
|                 )
 | |
| 
 | |
|         return {
 | |
|             "model": request.model,
 | |
|             "api_key": self.get_api_key(),
 | |
|             "api_base": self.api_base,
 | |
|             **input_dict,
 | |
|             "stream": request.stream,
 | |
|             **get_sampling_options(request.sampling_params),
 | |
|         }
 | |
| 
 | |
|     def get_api_key(self) -> str:
 | |
|         provider_data = self.get_request_provider_data()
 | |
|         key_field = self.provider_data_api_key_field
 | |
|         if provider_data and getattr(provider_data, key_field, None):
 | |
|             api_key = getattr(provider_data, key_field)
 | |
|         else:
 | |
|             api_key = self.api_key_from_config
 | |
|         if not api_key:
 | |
|             raise ValueError(
 | |
|                 "API key is not set. Please provide a valid API key in the "
 | |
|                 "provider data header, e.g. x-llamastack-provider-data: "
 | |
|                 f'{{"{key_field}": "<API_KEY>"}}, or in the provider config.'
 | |
|             )
 | |
|         return api_key
 | |
| 
 | |
|     async def openai_embeddings(
 | |
|         self,
 | |
|         model: str,
 | |
|         input: str | list[str],
 | |
|         encoding_format: str | None = "float",
 | |
|         dimensions: int | None = None,
 | |
|         user: str | None = None,
 | |
|     ) -> OpenAIEmbeddingsResponse:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
| 
 | |
|         # Convert input to list if it's a string
 | |
|         input_list = [input] if isinstance(input, str) else input
 | |
| 
 | |
|         # Call litellm embedding function
 | |
|         # litellm.drop_params = True
 | |
|         response = litellm.embedding(
 | |
|             model=self.get_litellm_model_name(model_obj.provider_resource_id),
 | |
|             input=input_list,
 | |
|             api_key=self.get_api_key(),
 | |
|             api_base=self.api_base,
 | |
|             dimensions=dimensions,
 | |
|         )
 | |
| 
 | |
|         # Convert response to OpenAI format
 | |
|         data = b64_encode_openai_embeddings_response(response.data, encoding_format)
 | |
| 
 | |
|         usage = OpenAIEmbeddingUsage(
 | |
|             prompt_tokens=response["usage"]["prompt_tokens"],
 | |
|             total_tokens=response["usage"]["total_tokens"],
 | |
|         )
 | |
| 
 | |
|         return OpenAIEmbeddingsResponse(
 | |
|             data=data,
 | |
|             model=model_obj.provider_resource_id,
 | |
|             usage=usage,
 | |
|         )
 | |
| 
 | |
|     async def openai_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         prompt: str | list[str] | list[int] | list[list[int]],
 | |
|         best_of: int | None = None,
 | |
|         echo: bool | None = None,
 | |
|         frequency_penalty: float | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|         guided_choice: list[str] | None = None,
 | |
|         prompt_logprobs: int | None = None,
 | |
|         suffix: str | None = None,
 | |
|     ) -> OpenAICompletion:
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=self.get_litellm_model_name(model_obj.provider_resource_id),
 | |
|             prompt=prompt,
 | |
|             best_of=best_of,
 | |
|             echo=echo,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             presence_penalty=presence_penalty,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|             guided_choice=guided_choice,
 | |
|             prompt_logprobs=prompt_logprobs,
 | |
|             api_key=self.get_api_key(),
 | |
|             api_base=self.api_base,
 | |
|         )
 | |
|         return await litellm.atext_completion(**params)
 | |
| 
 | |
|     async def openai_chat_completion(
 | |
|         self,
 | |
|         model: str,
 | |
|         messages: list[OpenAIMessageParam],
 | |
|         frequency_penalty: float | None = None,
 | |
|         function_call: str | dict[str, Any] | None = None,
 | |
|         functions: list[dict[str, Any]] | None = None,
 | |
|         logit_bias: dict[str, float] | None = None,
 | |
|         logprobs: bool | None = None,
 | |
|         max_completion_tokens: int | None = None,
 | |
|         max_tokens: int | None = None,
 | |
|         n: int | None = None,
 | |
|         parallel_tool_calls: bool | None = None,
 | |
|         presence_penalty: float | None = None,
 | |
|         response_format: OpenAIResponseFormatParam | None = None,
 | |
|         seed: int | None = None,
 | |
|         stop: str | list[str] | None = None,
 | |
|         stream: bool | None = None,
 | |
|         stream_options: dict[str, Any] | None = None,
 | |
|         temperature: float | None = None,
 | |
|         tool_choice: str | dict[str, Any] | None = None,
 | |
|         tools: list[dict[str, Any]] | None = None,
 | |
|         top_logprobs: int | None = None,
 | |
|         top_p: float | None = None,
 | |
|         user: str | None = None,
 | |
|     ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
 | |
|         # Add usage tracking for streaming when telemetry is active
 | |
|         from llama_stack.providers.utils.telemetry.tracing import get_current_span
 | |
| 
 | |
|         if stream and get_current_span() is not None:
 | |
|             if stream_options is None:
 | |
|                 stream_options = {"include_usage": True}
 | |
|             elif "include_usage" not in stream_options:
 | |
|                 stream_options = {**stream_options, "include_usage": True}
 | |
|         model_obj = await self.model_store.get_model(model)
 | |
|         params = await prepare_openai_completion_params(
 | |
|             model=self.get_litellm_model_name(model_obj.provider_resource_id),
 | |
|             messages=messages,
 | |
|             frequency_penalty=frequency_penalty,
 | |
|             function_call=function_call,
 | |
|             functions=functions,
 | |
|             logit_bias=logit_bias,
 | |
|             logprobs=logprobs,
 | |
|             max_completion_tokens=max_completion_tokens,
 | |
|             max_tokens=max_tokens,
 | |
|             n=n,
 | |
|             parallel_tool_calls=parallel_tool_calls,
 | |
|             presence_penalty=presence_penalty,
 | |
|             response_format=response_format,
 | |
|             seed=seed,
 | |
|             stop=stop,
 | |
|             stream=stream,
 | |
|             stream_options=stream_options,
 | |
|             temperature=temperature,
 | |
|             tool_choice=tool_choice,
 | |
|             tools=tools,
 | |
|             top_logprobs=top_logprobs,
 | |
|             top_p=top_p,
 | |
|             user=user,
 | |
|             api_key=self.get_api_key(),
 | |
|             api_base=self.api_base,
 | |
|         )
 | |
|         return await litellm.acompletion(**params)
 | |
| 
 | |
|     async def check_model_availability(self, model: str) -> bool:
 | |
|         """
 | |
|         Check if a specific model is available via LiteLLM for the current
 | |
|         provider (self.litellm_provider_name).
 | |
| 
 | |
|         :param model: The model identifier to check.
 | |
|         :return: True if the model is available dynamically, False otherwise.
 | |
|         """
 | |
|         if self.litellm_provider_name not in litellm.models_by_provider:
 | |
|             logger.error(f"Provider {self.litellm_provider_name} is not registered in litellm.")
 | |
|             return False
 | |
| 
 | |
|         return model in litellm.models_by_provider[self.litellm_provider_name]
 | |
| 
 | |
| 
 | |
| def b64_encode_openai_embeddings_response(
 | |
|     response_data: list[dict], encoding_format: str | None = "float"
 | |
| ) -> list[OpenAIEmbeddingData]:
 | |
|     """
 | |
|     Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
 | |
|     """
 | |
|     data = []
 | |
|     for i, embedding_data in enumerate(response_data):
 | |
|         if encoding_format == "base64":
 | |
|             byte_array = bytearray()
 | |
|             for embedding_value in embedding_data["embedding"]:
 | |
|                 byte_array.extend(struct.pack("f", float(embedding_value)))
 | |
| 
 | |
|             response_embedding = base64.b64encode(byte_array).decode("utf-8")
 | |
|         else:
 | |
|             response_embedding = embedding_data["embedding"]
 | |
|         data.append(
 | |
|             OpenAIEmbeddingData(
 | |
|                 embedding=response_embedding,
 | |
|                 index=i,
 | |
|             )
 | |
|         )
 | |
|     return data
 |