From 95b7f57d92d4d96110e06aeeafdd96bc455109c8 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 12 Nov 2024 08:21:37 -0800 Subject: [PATCH] use provider resource id to validate for models --- llama_stack/apis/inference/inference.py | 6 +-- llama_stack/distribution/routers/routers.py | 21 ++++++---- .../remote/inference/fireworks/fireworks.py | 10 ++--- .../remote/inference/together/together.py | 10 ++--- .../providers/tests/inference/fixtures.py | 29 ++++++++++++- .../tests/inference/test_text_inference.py | 41 ++++++++++--------- .../utils/inference/model_registry.py | 4 +- 7 files changed, 75 insertions(+), 46 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 1e7b29722..b2681e578 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -226,7 +226,7 @@ class Inference(Protocol): @webmethod(route="/inference/completion") async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -237,7 +237,7 @@ class Inference(Protocol): @webmethod(route="/inference/chat_completion") async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), # zero-shot tool definitions as input to the model @@ -254,6 +254,6 @@ class Inference(Protocol): @webmethod(route="/inference/embeddings") async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 220dfdb56..7d4e43f60 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -95,7 +95,7 @@ class InferenceRouter(Inference): async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -105,8 +105,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: + model = await self.routing_table.get_model(model_id) params = dict( - model=model, + model_id=model.provider_resource_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -116,7 +117,7 @@ class InferenceRouter(Inference): stream=stream, logprobs=logprobs, ) - provider = self.routing_table.get_provider_impl(model) + provider = self.routing_table.get_provider_impl(model_id) if stream: return (chunk async for chunk in await provider.chat_completion(**params)) else: @@ -124,16 +125,17 @@ class InferenceRouter(Inference): async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: - provider = self.routing_table.get_provider_impl(model) + model = await self.routing_table.get_model(model_id) + provider = self.routing_table.get_provider_impl(model_id) params = dict( - model=model, + model_id=model.provider_resource_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -147,11 +149,12 @@ class InferenceRouter(Inference): async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: - return await self.routing_table.get_provider_impl(model).embeddings( - model=model, + model = await self.routing_table.get_model(model_id) + return await self.routing_table.get_provider_impl(model_id).embeddings( + model_id=model.provider_resource_id, contents=contents, ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 57e851c5b..67bf1cf47 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -74,7 +74,7 @@ class FireworksInferenceAdapter( async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -82,7 +82,7 @@ class FireworksInferenceAdapter( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -138,7 +138,7 @@ class FireworksInferenceAdapter( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -149,7 +149,7 @@ class FireworksInferenceAdapter( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -229,7 +229,7 @@ class FireworksInferenceAdapter( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 28a566415..1b04ae556 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -63,7 +63,7 @@ class TogetherInferenceAdapter( async def completion( self, - model: str, + model_id: str, content: InterleavedTextMedia, sampling_params: Optional[SamplingParams] = SamplingParams(), response_format: Optional[ResponseFormat] = None, @@ -71,7 +71,7 @@ class TogetherInferenceAdapter( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = CompletionRequest( - model=model, + model=model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -135,7 +135,7 @@ class TogetherInferenceAdapter( async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), tools: Optional[List[ToolDefinition]] = None, @@ -146,7 +146,7 @@ class TogetherInferenceAdapter( logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( - model=model, + model=model_id, messages=messages, sampling_params=sampling_params, tools=tools or [], @@ -221,7 +221,7 @@ class TogetherInferenceAdapter( async def embeddings( self, - model: str, + model_id: str, contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d35ebab28..2a4eba3ad 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture: ) +def get_model_short_name(model_name: str) -> str: + """Convert model name to a short test identifier. + + Args: + model_name: Full model name like "Llama3.1-8B-Instruct" + + Returns: + Short name like "llama_8b" suitable for test markers + """ + model_name = model_name.lower() + if "vision" in model_name: + return "llama_vision" + elif "3b" in model_name: + return "llama_3b" + elif "8b" in model_name: + return "llama_8b" + else: + return model_name.replace(".", "_").replace("-", "_") + + +@pytest.fixture(scope="session") +def model_id(inference_model) -> str: + return get_model_short_name(inference_model) + + INFERENCE_FIXTURES = [ "meta_reference", "ollama", @@ -154,7 +179,7 @@ INFERENCE_FIXTURES = [ @pytest_asyncio.fixture(scope="session") -async def inference_stack(request, inference_model): +async def inference_stack(request, inference_model, model_id): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") impls = await resolve_impls_for_test_v2( @@ -163,7 +188,7 @@ async def inference_stack(request, inference_model): inference_fixture.provider_data, models=[ ModelInput( - model_id=inference_model, + model_id=model_id, ) ], ) diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index e7bfbc135..9850b328e 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -64,7 +64,7 @@ def sample_tool_definition(): class TestInference: @pytest.mark.asyncio - async def test_model_list(self, inference_model, inference_stack): + async def test_model_list(self, inference_model, inference_stack, model_id): _, models_impl = inference_stack response = await models_impl.list_models() assert isinstance(response, list) @@ -73,17 +73,16 @@ class TestInference: model_def = None for model in response: - if model.identifier == inference_model: + if model.identifier == model_id: model_def = model break assert model_def is not None @pytest.mark.asyncio - async def test_completion(self, inference_model, inference_stack): + async def test_completion(self, inference_model, inference_stack, model_id): inference_impl, _ = inference_stack - - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::ollama", @@ -96,7 +95,7 @@ class TestInference: response = await inference_impl.completion( content="Micheael Jordan is born in ", stream=False, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -110,7 +109,7 @@ class TestInference: async for r in await inference_impl.completion( content="Roses are red,", stream=True, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -125,11 +124,11 @@ class TestInference: @pytest.mark.asyncio @pytest.mark.skip("This test is not quite robust") async def test_completions_structured_output( - self, inference_model, inference_stack + self, inference_model, inference_stack, model_id ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::tgi", @@ -149,7 +148,7 @@ class TestInference: response = await inference_impl.completion( content=user_input, stream=False, - model=inference_model, + model_id=model_id, sampling_params=SamplingParams( max_tokens=50, ), @@ -167,11 +166,11 @@ class TestInference: @pytest.mark.asyncio async def test_chat_completion_non_streaming( - self, inference_model, inference_stack, common_params, sample_messages + self, inference_model, inference_stack, common_params, sample_messages, model_id ): inference_impl, _ = inference_stack response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=sample_messages, stream=False, **common_params, @@ -184,11 +183,11 @@ class TestInference: @pytest.mark.asyncio async def test_structured_output( - self, inference_model, inference_stack, common_params + self, inference_model, inference_stack, common_params, model_id ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl(inference_model) + provider = inference_impl.routing_table.get_provider_impl(model_id) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::fireworks", @@ -204,7 +203,7 @@ class TestInference: num_seasons_in_nba: int response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -227,7 +226,7 @@ class TestInference: assert answer.num_seasons_in_nba == 15 response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(content="Please give me information about Michael Jordan."), @@ -244,13 +243,13 @@ class TestInference: @pytest.mark.asyncio async def test_chat_completion_streaming( - self, inference_model, inference_stack, common_params, sample_messages + self, inference_model, inference_stack, common_params, sample_messages, model_id ): inference_impl, _ = inference_stack response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=sample_messages, stream=True, **common_params, @@ -277,6 +276,7 @@ class TestInference: common_params, sample_messages, sample_tool_definition, + model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -286,7 +286,7 @@ class TestInference: ] response = await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=messages, tools=[sample_tool_definition], stream=False, @@ -316,6 +316,7 @@ class TestInference: common_params, sample_messages, sample_tool_definition, + model_id, ): inference_impl, _ = inference_stack messages = sample_messages + [ @@ -327,7 +328,7 @@ class TestInference: response = [ r async for r in await inference_impl.chat_completion( - model=inference_model, + model_id=model_id, messages=messages, tools=[sample_tool_definition], stream=True, diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 141e4af31..bdc5af0f9 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -29,7 +29,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate): return self.stack_to_provider_models_map[identifier] async def register_model(self, model: Model) -> None: - if model.identifier not in self.stack_to_provider_models_map: + if model.provider_resource_id not in self.stack_to_provider_models_map: raise ValueError( - f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" + f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}" )