Find answers from the community

Updated 3 months ago

Custom LLM

Hello All, from my understanding, the Llama-index Palm2 abstraction uses the google-generative-ai API to talk to Palm2, is there any way can configure to use the Vertex API instead?
1
W
G
S
20 comments
Thanks for your response. My intent is to use VertexAI from Google which is an enterprise offering and that's available to me at the moment. So I have an API endpoint url which talks to Palm2 and gives me back a response through curl or postman, but don't see any way I can get LLama index to make that api call for me. For testing purposes I download facebook's 13B 8bit quantized model and was able to use a custom LLM. However in this case since I don't have access to the LLM directly but through an API, I was wondering if there a way I could enable this integration.
Yeah you can use the API endpoint in custom LLM, something like this.

Plain Text
class OurLLM(CustomLLM):
    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=context_window,
            num_output=num_output,
            model_name=model_name
        )

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        #Define the API endpoint for your Palm2 model here
        response  = API call to PALM
        return CompletionResponse(text=response.text)

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()

# define our LLM
llm = OurLLM()
Hey mate, just wanted to thankyou for your help, you saved me hours of research and going through the documentation to get to the right place. What you advised worked like a charm. Really appreciate your help!
Happy to help Goku!!
Hey @WhiteFang_Jr , so if I want to write it out, it would be something like this?:
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-alpha"
headers = {"Authorization": "Bearer xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
response = requests.post(API_URL, headers=headers, json=prompt)
assuming I am calling a huggingface hosted LLM
Yes, and you would need to return the response in the required format as well.
return CompletionResponse(text=response.text)
class OurLLM(CustomLLM):

@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=context_window,
num_output=num_output,
model_name=API_URL
)

@llm_completion_callback()
def complete(self, prompt: str, kwargs: Any) -> CompletionResponse: API_URL = "" headers = {"Authorization": "Bearer xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"} response = requests.post(API_URL, headers=headers, json=prompt) return CompletionResponse(text=response.text) @llm_completion_callback() def stream_complete(self, prompt: str, kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
notice in the orignal code the model_name was defined as:
model_name = "facebook/opt-iml-max-30b"
Yes as for the doc that model was used.
so should I keep the model_name and change it to model_name=API_URL
You could also put it as model_name=zephyr-7b-alpha
yes I guess I can do that as well.
for the sake of accuracy, the code should then look like the following:

class OurLLM(CustomLLM):

@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=context_window,
num_output=num_output,
model_name=API_URL # (or 'zephyr-7b-alpha')
)

@llm_completion_callback()
def complete(self, prompt: str, kwargs: Any) -> CompletionResponse:
API_URL = "https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-alpha"
headers = {"Authorization": "Bearer xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
response = requests.post(API_URL, headers=headers, json=prompt)
return CompletionResponse(text=response.text)

@llm_completion_callback()
def stream_complete(self, prompt: str, kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
Hi
I'm trying something like this
_get_prompt and _llm are not accessible although they are available in the LLM class, which is the parent class for CustomLLM
class OurLLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=4096,
num_output=256,
model_name= "Orca-2-13b"
)

@llm_completion_callback()
def complete(self, prompt: str, kwargs: Any) -> CompletionResponse: #Define the API endpoint for your Palm2 model here API_URL= "" response = requests.post(API_URL,json=prompt) return CompletionResponse(text=response.text) @llm_completion_callback() def stream_complete(self, prompt: str, kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()

def predict(self, prompt,output_cls,promptArgs) -> str: #Define the API endpoint for your Palm2 model here formatted_prompt = self._get_prompt(prompt, promptArgs)
response = self.complete(formatted_prompt)

response = self._llm.complete(formatted_prompt)
return response.text
Add a reply
Sign up and join the conversation on Discord