@app.post("/document/query") def query_stream( query: str = Body(...), uuid_filename: str = Body(...), email: str = Body(...), ) -> StreamingResponse: subscription = get_user_subscription(email) model = MODEL_BASIC if subscription == "FREE" else MODEL_PREMIUM with token_counter(model, query_stream.__name__): filename_without_ext = uuid_filename.split(".")[0] # Create index index = initialize_index(model) document_is_indexed = does_document_exist_in_index(filename_without_ext) if document_is_indexed is False: logging.info("Re-adding to index...") reindex_document(filename_without_ext) if is_summary_request(query): query = modify_query_for_summary(query, filename_without_ext, model) chat_engine = initialize_chat_engine(index, filename_without_ext) streaming_response = chat_engine.stream_chat(query) # takes 10 seconds!! def generate() -> Generator[str, any, None]: yield from streaming_response.response_gen return StreamingResponse(generate(), media_type="text/plain")
def initialize_chat_engine(index: VectorStoreIndex, document_uuid: str) -> BaseChatEngine: """Initialize chat engine with chat history.""" chat_history = get_chat_history(document_uuid) filters = MetadataFilters( filters=[ExactMatchFilter(key="doc_id", value=document_uuid)], ) return index.as_chat_engine( chat_mode=ChatMode.CONTEXT, condense_question_prompt=PromptTemplate(CHAT_PROMPT_TEMPLATE), chat_history=chat_history, agent_chat_response_mode="StreamingAgentChatResponse", similarity_top_k=10, filters=filters, )
service_context_basic = ServiceContext.from_defaults( llm=OpenAI(temperature=0, model=MODEL_BASIC, timeout=180), callback_manager=callback_manager_basic, embed_model=embed_model, context_window=16385, chunk_size_limit=16385, ) service_context_premium = ServiceContext.from_defaults( llm=OpenAI(temperature=0, model=MODEL_PREMIUM, timeout=180), callback_manager=callback_manager_premium, embed_model=embed_model, context_window=128000, chunk_size_limit=128000, ) def initialize_index(model_name: str = MODEL_BASIC) -> VectorStoreIndex: """Initialize the index. Args: ---- model_name (str, optional): The model name. Defaults to MODEL_BASIC. Returns: ------- Any: The initialized index. """ service_context = service_context_basic if model_name == MODEL_BASIC else service_context_premium return VectorStoreIndex( nodes=[], storage_context=storage_context, service_context=service_context, use_async=True, )