jina_rerank.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from __future__ import annotations
  2. from copy import deepcopy
  3. from typing import Any, Dict, List, Optional, Sequence, Union
  4. import requests
  5. from langchain_core.callbacks import Callbacks
  6. from langchain_core.documents import BaseDocumentCompressor, Document
  7. from langchain_core.utils import get_from_dict_or_env
  8. from pydantic import ConfigDict, model_validator
  9. JINA_API_URL: str = "https://api.jina.ai/v1/rerank"
  10. class JinaRerank(BaseDocumentCompressor):
  11. """Document compressor that uses `Jina Rerank API` with support for custom base_url."""
  12. session: Any = None
  13. """Requests session to communicate with API."""
  14. top_n: Optional[int] = 3
  15. """Number of documents to return."""
  16. model: str = "jina-reranker-v1-base-en"
  17. """Model to use for reranking."""
  18. jina_api_key: Optional[str] = None
  19. """Jina API key. Must be specified directly or via environment variable
  20. JINA_API_KEY."""
  21. user_agent: str = "langchain"
  22. """Identifier for the application making the request."""
  23. base_url: str = JINA_API_URL
  24. """Base URL for the Jina API. Defaults to JINA_API_URL."""
  25. model_config = ConfigDict(
  26. arbitrary_types_allowed=True,
  27. extra="forbid",
  28. )
  29. @model_validator(mode="before")
  30. @classmethod
  31. def validate_environment(cls, values: Dict) -> Any:
  32. """Validate that api key exists in environment."""
  33. jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY")
  34. user_agent = values.get("user_agent", "langchain")
  35. session = requests.Session()
  36. session.headers.update(
  37. {
  38. "Authorization": f"Bearer {jina_api_key}",
  39. "Accept-Encoding": "identity",
  40. "Content-type": "application/json",
  41. "user-agent": user_agent,
  42. }
  43. )
  44. values["session"] = session
  45. return values
  46. def rerank(
  47. self,
  48. documents: Sequence[Union[str, Document, dict]],
  49. query: str,
  50. *,
  51. model: Optional[str] = None,
  52. top_n: Optional[int] = -1,
  53. max_chunks_per_doc: Optional[int] = None,
  54. ) -> List[Dict[str, Any]]:
  55. """Returns an ordered list of documents ordered by their relevance to the provided query.
  56. Args:
  57. query: The query to use for reranking.
  58. documents: A sequence of documents to rerank.
  59. model: The model to use for re-ranking. Default to self.model.
  60. top_n : The number of results to return. If None returns all results.
  61. Defaults to self.top_n.
  62. max_chunks_per_doc : The maximum number of chunks derived from a document.
  63. """ # noqa: E501
  64. if len(documents) == 0: # to avoid empty api call
  65. return []
  66. docs = [
  67. doc.page_content if isinstance(doc, Document) else doc for doc in documents
  68. ]
  69. model = model or self.model
  70. top_n = top_n if (top_n is None or top_n > 0) else self.top_n
  71. data = {
  72. "query": query,
  73. "documents": docs,
  74. "model": model,
  75. "top_n": top_n,
  76. }
  77. # 构建完整的API路径
  78. api_url = self.base_url.rstrip('/') + '/rerank'
  79. resp = self.session.post(
  80. api_url,
  81. json=data,
  82. ).json()
  83. if "results" not in resp:
  84. raise RuntimeError(resp["detail"])
  85. results = resp["results"]
  86. result_dicts = []
  87. for res in results:
  88. result_dicts.append(
  89. {
  90. "index": res["index"],
  91. "relevance_score": res["relevance_score"],
  92. }
  93. )
  94. return result_dicts
  95. def compress_documents(
  96. self,
  97. documents: Sequence[Document],
  98. query: str,
  99. top_n: Optional[int] = None,
  100. callbacks: Optional[Callbacks] = None,
  101. ) -> Sequence[Document]:
  102. """
  103. Compress documents using Jina's Rerank API.
  104. Args:
  105. documents: A sequence of documents to compress.
  106. query: The query to use for compressing the documents.
  107. callbacks: Callbacks to run during the compression process.
  108. Returns:
  109. A sequence of compressed documents.
  110. """
  111. compressed = []
  112. for res in self.rerank(documents=documents, query=query, top_n=top_n):
  113. doc = documents[res["index"]]
  114. doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
  115. doc_copy.metadata["relevance_score"] = res["relevance_score"]
  116. compressed.append(doc_copy)
  117. return compressed