# LLM model
llm = VertexAI(
model_name="text-bison@001",
max_output_tokens=256,
temperature=0.0,
top_p=0.8,
top_k=40,
verbose=True,
)
embed_model = LangchainEmbedding(VertexAIEmbeddings())
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
table_schema = {
'users': {
'description': 'list of users with email addresses in column eml_adr'
},
'programmes': {
'description': 'list of programmes users can be enrolled on. Programme names in column progrm_n',
},
}
tables_to_include = list(table_schema.keys())
sql_database = SQLDatabase(engine, include_tables=tables_to_include)
# Build an index of relevant tables
table_node_mapping = SQLTableNodeMapping(sql_database)
# Build a list containing SQLTableSchema obj for each table in table_schema
table_schema_objs = [
(SQLTableSchema(table_name=k, context_str=v['description'])) for k, v in table_schema.items()
]
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
table_retriever_query_engine = SQLTableRetrieverQueryEngine(
sql_database,
obj_index.as_retriever(similarity_top_k=1),
)
while True:
response = table_retriever_query_engine.query(input("Enter query >>> "))
print(str(response))