@dataclass
class CriaChatResponse(AgentChatResponse):
raw: Optional[dict] = field(default_factory=dict)
class CriaChatEngine(ContextChatEngine):
@classmethod
def from_index(cls, index: BaseIndex, **kwargs):
index.as_query_engine()
return cls.from_defaults(
retriever=index.as_retriever(**kwargs),
**kwargs,
)
async def achat(
self, message: str, chat_history: Optional[List[ChatMessage]] = None
) -> CriaChatResponse:
"""
Should maintain parity with superclass method.
"""
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
context_str_template, nodes = await self._agenerate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
all_messages = prefix_messages + self._memory.get()
chat_response = await self._llm.achat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
return CriaChatResponse( # Custom response with a bit more info
response=str(chat_response.message.content),
sources=[
ToolOutput(
tool_name="retriever",
content=str(prefix_messages[0]),
raw_input={"message": message},
raw_output=prefix_messages[0],
)
],
source_nodes=nodes,
raw=chat_response.raw # Add raw payload info
)