That helps!
Here is what I have thus far:
import os
from typing import List, Optional, Mapping, Any
from google.cloud import aiplatform
from google.protobuf.struct_pb2 import Value
from google.protobuf import json_format
from llama_index import LLMPredictor, SimpleDirectoryReader, ListIndex, ServiceContext
from langchain.llms.base import LLM
from langchain.llms import VertexAI
from llama_index import LangchainEmbedding, ListIndex, SimpleDirectoryReader
from langchain.embeddings import VertexAIEmbeddings
class Chat:
def __init__(self):
self.messages = []
def send_message(self, message):
self.messages.append(message)
return "Response to " + message
os.environ[
"GOOGLE_APPLICATION_CREDENTIALS"
] = "../gcp-enterprise-data-chat-1c02e4fff19e.json"
project = "gcp-enterprise-data-chat"
location = "us-west1"
chat = Chat()
# set context window size
context_window = 4096
# set number of output tokens
num_output = 256
class PaLM(VertexAI):
model_name = "text-bison@001"
total_tokens_used = 0
last_token_usage = 0
def __init__(self):
super().__init__()
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = query_google_llm(chat, prompt)
return response
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"name_of_model": self.model_name}
@property
def _llm_type(self) -> str:
return "custom"
def query_google_llm(chat, query):
response = chat.send_message(query)
return response
# load in HF embedding model from langchain
embed_model = LangchainEmbedding(VertexAIEmbeddings())
# define our LLM
llm_predictor = LLMPredictor(llm=PaLM())
service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor,
context_window=context_window,
num_output=num_output,
embed_model=embed_model,
)
# Load the your data
documents = SimpleDirectoryReader("../data/llama_index").load_data()
index = ListIndex.from_documents(documents, service_context=service_context)
# Query and print response
query_engine = index.as_query_engine()
response = query_engine.query("<query_text>")
print(response)