Merge branch 'main' into add-nvidia-inference-adapter

This commit is contained in:
Matthew Farrellee 2024-11-17 15:47:13 -05:00
commit c24f882f31
6 changed files with 51 additions and 22 deletions

View file

@ -84,7 +84,7 @@ _MODEL_ALIASES = [
]
class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
@ -117,7 +117,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
def completion(
self,
model: str,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -128,14 +128,14 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
@ -156,7 +156,7 @@ class NVIDIAInferenceAdapter(ModelRegistryHelper, Inference):
request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=self.get_provider_model_id(model),
model=self.get_provider_model_id(model_id),
messages=messages,
sampling_params=sampling_params,
tools=tools,