from langchain.retrievers.web_research importWebResearchRetriever from langchain_community.utilities importGoogleSearchAPIWrapper from langchain_community.vectorstores importChroma from langchain_openai importChatOpenAI,OpenAIEmbeddings
logging.basicConfig() logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO) from langchain.chains importRetrievalQAWithSourcesChain
user_input ="How do LLM Powered Autonomous Agents work?" qa_chain =RetrievalQAWithSourcesChain.from_chain_type( llm, retriever=web_research_retriever ) result = qa_chain({"question": user_input}) print(result)
search:GoogleSearchAPIWrapper=Field(..., description="Google Search API Wrapper")
再看下其构造过程:from_llm 函数
deffrom_llm( cls, vectorstore:VectorStore, llm:BaseLLM, search:GoogleSearchAPIWrapper, prompt:Optional[BasePromptTemplate]=None, num_search_results:int=1, text_splitter:RecursiveCharacterTextSplitter=RecursiveCharacterTextSplitter( chunk_size=1500, chunk_overlap=150 ), )->"WebResearchRetriever": """Initialize from llm using default template.
Args: vectorstore: Vector store for storing web pages llm: llm for search question generation search: GoogleSearchAPIWrapper prompt: prompt to generating search questions num_search_results: Number of pages per Google search text_splitter: Text splitter for splitting web pages into chunks
DEFAULT_SEARCH_PROMPT =PromptTemplate( input_variables=["question"], template="""You are an assistant tasked with improving Google search \ results. Generate THREE Google search queries that are similar to \ this question. The output should be a numbered list of questions and each \ should have a question mark at the end: {question}""", )
def_get_relevant_documents( self, query:str, *, run_manager:CallbackManagerForRetrieverRun, )->List[Document]: """Search Google for documents related to the query input.
Args: query: user query
Returns: Relevant documents from all various urls. """
# Get search questions logger.info("Generating questions for Google Search ...") result = self.llm_chain({"question": query}) logger.info(f"Questions for Google Search (raw): {result}") questions = result["text"] logger.info(f"Questions for Google Search: {questions}")
# Get urls logger.info("Searching for relevant urls...") urls_to_look =[] for query in questions: # Google search search_results = self.search_tool(query, self.num_search_results) logger.info("Searching for relevant urls...") logger.info(f"Search results: {search_results}") for res in search_results: if res.get("link",None): urls_to_look.append(res["link"])
# Relevant urls urls =set(urls_to_look)
# Check for any new urls that we have not processed new_urls =list(urls.difference(self.url_database))
logger.info(f"New URLs to load: {new_urls}") # Load, split, and add new urls to vectorstore if new_urls: loader =AsyncHtmlLoader(new_urls, ignore_load_errors=True) html2text =Html2TextTransformer() logger.info("Indexing new urls...") docs = loader.load() docs =list(html2text.transform_documents(docs)) docs = self.text_splitter.split_documents(docs) self.vectorstore.add_documents(docs) self.url_database.extend(new_urls)
# Search for relevant splits # TODO: make this async logger.info("Grabbing most relevant splits from urls...") docs =[] for query in questions: docs.extend(self.vectorstore.similarity_search(query))
# Get unique docs unique_documents_dict ={ (doc.page_content,tuple(sorted(doc.metadata.items()))): doc for doc in docs } unique_documents =list(unique_documents_dict.values()) return unique_documents
1.2.2 GoogleSearchAPIWrapper
这是 Google CSE 检索API的封装类。
classGoogleSearchAPIWrapper(BaseModel): """Wrapper for Google Search API."""
classRetrievalQAWithSourcesChain(BaseQAWithSourcesChain): """Question-answering with sources over an index."""
retriever:BaseRetriever=Field(exclude=True) """Index to connect to.""" reduce_k_below_max_tokens:bool=False """Reduce the number of results to return from store based on tokens limit""" max_tokens_limit:int=3375 """Restrict the docs to return from store based on tokens, enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""