mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / generate-matrix (push) Successful in 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.12) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 1s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 6s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
API Conformance Tests / check-schema-compatibility (push) Successful in 12s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 48s
We'd like to remove the dependence of `llama-stack` on `llama-stack-client`. This is a necessary step. A few small cleanups - Enables `embeddings` now also - Remove ModelRegistryHelper dependency (unused) - Consolidate to auth_credential field via RemoteInferenceProviderConfig - Implement list_models() to fetch from downstream /v1/models ## Test Plan Tested using this script https://gist.github.com/ashwinb/6356463d10f989c0682ab3bff8589581 Output: ``` Listing models from downstream server... Available models: ['passthrough/ollama/nomic-embed-text:latest', 'passthrough/ollama/all-minilm:l6-v2', 'passthrough/ollama/llama3.2-vision:11b', 'passthrough/ollama/llama3.2-vision:latest', 'passthrough/ollama/llama-guard3:1b', 'passthrough/o llama/llama3.2:1b', 'passthrough/ollama/all-minilm:latest', 'passthrough/ollama/llama3.2:3b', 'passthrough/ollama/llama3.2:3b-instruct-fp16', 'passthrough/bedrock/meta.llama3-1-8b-instruct-v1:0', 'passthrough/bedrock/meta.llama3-1-70b-instruct -v1:0', 'passthrough/bedrock/meta.llama3-1-405b-instruct-v1:0', 'passthrough/sentence-transformers/nomic-ai/nomic-embed-text-v1.5'] Using LLM model: passthrough/ollama/llama3.2-vision:11b Making inference request... Response: 4. --- Testing streaming --- Streamed response: ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='1', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='', index=0, logprobs=None)], created=1762381674, m odel='passthrough/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None) ... 5ChatCompletionChunk(id='chatcmpl-64', choices=[Choice(delta=ChoiceDelta(content='', reasoning_content=None, refusal=None, role='assistant', tool_calls=None), finish_reason='stop', index=0, logprobs=None)], created=1762381674, model='passthrou gh/ollama/llama3.2-vision:11b', object='chat.completion.chunk', usage=None) ```
135 lines
5 KiB
Python
135 lines
5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import AsyncIterator
|
|
|
|
from openai import AsyncOpenAI
|
|
|
|
from llama_stack.apis.inference import (
|
|
Inference,
|
|
OpenAIChatCompletion,
|
|
OpenAIChatCompletionChunk,
|
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
OpenAICompletion,
|
|
OpenAICompletionRequestWithExtraBody,
|
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
OpenAIEmbeddingsResponse,
|
|
)
|
|
from llama_stack.apis.models import Model
|
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
|
|
|
from .config import PassthroughImplConfig
|
|
|
|
|
|
class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
|
def __init__(self, config: PassthroughImplConfig) -> None:
|
|
self.config = config
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def unregister_model(self, model_id: str) -> None:
|
|
pass
|
|
|
|
async def register_model(self, model: Model) -> Model:
|
|
return model
|
|
|
|
async def list_models(self) -> list[Model]:
|
|
"""List models by calling the downstream /v1/models endpoint."""
|
|
client = self._get_openai_client()
|
|
|
|
response = await client.models.list()
|
|
|
|
# Convert from OpenAI format to Llama Stack Model format
|
|
models = []
|
|
for model_data in response.data:
|
|
downstream_model_id = model_data.id
|
|
custom_metadata = getattr(model_data, "custom_metadata", {}) or {}
|
|
|
|
# Prefix identifier with provider ID for local registry
|
|
local_identifier = f"{self.__provider_id__}/{downstream_model_id}"
|
|
|
|
model = Model(
|
|
identifier=local_identifier,
|
|
provider_id=self.__provider_id__,
|
|
provider_resource_id=downstream_model_id,
|
|
model_type=custom_metadata.get("model_type", "llm"),
|
|
metadata=custom_metadata,
|
|
)
|
|
models.append(model)
|
|
|
|
return models
|
|
|
|
async def should_refresh_models(self) -> bool:
|
|
"""Passthrough should refresh models since they come from downstream dynamically."""
|
|
return self.config.refresh_models
|
|
|
|
def _get_openai_client(self) -> AsyncOpenAI:
|
|
"""Get an AsyncOpenAI client configured for the downstream server."""
|
|
base_url = self._get_passthrough_url()
|
|
api_key = self._get_passthrough_api_key()
|
|
|
|
return AsyncOpenAI(
|
|
base_url=f"{base_url.rstrip('/')}/v1",
|
|
api_key=api_key,
|
|
)
|
|
|
|
def _get_passthrough_url(self) -> str:
|
|
"""Get the passthrough URL from config or provider data."""
|
|
if self.config.url is not None:
|
|
return self.config.url
|
|
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None:
|
|
raise ValueError(
|
|
'Pass url of the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_url": <your passthrough url>}'
|
|
)
|
|
return provider_data.passthrough_url
|
|
|
|
def _get_passthrough_api_key(self) -> str:
|
|
"""Get the passthrough API key from config or provider data."""
|
|
if self.config.auth_credential is not None:
|
|
return self.config.auth_credential.get_secret_value()
|
|
|
|
provider_data = self.get_request_provider_data()
|
|
if provider_data is None:
|
|
raise ValueError(
|
|
'Pass API Key for the passthrough endpoint in the header X-LlamaStack-Provider-Data as { "passthrough_api_key": <your api key>}'
|
|
)
|
|
return provider_data.passthrough_api_key
|
|
|
|
async def openai_completion(
|
|
self,
|
|
params: OpenAICompletionRequestWithExtraBody,
|
|
) -> OpenAICompletion:
|
|
"""Forward completion request to downstream using OpenAI client."""
|
|
client = self._get_openai_client()
|
|
request_params = params.model_dump(exclude_none=True)
|
|
response = await client.completions.create(**request_params)
|
|
return response # type: ignore
|
|
|
|
async def openai_chat_completion(
|
|
self,
|
|
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
"""Forward chat completion request to downstream using OpenAI client."""
|
|
client = self._get_openai_client()
|
|
request_params = params.model_dump(exclude_none=True)
|
|
response = await client.chat.completions.create(**request_params)
|
|
return response # type: ignore
|
|
|
|
async def openai_embeddings(
|
|
self,
|
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
) -> OpenAIEmbeddingsResponse:
|
|
"""Forward embeddings request to downstream using OpenAI client."""
|
|
client = self._get_openai_client()
|
|
request_params = params.model_dump(exclude_none=True)
|
|
response = await client.embeddings.create(**request_params)
|
|
return response # type: ignore
|