mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: add (openai, anthropic, gemini) providers via litellm (#1267)
# What does this PR do? This PR introduces more non-llama model support to llama stack. Providers introduced: openai, anthropic and gemini. All of these providers use essentially the same piece of code -- the implementation works via the `litellm` library. We will expose only specific models for providers we enable making sure they all work well and pass tests. This setup (instead of automatically enabling _all_ providers and models allowed by LiteLLM) ensures we can also perform any needed prompt tuning on a per-model basis as needed (just like we do it for llama models.) ## Test Plan ```bash #!/bin/bash args=("$@") for model in openai/gpt-4o anthropic/claude-3-5-sonnet-latest gemini/gemini-1.5-flash; do LLAMA_STACK_CONFIG=dev pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --embedding-model=all-MiniLM-L6-v2 \ --vision-inference-model="" \ --inference-model=$model "${args[@]}" done ```
This commit is contained in:
parent
b0310af177
commit
63e6acd0c3
25 changed files with 1048 additions and 33 deletions
|
@ -57,17 +57,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_model_id=m.provider_model_id,
|
||||
provider_id="fireworks",
|
||||
metadata=m.metadata,
|
||||
model_type=m.model_type,
|
||||
)
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
@ -82,6 +71,16 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()}
|
||||
default_models = [
|
||||
ModelInput(
|
||||
model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id,
|
||||
provider_id="fireworks",
|
||||
model_type=m.model_type,
|
||||
metadata=m.metadata,
|
||||
)
|
||||
for m in MODEL_ENTRIES
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id="sentence-transformers",
|
||||
|
@ -98,7 +97,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
default_models=default_models,
|
||||
default_models=default_models + [embedding_model],
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue