feat: add (openai, anthropic, gemini) providers via litellm (#1267)

# What does this PR do?

This PR introduces more non-llama model support to llama stack.
Providers introduced: openai, anthropic and gemini. All of these
providers use essentially the same piece of code -- the implementation
works via the `litellm` library.

We will expose only specific models for providers we enable making sure
they all work well and pass tests. This setup (instead of automatically
enabling _all_ providers and models allowed by LiteLLM) ensures we can
also perform any needed prompt tuning on a per-model basis as needed
(just like we do it for llama models.)

## Test Plan

```bash
#!/bin/bash

args=("$@")
for model in openai/gpt-4o anthropic/claude-3-5-sonnet-latest gemini/gemini-1.5-flash; do
    LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py \
        --embedding-model=all-MiniLM-L6-v2 \
        --vision-inference-model="" \
        --inference-model=$model "${args[@]}"
done
```
This commit is contained in:
Ashwin Bharambe 2025-02-25 22:07:33 -08:00 committed by GitHub
parent b0310af177
commit 63e6acd0c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1048 additions and 33 deletions

View file

@ -116,12 +116,14 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
providers = [p for p in client.providers.list() if p.api == "inference"]
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
if text_model_id:
model_ids = [m.identifier for m in client.models.list()]
if text_model_id and text_model_id not in model_ids:
client.models.register(model_id=text_model_id, provider_id=inference_providers[0])
if vision_model_id:
if vision_model_id and vision_model_id not in model_ids:
client.models.register(model_id=vision_model_id, provider_id=inference_providers[0])
if embedding_model_id and embedding_dimension:
if embedding_model_id and embedding_dimension and embedding_model_id not in model_ids:
# try to find a provider that supports embeddings, if sentence-transformers is not available
selected_provider = None
for p in providers:

View file

@ -19,6 +19,16 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
print(f"Provider: {provider.provider_type} for model {model_id}")
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
@pytest.fixture(scope="session")
def provider_tool_format(inference_provider_type):
return (
@ -35,6 +45,7 @@ def provider_tool_format(inference_provider_type):
],
)
def test_text_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.completion(
@ -56,6 +67,7 @@ def test_text_completion_non_streaming(client_with_models, text_model_id, test_c
],
)
def test_text_completion_streaming(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
response = client_with_models.inference.completion(
@ -79,6 +91,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
],
)
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -107,6 +120,7 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
],
)
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -139,6 +153,8 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
],
)
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
class AnswerFormat(BaseModel):
name: str
year_born: str
@ -237,9 +253,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
tool_prompt_format=tool_prompt_format,
stream=False,
)
# No content is returned for the system message since we expect the
# response to be a tool call
assert response.completion_message.content == ""
# some models can return content for the response in addition to the tool call
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1