use provider resource id to validate for models

This commit is contained in:
Dinesh Yeduguru 2024-11-12 08:21:37 -08:00
parent e4f14eafe2
commit 95b7f57d92
7 changed files with 75 additions and 46 deletions

View file

@ -226,7 +226,7 @@ class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -237,7 +237,7 @@ class Inference(Protocol):
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat_completion")
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
@ -254,6 +254,6 @@ class Inference(Protocol):
@webmethod(route="/inference/embeddings") @webmethod(route="/inference/embeddings")
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ... ) -> EmbeddingsResponse: ...

View file

@ -95,7 +95,7 @@ class InferenceRouter(Inference):
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -105,8 +105,9 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
params = dict( params = dict(
model=model, model_id=model.provider_resource_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
@ -116,7 +117,7 @@ class InferenceRouter(Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
provider = self.routing_table.get_provider_impl(model) provider = self.routing_table.get_provider_impl(model_id)
if stream: if stream:
return (chunk async for chunk in await provider.chat_completion(**params)) return (chunk async for chunk in await provider.chat_completion(**params))
else: else:
@ -124,16 +125,17 @@ class InferenceRouter(Inference):
async def completion( async def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
provider = self.routing_table.get_provider_impl(model) model = await self.routing_table.get_model(model_id)
provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model=model, model_id=model.provider_resource_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -147,11 +149,12 @@ class InferenceRouter(Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
return await self.routing_table.get_provider_impl(model).embeddings( model = await self.routing_table.get_model(model_id)
model=model, return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model.provider_resource_id,
contents=contents, contents=contents,
) )

View file

@ -74,7 +74,7 @@ class FireworksInferenceAdapter(
async def completion( async def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -82,7 +82,7 @@ class FireworksInferenceAdapter(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = CompletionRequest( request = CompletionRequest(
model=model, model=model_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -138,7 +138,7 @@ class FireworksInferenceAdapter(
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
@ -149,7 +149,7 @@ class FireworksInferenceAdapter(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
@ -229,7 +229,7 @@ class FireworksInferenceAdapter(
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -63,7 +63,7 @@ class TogetherInferenceAdapter(
async def completion( async def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -71,7 +71,7 @@ class TogetherInferenceAdapter(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = CompletionRequest( request = CompletionRequest(
model=model, model=model_id,
content=content, content=content,
sampling_params=sampling_params, sampling_params=sampling_params,
response_format=response_format, response_format=response_format,
@ -135,7 +135,7 @@ class TogetherInferenceAdapter(
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
@ -146,7 +146,7 @@ class TogetherInferenceAdapter(
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model_id,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params,
tools=tools or [], tools=tools or [],
@ -221,7 +221,7 @@ class TogetherInferenceAdapter(
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -142,6 +142,31 @@ def inference_bedrock() -> ProviderFixture:
) )
def get_model_short_name(model_name: str) -> str:
"""Convert model name to a short test identifier.
Args:
model_name: Full model name like "Llama3.1-8B-Instruct"
Returns:
Short name like "llama_8b" suitable for test markers
"""
model_name = model_name.lower()
if "vision" in model_name:
return "llama_vision"
elif "3b" in model_name:
return "llama_3b"
elif "8b" in model_name:
return "llama_8b"
else:
return model_name.replace(".", "_").replace("-", "_")
@pytest.fixture(scope="session")
def model_id(inference_model) -> str:
return get_model_short_name(inference_model)
INFERENCE_FIXTURES = [ INFERENCE_FIXTURES = [
"meta_reference", "meta_reference",
"ollama", "ollama",
@ -154,7 +179,7 @@ INFERENCE_FIXTURES = [
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def inference_stack(request, inference_model): async def inference_stack(request, inference_model, model_id):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2( impls = await resolve_impls_for_test_v2(
@ -163,7 +188,7 @@ async def inference_stack(request, inference_model):
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ models=[
ModelInput( ModelInput(
model_id=inference_model, model_id=model_id,
) )
], ],
) )

View file

@ -64,7 +64,7 @@ def sample_tool_definition():
class TestInference: class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_list(self, inference_model, inference_stack): async def test_model_list(self, inference_model, inference_stack, model_id):
_, models_impl = inference_stack _, models_impl = inference_stack
response = await models_impl.list_models() response = await models_impl.list_models()
assert isinstance(response, list) assert isinstance(response, list)
@ -73,17 +73,16 @@ class TestInference:
model_def = None model_def = None
for model in response: for model in response:
if model.identifier == inference_model: if model.identifier == model_id:
model_def = model model_def = model
break break
assert model_def is not None assert model_def is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion(self, inference_model, inference_stack): async def test_completion(self, inference_model, inference_stack, model_id):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(model_id)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::ollama", "remote::ollama",
@ -96,7 +95,7 @@ class TestInference:
response = await inference_impl.completion( response = await inference_impl.completion(
content="Micheael Jordan is born in ", content="Micheael Jordan is born in ",
stream=False, stream=False,
model=inference_model, model_id=model_id,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -110,7 +109,7 @@ class TestInference:
async for r in await inference_impl.completion( async for r in await inference_impl.completion(
content="Roses are red,", content="Roses are red,",
stream=True, stream=True,
model=inference_model, model_id=model_id,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -125,11 +124,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust") @pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output( async def test_completions_structured_output(
self, inference_model, inference_stack self, inference_model, inference_stack, model_id
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model) provider = inference_impl.routing_table.get_provider_impl(model_id)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::tgi", "remote::tgi",
@ -149,7 +148,7 @@ class TestInference:
response = await inference_impl.completion( response = await inference_impl.completion(
content=user_input, content=user_input,
stream=False, stream=False,
model=inference_model, model_id=model_id,
sampling_params=SamplingParams( sampling_params=SamplingParams(
max_tokens=50, max_tokens=50,
), ),
@ -167,11 +166,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_non_streaming( async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages self, inference_model, inference_stack, common_params, sample_messages, model_id
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=sample_messages, messages=sample_messages,
stream=False, stream=False,
**common_params, **common_params,
@ -184,11 +183,11 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_structured_output( async def test_structured_output(
self, inference_model, inference_stack, common_params self, inference_model, inference_stack, common_params, model_id
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model) provider = inference_impl.routing_table.get_provider_impl(model_id)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::fireworks", "remote::fireworks",
@ -204,7 +203,7 @@ class TestInference:
num_seasons_in_nba: int num_seasons_in_nba: int
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
@ -227,7 +226,7 @@ class TestInference:
assert answer.num_seasons_in_nba == 15 assert answer.num_seasons_in_nba == 15
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Please give me information about Michael Jordan."), UserMessage(content="Please give me information about Michael Jordan."),
@ -244,13 +243,13 @@ class TestInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_completion_streaming( async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages self, inference_model, inference_stack, common_params, sample_messages, model_id
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=sample_messages, messages=sample_messages,
stream=True, stream=True,
**common_params, **common_params,
@ -277,6 +276,7 @@ class TestInference:
common_params, common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
model_id,
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
messages = sample_messages + [ messages = sample_messages + [
@ -286,7 +286,7 @@ class TestInference:
] ]
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=False, stream=False,
@ -316,6 +316,7 @@ class TestInference:
common_params, common_params,
sample_messages, sample_messages,
sample_tool_definition, sample_tool_definition,
model_id,
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
messages = sample_messages + [ messages = sample_messages + [
@ -327,7 +328,7 @@ class TestInference:
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=inference_model, model_id=model_id,
messages=messages, messages=messages,
tools=[sample_tool_definition], tools=[sample_tool_definition],
stream=True, stream=True,

View file

@ -29,7 +29,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return self.stack_to_provider_models_map[identifier] return self.stack_to_provider_models_map[identifier]
async def register_model(self, model: Model) -> None: async def register_model(self, model: Model) -> None:
if model.identifier not in self.stack_to_provider_models_map: if model.provider_resource_id not in self.stack_to_provider_models_map:
raise ValueError( raise ValueError(
f"Unsupported model {model.identifier}. Supported models: {self.stack_to_provider_models_map.keys()}" f"Unsupported model {model.provider_resource_id}. Supported models: {self.stack_to_provider_models_map.keys()}"
) )