mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-08 21:04:39 +00:00
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
Python Package Build Test / build (3.12) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 7s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 5s
UI Tests / ui-tests (22) (push) Successful in 32s
Pre-commit / pre-commit (push) Successful in 1m29s
# What does this PR do? - The watsonx.ai provider now uses the LiteLLM mixin instead of using IBM's library, which does not seem to be working (see #3165 for context). - The watsonx.ai provider now lists all the models available by calling the watsonx.ai server instead of having a hard coded list of known models. (That list gets out of date quickly) - An edge case in [llama_stack/core/routers/inference.py](https://github.com/llamastack/llama-stack/pull/3674/files#diff-a34bc966ed9befd9f13d4883c23705dff49be0ad6211c850438cdda6113f3455) is addressed that was causing my manual tests to fail. - Fixes `b64_encode_openai_embeddings_response` which was trying to enumerate over a dictionary and then reference elements of the dictionary using .field instead of ["field"]. That method is called by the LiteLLM mixin for embedding models, so it is needed to get the watsonx.ai embedding models to work. - A unit test along the lines of the one in #3348 is added. A more comprehensive plan for automatically testing the end-to-end functionality for inference providers would be a good idea, but is out of scope for this PR. - Updates to the watsonx distribution. Some were in response to the switch to LiteLLM (e.g., updating the Python packages needed). Others seem to be things that were already broken that I found along the way (e.g., a reference to a watsonx specific doc template that doesn't seem to exist). Closes #3165 Also it is related to a line-item in #3387 but doesn't really address that goal (because it uses the LiteLLM mixin, not the OpenAI one). I tried the OpenAI one and it doesn't work with watsonx.ai, presumably because the watsonx.ai service is not OpenAI compatible. It works with LiteLLM because LiteLLM has a provider implementation for watsonx.ai. ## Test Plan The test script below goes back and forth between the OpenAI and watsonx providers. The idea is that the OpenAI provider shows how it should work and then the watsonx provider output shows that it is also working with watsonx. Note that the result from the MCP test is not as good (the Llama 3.3 70b model does not choose tools as wisely as gpt-4o), but it is still working and providing a valid response. For more details on setup and the MCP server being used for testing, see [the AI Alliance sample notebook](https://github.com/The-AI-Alliance/llama-stack-examples/blob/main/notebooks/01-responses/) that these examples are drawn from. ```python #!/usr/bin/env python3 import json from llama_stack_client import LlamaStackClient from litellm import completion import http.client def print_response(response): """Print response in a nicely formatted way""" print(f"ID: {response.id}") print(f"Status: {response.status}") print(f"Model: {response.model}") print(f"Created at: {response.created_at}") print(f"Output items: {len(response.output)}") for i, output_item in enumerate(response.output): if len(response.output) > 1: print(f"\n--- Output Item {i+1} ---") print(f"Output type: {output_item.type}") if output_item.type in ("text", "message"): print(f"Response content: {output_item.content[0].text}") elif output_item.type == "file_search_call": print(f" Tool Call ID: {output_item.id}") print(f" Tool Status: {output_item.status}") # 'queries' is a list, so we join it for clean printing print(f" Queries: {', '.join(output_item.queries)}") # Display results if they exist, otherwise note they are empty print(f" Results: {output_item.results if output_item.results else 'None'}") elif output_item.type == "mcp_list_tools": print_mcp_list_tools(output_item) elif output_item.type == "mcp_call": print_mcp_call(output_item) else: print(f"Response content: {output_item.content}") def print_mcp_call(mcp_call): """Print MCP call in a nicely formatted way""" print(f"\n🛠️ MCP Tool Call: {mcp_call.name}") print(f" Server: {mcp_call.server_label}") print(f" ID: {mcp_call.id}") print(f" Arguments: {mcp_call.arguments}") if mcp_call.error: print("Error: {mcp_call.error}") elif mcp_call.output: print("Output:") # Try to format JSON output nicely try: parsed_output = json.loads(mcp_call.output) print(json.dumps(parsed_output, indent=4)) except: # If not valid JSON, print as-is print(f" {mcp_call.output}") else: print(" ⏳ No output yet") def print_mcp_list_tools(mcp_list_tools): """Print MCP list tools in a nicely formatted way""" print(f"\n🔧 MCP Server: {mcp_list_tools.server_label}") print(f" ID: {mcp_list_tools.id}") print(f" Available Tools: {len(mcp_list_tools.tools)}") print("=" * 80) for i, tool in enumerate(mcp_list_tools.tools, 1): print(f"\n{i}. {tool.name}") print(f" Description: {tool.description}") # Parse and display input schema schema = tool.input_schema if schema and 'properties' in schema: properties = schema['properties'] required = schema.get('required', []) print(" Parameters:") for param_name, param_info in properties.items(): param_type = param_info.get('type', 'unknown') param_desc = param_info.get('description', 'No description') required_marker = " (required)" if param_name in required else " (optional)" print(f" • {param_name} ({param_type}){required_marker}") if param_desc: print(f" {param_desc}") if i < len(mcp_list_tools.tools): print("-" * 40) def main(): """Main function to run all the tests""" # Configuration LLAMA_STACK_URL = "http://localhost:8321/" LLAMA_STACK_MODEL_IDS = [ "openai/gpt-3.5-turbo", "openai/gpt-4o", "llama-openai-compat/Llama-3.3-70B-Instruct", "watsonx/meta-llama/llama-3-3-70b-instruct" ] # Using gpt-4o for this demo, but feel free to try one of the others or add more to run.yaml. OPENAI_MODEL_ID = LLAMA_STACK_MODEL_IDS[1] WATSONX_MODEL_ID = LLAMA_STACK_MODEL_IDS[-1] NPS_MCP_URL = "http://localhost:3005/sse/" print("=== Llama Stack Testing Script ===") print(f"Using OpenAI model: {OPENAI_MODEL_ID}") print(f"Using WatsonX model: {WATSONX_MODEL_ID}") print(f"MCP URL: {NPS_MCP_URL}") print() # Initialize client print("Initializing LlamaStackClient...") client = LlamaStackClient(base_url="http://localhost:8321") # Test 1: List models print("\n=== Test 1: List Models ===") try: models = client.models.list() print(f"Found {len(models)} models") except Exception as e: print(f"Error listing models: {e}") raise e # Test 2: Basic chat completion with OpenAI print("\n=== Test 2: Basic Chat Completion (OpenAI) ===") try: chat_completion_response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}] ) print("OpenAI Response:") for chunk in chat_completion_response.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with OpenAI chat completion: {e}") raise e # Test 3: Basic chat completion with WatsonX print("\n=== Test 3: Basic Chat Completion (WatsonX) ===") try: chat_completion_response_wxai = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], ) print("WatsonX Response:") for chunk in chat_completion_response_wxai.choices[0].message.content: print(chunk, end="", flush=True) print() except Exception as e: print(f"Error with WatsonX chat completion: {e}") raise e # Test 4: Tool calling with OpenAI print("\n=== Test 4: Tool Calling (OpenAI) ===") tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather for a specific location", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g., San Francisco, CA", }, "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, }, "required": ["location"], }, }, } ] messages = [ {"role": "user", "content": "What's the weather like in Boston, MA?"} ] try: print("--- Initial API Call ---") response = client.chat.completions.create( model=OPENAI_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("OpenAI tool calling response received") except Exception as e: print(f"Error with OpenAI tool calling: {e}") raise e # Test 5: Tool calling with WatsonX print("\n=== Test 5: Tool Calling (WatsonX) ===") try: wxai_response = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=messages, tools=tools, tool_choice="auto", # "auto" is the default ) print("WatsonX tool calling response received") except Exception as e: print(f"Error with WatsonX tool calling: {e}") raise e # Test 6: Streaming with WatsonX print("\n=== Test 6: Streaming Response (WatsonX) ===") try: chat_completion_response_wxai_stream = client.chat.completions.create( model=WATSONX_MODEL_ID, messages=[{"role": "user", "content": "What is the capital of France?"}], stream=True ) print("Model response: ", end="") for chunk in chat_completion_response_wxai_stream: # Each 'chunk' is a ChatCompletionChunk object. # We want the content from the 'delta' attribute. if hasattr(chunk, 'choices') and chunk.choices is not None: content = chunk.choices[0].delta.content # The first few chunks might have None content, so we check for it. if content is not None: print(content, end="", flush=True) print() except Exception as e: print(f"Error with streaming: {e}") raise e # Test 7: MCP with OpenAI print("\n=== Test 7: MCP Integration (OpenAI) ===") try: mcp_llama_stack_client_response = client.responses.create( model=OPENAI_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (OpenAI): {e}") raise e # Test 8: MCP with WatsonX print("\n=== Test 8: MCP Integration (WatsonX) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="What is the capital of France?" ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (WatsonX): {e}") raise e # Test 9: MCP with Llama 3.3 print("\n=== Test 9: MCP Integration (Llama 3.3) ===") try: mcp_llama_stack_client_response = client.responses.create( model=WATSONX_MODEL_ID, input="Tell me about some parks in Rhode Island, and let me know if there are any upcoming events at them.", tools=[ { "type": "mcp", "server_url": NPS_MCP_URL, "server_label": "National Parks Service tools", "allowed_tools": ["search_parks", "get_park_events"], } ] ) print_response(mcp_llama_stack_client_response) except Exception as e: print(f"Error with MCP (Llama 3.3): {e}") raise e # Test 10: Embeddings print("\n=== Test 10: Embeddings ===") try: conn = http.client.HTTPConnection("localhost:8321") payload = json.dumps({ "model": "watsonx/ibm/granite-embedding-278m-multilingual", "input": "Hello, world!", }) headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } conn.request("POST", "/v1/openai/v1/embeddings", payload, headers) res = conn.getresponse() data = res.read() print(data.decode("utf-8")) except Exception as e: print(f"Error with Embeddings: {e}") raise e print("\n=== Testing Complete ===") if __name__ == "__main__": main() ``` --------- Signed-off-by: Bill Murdock <bmurdock@redhat.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
671 lines
29 KiB
Python
671 lines
29 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from datetime import UTC, datetime
|
|
from typing import Annotated, Any
|
|
|
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
|
from pydantic import Field, TypeAdapter
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
CompletionMessage,
|
|
CompletionResponse,
|
|
CompletionResponseStreamChunk,
|
|
Inference,
|
|
ListOpenAIChatCompletionResponse,
|
|
Message,
|
|
OpenAIAssistantMessageParam,
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChatCompletionToolCall,
|
|
OpenAIChatCompletionToolCallFunction,
|
|
OpenAIChoice,
|
|
OpenAIChoiceLogprobs,
|
|
OpenAICompletion,
|
|
OpenAICompletionWithInputMessages,
|
|
OpenAIEmbeddingsResponse,
|
|
OpenAIMessageParam,
|
|
OpenAIResponseFormatParam,
|
|
Order,
|
|
StopReason,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.apis.models import Model, ModelType
|
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
|
|
|
logger = get_logger(name=__name__, category="core::routers")
|
|
|
|
|
|
class InferenceRouter(Inference):
|
|
"""Routes to an provider based on the model"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
telemetry: Telemetry | None = None,
|
|
store: InferenceStore | None = None,
|
|
) -> None:
|
|
logger.debug("Initializing InferenceRouter")
|
|
self.routing_table = routing_table
|
|
self.telemetry = telemetry
|
|
self.store = store
|
|
if self.telemetry:
|
|
self.tokenizer = Tokenizer.get_instance()
|
|
self.formatter = ChatFormat(self.tokenizer)
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("InferenceRouter.initialize")
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("InferenceRouter.shutdown")
|
|
if self.store:
|
|
try:
|
|
await self.store.shutdown()
|
|
except Exception as e:
|
|
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
|
|
|
async def register_model(
|
|
self,
|
|
model_id: str,
|
|
provider_model_id: str | None = None,
|
|
provider_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
model_type: ModelType | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
|
)
|
|
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
|
|
|
def _construct_metrics(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
model: Model,
|
|
) -> list[MetricEvent]:
|
|
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
|
|
|
Args:
|
|
prompt_tokens: Number of tokens in the prompt
|
|
completion_tokens: Number of tokens in the completion
|
|
total_tokens: Total number of tokens used
|
|
model: Model object containing model_id and provider_id
|
|
|
|
Returns:
|
|
List of MetricEvent objects with token usage metrics
|
|
"""
|
|
span = get_current_span()
|
|
if span is None:
|
|
logger.warning("No span found for token usage metrics")
|
|
return []
|
|
|
|
metrics = [
|
|
("prompt_tokens", prompt_tokens),
|
|
("completion_tokens", completion_tokens),
|
|
("total_tokens", total_tokens),
|
|
]
|
|
metric_events = []
|
|
for metric_name, value in metrics:
|
|
metric_events.append(
|
|
MetricEvent(
|
|
trace_id=span.trace_id,
|
|
span_id=span.span_id,
|
|
metric=metric_name,
|
|
value=value,
|
|
timestamp=datetime.now(UTC),
|
|
unit="tokens",
|
|
attributes={
|
|
"model_id": model.model_id,
|
|
"provider_id": model.provider_id,
|
|
},
|
|
)
|
|
)
|
|
return metric_events
|
|
|
|
async def _compute_and_log_token_usage(
|
|
self,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
model: Model,
|
|
) -> list[MetricInResponse]:
|
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
|
if self.telemetry:
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def _count_tokens(
|
|
self,
|
|
messages: list[Message] | InterleavedContent,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> int | None:
|
|
if not hasattr(self, "formatter") or self.formatter is None:
|
|
return None
|
|
|
|
if isinstance(messages, list):
|
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
|
else:
|
|
encoded = self.formatter.encode_content(messages)
|
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
|
|
|
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
|
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
|
model = await self.routing_table.get_model(model_id)
|
|
if model is None:
|
|
raise ModelNotFoundError(model_id)
|
|
if model.model_type != expected_model_type:
|
|
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
|
return model
|
|
|
|
async def openai_completion(
|
|
self,
|
|
model: str,
|
|
prompt: str | list[str] | list[int] | list[list[int]],
|
|
best_of: int | None = None,
|
|
echo: bool | None = None,
|
|
frequency_penalty: float | None = None,
|
|
logit_bias: dict[str, float] | None = None,
|
|
logprobs: bool | None = None,
|
|
max_tokens: int | None = None,
|
|
n: int | None = None,
|
|
presence_penalty: float | None = None,
|
|
seed: int | None = None,
|
|
stop: str | list[str] | None = None,
|
|
stream: bool | None = None,
|
|
stream_options: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
guided_choice: list[str] | None = None,
|
|
prompt_logprobs: int | None = None,
|
|
suffix: str | None = None,
|
|
) -> OpenAICompletion:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
|
)
|
|
model_obj = await self._get_model(model, ModelType.llm)
|
|
params = dict(
|
|
model=model_obj.identifier,
|
|
prompt=prompt,
|
|
best_of=best_of,
|
|
echo=echo,
|
|
frequency_penalty=frequency_penalty,
|
|
logit_bias=logit_bias,
|
|
logprobs=logprobs,
|
|
max_tokens=max_tokens,
|
|
n=n,
|
|
presence_penalty=presence_penalty,
|
|
seed=seed,
|
|
stop=stop,
|
|
stream=stream,
|
|
stream_options=stream_options,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
user=user,
|
|
guided_choice=guided_choice,
|
|
prompt_logprobs=prompt_logprobs,
|
|
suffix=suffix,
|
|
)
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
if stream:
|
|
return await provider.openai_completion(**params)
|
|
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
|
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
|
# response_stream = await provider.openai_completion(**params)
|
|
|
|
response = await provider.openai_completion(**params)
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
model=model_obj,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
model: str,
|
|
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
|
|
frequency_penalty: float | None = None,
|
|
function_call: str | dict[str, Any] | None = None,
|
|
functions: list[dict[str, Any]] | None = None,
|
|
logit_bias: dict[str, float] | None = None,
|
|
logprobs: bool | None = None,
|
|
max_completion_tokens: int | None = None,
|
|
max_tokens: int | None = None,
|
|
n: int | None = None,
|
|
parallel_tool_calls: bool | None = None,
|
|
presence_penalty: float | None = None,
|
|
response_format: OpenAIResponseFormatParam | None = None,
|
|
seed: int | None = None,
|
|
stop: str | list[str] | None = None,
|
|
stream: bool | None = None,
|
|
stream_options: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
tools: list[dict[str, Any]] | None = None,
|
|
top_logprobs: int | None = None,
|
|
top_p: float | None = None,
|
|
user: str | None = None,
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
|
)
|
|
model_obj = await self._get_model(model, ModelType.llm)
|
|
|
|
# Use the OpenAI client for a bit of extra input validation without
|
|
# exposing the OpenAI client itself as part of our API surface
|
|
if tool_choice:
|
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
|
if tools is None:
|
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
|
if tools:
|
|
for tool in tools:
|
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
|
|
|
# Some providers make tool calls even when tool_choice is "none"
|
|
# so just clear them both out to avoid unexpected tool calls
|
|
if tool_choice == "none" and tools is not None:
|
|
tool_choice = None
|
|
tools = None
|
|
|
|
params = dict(
|
|
model=model_obj.identifier,
|
|
messages=messages,
|
|
frequency_penalty=frequency_penalty,
|
|
function_call=function_call,
|
|
functions=functions,
|
|
logit_bias=logit_bias,
|
|
logprobs=logprobs,
|
|
max_completion_tokens=max_completion_tokens,
|
|
max_tokens=max_tokens,
|
|
n=n,
|
|
parallel_tool_calls=parallel_tool_calls,
|
|
presence_penalty=presence_penalty,
|
|
response_format=response_format,
|
|
seed=seed,
|
|
stop=stop,
|
|
stream=stream,
|
|
stream_options=stream_options,
|
|
temperature=temperature,
|
|
tool_choice=tool_choice,
|
|
tools=tools,
|
|
top_logprobs=top_logprobs,
|
|
top_p=top_p,
|
|
user=user,
|
|
)
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
if stream:
|
|
response_stream = await provider.openai_chat_completion(**params)
|
|
|
|
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
|
|
# We need to add metrics to each chunk and store the final completion
|
|
return self.stream_tokens_and_compute_metrics_openai_chat(
|
|
response=response_stream,
|
|
model=model_obj,
|
|
messages=messages,
|
|
)
|
|
|
|
response = await self._nonstream_openai_chat_completion(provider, params)
|
|
|
|
# Store the response with the ID that will be returned to the client
|
|
if self.store:
|
|
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
|
|
|
if self.telemetry:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=response.usage.prompt_tokens,
|
|
completion_tokens=response.usage.completion_tokens,
|
|
total_tokens=response.usage.total_tokens,
|
|
model=model_obj,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
# these metrics will show up in the client response.
|
|
response.metrics = (
|
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
)
|
|
return response
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
model: str,
|
|
input: str | list[str],
|
|
encoding_format: str | None = "float",
|
|
dimensions: int | None = None,
|
|
user: str | None = None,
|
|
) -> OpenAIEmbeddingsResponse:
|
|
logger.debug(
|
|
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
|
)
|
|
model_obj = await self._get_model(model, ModelType.embedding)
|
|
params = dict(
|
|
model=model_obj.identifier,
|
|
input=input,
|
|
encoding_format=encoding_format,
|
|
dimensions=dimensions,
|
|
user=user,
|
|
)
|
|
|
|
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
|
|
return await provider.openai_embeddings(**params)
|
|
|
|
async def list_chat_completions(
|
|
self,
|
|
after: str | None = None,
|
|
limit: int | None = 20,
|
|
model: str | None = None,
|
|
order: Order | None = Order.desc,
|
|
) -> ListOpenAIChatCompletionResponse:
|
|
if self.store:
|
|
return await self.store.list_chat_completions(after, limit, model, order)
|
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
|
|
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
|
if self.store:
|
|
return await self.store.get_chat_completion(completion_id)
|
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
|
|
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
|
|
response = await provider.openai_chat_completion(**params)
|
|
for choice in response.choices:
|
|
# some providers return an empty list for no tool calls in non-streaming responses
|
|
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
|
|
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
|
|
choice.message.tool_calls = None
|
|
return response
|
|
|
|
async def health(self) -> dict[str, HealthResponse]:
|
|
health_statuses = {}
|
|
timeout = 1 # increasing the timeout to 1 second for health checks
|
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
|
try:
|
|
# check if the provider has a health method
|
|
if not hasattr(impl, "health"):
|
|
continue
|
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
|
health_statuses[provider_id] = health
|
|
except TimeoutError:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR,
|
|
message=f"Health check timed out after {timeout} seconds",
|
|
)
|
|
except NotImplementedError:
|
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
|
except Exception as e:
|
|
health_statuses[provider_id] = HealthResponse(
|
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
|
)
|
|
return health_statuses
|
|
|
|
async def stream_tokens_and_compute_metrics(
|
|
self,
|
|
response,
|
|
prompt_tokens,
|
|
model,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
completion_text = ""
|
|
async for chunk in response:
|
|
complete = False
|
|
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
|
if chunk.event.delta.type == "text":
|
|
completion_text += chunk.event.delta.text
|
|
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(
|
|
[
|
|
CompletionMessage(
|
|
content=completion_text,
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
],
|
|
tool_prompt_format=tool_prompt_format,
|
|
)
|
|
else:
|
|
if hasattr(chunk, "delta"):
|
|
completion_text += chunk.delta
|
|
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry:
|
|
complete = True
|
|
completion_tokens = await self._count_tokens(completion_text)
|
|
# if we are done receiving tokens
|
|
if complete:
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for streaming completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in [
|
|
"completion_tokens",
|
|
"total_tokens",
|
|
]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
else:
|
|
# Fallback if no telemetry
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
model,
|
|
)
|
|
async_metrics = [
|
|
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
]
|
|
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
yield chunk
|
|
|
|
async def count_tokens_and_compute_metrics(
|
|
self,
|
|
response: ChatCompletionResponse | CompletionResponse,
|
|
prompt_tokens,
|
|
model,
|
|
tool_prompt_format: ToolPromptFormat | None = None,
|
|
):
|
|
if isinstance(response, ChatCompletionResponse):
|
|
content = [response.completion_message]
|
|
else:
|
|
content = response.content
|
|
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
|
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
|
|
# Create a separate span for completion metrics
|
|
if self.telemetry:
|
|
# Log metrics in the new span context
|
|
completion_metrics = self._construct_metrics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in completion_metrics:
|
|
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
|
enqueue_event(metric)
|
|
|
|
# Return metrics in response
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
|
|
|
# Fallback if no telemetry
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens or 0,
|
|
completion_tokens or 0,
|
|
total_tokens,
|
|
model,
|
|
)
|
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
|
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
|
self,
|
|
response: AsyncIterator[OpenAIChatCompletionChunk],
|
|
model: Model,
|
|
messages: list[OpenAIMessageParam] | None = None,
|
|
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""Stream OpenAI chat completion chunks, compute metrics, and store the final completion."""
|
|
id = None
|
|
created = None
|
|
choices_data: dict[int, dict[str, Any]] = {}
|
|
|
|
try:
|
|
async for chunk in response:
|
|
# Skip None chunks
|
|
if chunk is None:
|
|
continue
|
|
|
|
# Capture ID and created timestamp from first chunk
|
|
if id is None and chunk.id:
|
|
id = chunk.id
|
|
if created is None and chunk.created:
|
|
created = chunk.created
|
|
|
|
# Accumulate choice data for final assembly
|
|
if chunk.choices:
|
|
for choice_delta in chunk.choices:
|
|
idx = choice_delta.index
|
|
if idx not in choices_data:
|
|
choices_data[idx] = {
|
|
"content_parts": [],
|
|
"tool_calls_builder": {},
|
|
"finish_reason": "stop",
|
|
"logprobs_content_parts": [],
|
|
}
|
|
current_choice_data = choices_data[idx]
|
|
|
|
if choice_delta.delta:
|
|
delta = choice_delta.delta
|
|
if delta.content:
|
|
current_choice_data["content_parts"].append(delta.content)
|
|
if delta.tool_calls:
|
|
for tool_call_delta in delta.tool_calls:
|
|
tc_idx = tool_call_delta.index
|
|
if tc_idx not in current_choice_data["tool_calls_builder"]:
|
|
current_choice_data["tool_calls_builder"][tc_idx] = {
|
|
"id": None,
|
|
"type": "function",
|
|
"function_name_parts": [],
|
|
"function_arguments_parts": [],
|
|
}
|
|
builder = current_choice_data["tool_calls_builder"][tc_idx]
|
|
if tool_call_delta.id:
|
|
builder["id"] = tool_call_delta.id
|
|
if tool_call_delta.type:
|
|
builder["type"] = tool_call_delta.type
|
|
if tool_call_delta.function:
|
|
if tool_call_delta.function.name:
|
|
builder["function_name_parts"].append(tool_call_delta.function.name)
|
|
if tool_call_delta.function.arguments:
|
|
builder["function_arguments_parts"].append(
|
|
tool_call_delta.function.arguments
|
|
)
|
|
if choice_delta.finish_reason:
|
|
current_choice_data["finish_reason"] = choice_delta.finish_reason
|
|
if choice_delta.logprobs and choice_delta.logprobs.content:
|
|
current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content)
|
|
|
|
# Compute metrics on final chunk
|
|
if chunk.choices and chunk.choices[0].finish_reason:
|
|
completion_text = ""
|
|
for choice_data in choices_data.values():
|
|
completion_text += "".join(choice_data["content_parts"])
|
|
|
|
# Add metrics to the chunk
|
|
if self.telemetry and hasattr(chunk, "usage") and chunk.usage:
|
|
metrics = self._construct_metrics(
|
|
prompt_tokens=chunk.usage.prompt_tokens,
|
|
completion_tokens=chunk.usage.completion_tokens,
|
|
total_tokens=chunk.usage.total_tokens,
|
|
model=model,
|
|
)
|
|
for metric in metrics:
|
|
enqueue_event(metric)
|
|
|
|
yield chunk
|
|
finally:
|
|
# Store the final assembled completion
|
|
if id and self.store and messages:
|
|
assembled_choices: list[OpenAIChoice] = []
|
|
for choice_idx, choice_data in choices_data.items():
|
|
content_str = "".join(choice_data["content_parts"])
|
|
assembled_tool_calls: list[OpenAIChatCompletionToolCall] = []
|
|
if choice_data["tool_calls_builder"]:
|
|
for tc_build_data in choice_data["tool_calls_builder"].values():
|
|
if tc_build_data["id"]:
|
|
func_name = "".join(tc_build_data["function_name_parts"])
|
|
func_args = "".join(tc_build_data["function_arguments_parts"])
|
|
assembled_tool_calls.append(
|
|
OpenAIChatCompletionToolCall(
|
|
id=tc_build_data["id"],
|
|
type=tc_build_data["type"],
|
|
function=OpenAIChatCompletionToolCallFunction(
|
|
name=func_name, arguments=func_args
|
|
),
|
|
)
|
|
)
|
|
message = OpenAIAssistantMessageParam(
|
|
role="assistant",
|
|
content=content_str if content_str else None,
|
|
tool_calls=assembled_tool_calls if assembled_tool_calls else None,
|
|
)
|
|
logprobs_content = choice_data["logprobs_content_parts"]
|
|
final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None
|
|
|
|
assembled_choices.append(
|
|
OpenAIChoice(
|
|
finish_reason=choice_data["finish_reason"],
|
|
index=choice_idx,
|
|
message=message,
|
|
logprobs=final_logprobs,
|
|
)
|
|
)
|
|
|
|
final_response = OpenAIChatCompletion(
|
|
id=id,
|
|
choices=assembled_choices,
|
|
created=created or int(time.time()),
|
|
model=model.identifier,
|
|
object="chat.completion",
|
|
)
|
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|