mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Move runpod provider to the correct directory
Also cleanup the test code to avoid skipping tests. Let failures be known and public.
This commit is contained in:
parent
22dc684da6
commit
d78027f3b5
6 changed files with 7 additions and 102 deletions
|
@ -200,8 +200,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="runpod",
|
adapter_type="runpod",
|
||||||
pip_packages=["openai"],
|
pip_packages=["openai"],
|
||||||
module="llama_stack.providers.adapters.inference.runpod",
|
module="llama_stack.providers.remote.inference.runpod",
|
||||||
config_class="llama_stack.providers.adapters.inference.runpod.RunpodImplConfig",
|
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -10,7 +10,7 @@ from .runpod import RunpodInferenceAdapter
|
||||||
|
|
||||||
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
async def get_adapter_impl(config: RunpodImplConfig, _deps):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, RunpodImplConfig
|
config, RunpodImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
impl = RunpodInferenceAdapter(config)
|
impl = RunpodInferenceAdapter(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
|
@ -12,6 +12,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
@ -40,6 +41,8 @@ RUNPOD_SUPPORTED_MODELS = {
|
||||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
def __init__(self, config: RunpodImplConfig) -> None:
|
def __init__(self, config: RunpodImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
|
@ -130,4 +133,4 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model: str,
|
model: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
|
@ -109,19 +109,6 @@ class TestInference:
|
||||||
async def test_completion(self, inference_model, inference_stack):
|
async def test_completion(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
"inline::meta-reference",
|
|
||||||
"remote::ollama",
|
|
||||||
"remote::tgi",
|
|
||||||
"remote::together",
|
|
||||||
"remote::fireworks",
|
|
||||||
"remote::nvidia",
|
|
||||||
"remote::cerebras",
|
|
||||||
"remote::vllm",
|
|
||||||
):
|
|
||||||
pytest.skip("Other inference providers don't support completion() yet")
|
|
||||||
|
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content="Micheael Jordan is born in ",
|
content="Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -155,12 +142,6 @@ class TestInference:
|
||||||
async def test_completion_logprobs(self, inference_model, inference_stack):
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
# "remote::nvidia", -- provider doesn't provide all logprobs
|
|
||||||
):
|
|
||||||
pytest.skip("Other inference providers don't support completion() yet")
|
|
||||||
|
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
content="Micheael Jordan is born in ",
|
content="Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -212,21 +193,6 @@ class TestInference:
|
||||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
"inline::meta-reference",
|
|
||||||
"remote::ollama",
|
|
||||||
"remote::tgi",
|
|
||||||
"remote::together",
|
|
||||||
"remote::fireworks",
|
|
||||||
"remote::nvidia",
|
|
||||||
"remote::vllm",
|
|
||||||
"remote::cerebras",
|
|
||||||
):
|
|
||||||
pytest.skip(
|
|
||||||
"Other inference providers don't support structured output in completions yet"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BaseModel):
|
class Output(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
year_born: str
|
year_born: str
|
||||||
|
@ -275,18 +241,6 @@ class TestInference:
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
"inline::meta-reference",
|
|
||||||
"remote::ollama",
|
|
||||||
"remote::fireworks",
|
|
||||||
"remote::tgi",
|
|
||||||
"remote::together",
|
|
||||||
"remote::vllm",
|
|
||||||
"remote::nvidia",
|
|
||||||
):
|
|
||||||
pytest.skip("Other inference providers don't support structured output yet")
|
|
||||||
|
|
||||||
class AnswerFormat(BaseModel):
|
class AnswerFormat(BaseModel):
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
|
@ -377,20 +331,6 @@ class TestInference:
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if (
|
|
||||||
provider.__provider_spec__.provider_type == "remote::groq"
|
|
||||||
and "Llama-3.2" in inference_model
|
|
||||||
):
|
|
||||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
|
||||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
|
||||||
|
|
||||||
if provider.__provider_spec__.provider_type == "remote::sambanova" and (
|
|
||||||
"-1B-" in inference_model or "-3B-" in inference_model
|
|
||||||
):
|
|
||||||
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling for 1B/ 3B
|
|
||||||
pytest.skip("Sambanova's tool calling for lightweight models don't work")
|
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What's the weather like in San Francisco?",
|
content="What's the weather like in San Francisco?",
|
||||||
|
@ -430,17 +370,6 @@ class TestInference:
|
||||||
sample_tool_definition,
|
sample_tool_definition,
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if (
|
|
||||||
provider.__provider_spec__.provider_type == "remote::groq"
|
|
||||||
and "Llama-3.2" in inference_model
|
|
||||||
):
|
|
||||||
# TODO(aidand): Remove this skip once Groq's tool calling for Llama3.2 works better
|
|
||||||
pytest.skip("Groq's tool calling for Llama3.2 doesn't work very well")
|
|
||||||
if provider.__provider_spec__.provider_type == "remote::sambanova":
|
|
||||||
# TODO(snova-edawrdm): Remove this skip once SambaNova's tool calling under streaming is supported (we are working on it)
|
|
||||||
pytest.skip("Sambanova's tool calling for streaming doesn't work")
|
|
||||||
|
|
||||||
messages = sample_messages + [
|
messages = sample_messages + [
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What's the weather like in San Francisco?",
|
content="What's the weather like in San Francisco?",
|
||||||
|
|
|
@ -51,20 +51,6 @@ class TestVisionModelInference:
|
||||||
self, inference_model, inference_stack, image, expected_strings
|
self, inference_model, inference_stack, image, expected_strings
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
"inline::meta-reference",
|
|
||||||
"remote::together",
|
|
||||||
"remote::fireworks",
|
|
||||||
"remote::ollama",
|
|
||||||
"remote::vllm",
|
|
||||||
"remote::sambanova",
|
|
||||||
):
|
|
||||||
pytest.skip(
|
|
||||||
"Other inference providers don't support vision chat completion() yet"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model_id=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -92,19 +78,6 @@ class TestVisionModelInference:
|
||||||
):
|
):
|
||||||
inference_impl, _ = inference_stack
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
|
||||||
if provider.__provider_spec__.provider_type not in (
|
|
||||||
"inline::meta-reference",
|
|
||||||
"remote::together",
|
|
||||||
"remote::fireworks",
|
|
||||||
"remote::ollama",
|
|
||||||
"remote::vllm",
|
|
||||||
"remote::sambanova",
|
|
||||||
):
|
|
||||||
pytest.skip(
|
|
||||||
"Other inference providers don't support vision chat completion() yet"
|
|
||||||
)
|
|
||||||
|
|
||||||
images = [
|
images = [
|
||||||
ImageContentItem(
|
ImageContentItem(
|
||||||
image=dict(
|
image=dict(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue