multimodal_embedding.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from typing import List
  2. from PIL import Image
  3. import base64
  4. import io
  5. from langchain_openai import OpenAIEmbeddings
  6. from dashscope import MultiModalEmbedding
  7. from src.conf.settings import model_settings
  8. from langfuse import observe
  9. from src.utils.file.image_util import image_util
  10. class Embedding:
  11. """Embedding模型工具"""
  12. def __init__(self, model_name: str = None, api_key: str = None):
  13. """
  14. 初始化Embedding模型
  15. Args:
  16. model_name: 模型名称,若为None则使用配置文件中的值
  17. api_key: API密钥,若为None则使用配置文件中的值
  18. """
  19. # 获取模型配置
  20. self.model_provider = model_settings.model_provider
  21. self.embedding_model_name = model_name or model_settings.embedding_model_name
  22. self.multi_embedding_model_name = model_settings.multimodal_embedding_model_name
  23. self.base_url = model_settings.base_url
  24. self.api_key = api_key or model_settings.api_key
  25. self.dashscope_api_key = model_settings.dashscope_api_key
  26. @observe(name="text_embedding", as_type="embedding")
  27. def get_text_embedding(self, text: str) -> List[float]:
  28. """
  29. 获取文本的embedding
  30. Args:
  31. text: 要获取embedding的文本
  32. Returns:
  33. List[float]: 文本的embedding向量
  34. """
  35. try:
  36. # 使用langchain_openai初始化OpenAI Embeddings模型
  37. embeddings = OpenAIEmbeddings(
  38. model=self.embedding_model_name,
  39. base_url=self.base_url,
  40. api_key=self.api_key
  41. )
  42. embedding = embeddings.embed_query(text)
  43. return embedding
  44. except Exception as e:
  45. raise Exception(f"文本embedding生成失败: {str(e)}")
  46. @observe(name="texts_embedding", as_type="embedding")
  47. def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
  48. """
  49. 获取多个文本的embedding
  50. Args:
  51. texts: 要获取embedding的文本列表
  52. Returns:
  53. List[List[float]]: 文本列表的embedding向量列表
  54. """
  55. try:
  56. # 使用langchain_openai初始化OpenAI Embeddings模型
  57. embeddings = OpenAIEmbeddings(
  58. model=self.embedding_model_name,
  59. base_url=self.base_url,
  60. api_key=self.api_key
  61. )
  62. embeddings = embeddings.embed_documents(texts)
  63. return embeddings
  64. except Exception as e:
  65. raise Exception(f"多个文本embedding生成失败: {str(e)}")
  66. @observe(name="image_embedding", as_type="embedding")
  67. def get_image_embedding(self, image: Image.Image) -> List[float]:
  68. """
  69. 获取图像的embedding
  70. Args:
  71. image: PIL图像对象
  72. Returns:
  73. List[float]: 图像的embedding向量
  74. """
  75. try:
  76. # 将图像转换为base64
  77. buffer = io.BytesIO()
  78. image.save(buffer, format="PNG")
  79. buffer.seek(0)
  80. # 压缩图片字节流
  81. compressed_bytes = image_util._compress_image_to_bytes(buffer)
  82. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  83. # 构建输入项
  84. item = [
  85. {
  86. "image": f"data:image/png;base64,{image_base64}"
  87. }
  88. ]
  89. response = MultiModalEmbedding.call(
  90. model=self.multi_embedding_model_name,
  91. api_key=self.dashscope_api_key,
  92. input=item
  93. )
  94. if response.status_code == 200:
  95. return response.output["embeddings"][0]["embedding"]
  96. except Exception as e:
  97. raise Exception(f"图像embedding生成失败: {str(e)}")
  98. @observe(name="multimodal_embedding", as_type="embedding")
  99. def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
  100. """
  101. 获取多模态(文本+图像)的embedding
  102. Args:
  103. text: 文本内容
  104. image: PIL图像对象
  105. Returns:
  106. List[float]: 多模态的embedding向量
  107. """
  108. try:
  109. item = []
  110. if image is not None:
  111. buffer = io.BytesIO()
  112. image.save(buffer, format="PNG")
  113. buffer.seek(0)
  114. # 压缩图片字节流
  115. compressed_bytes = image_util._compress_image_to_bytes(buffer)
  116. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  117. item.append({'image': f"data:image/png;base64,{image_base64}"})
  118. # 判断text部位None或者不为空字符串
  119. if text is not None and text.strip() != "":
  120. item.append({'text': text})
  121. response = MultiModalEmbedding.call(
  122. model=self.multi_embedding_model_name,
  123. api_key=self.dashscope_api_key,
  124. input=item
  125. )
  126. if response.status_code == 200:
  127. return response.output["embeddings"][0]["embedding"]
  128. else:
  129. raise Exception(f"Error: {response.message}")
  130. except Exception as e:
  131. raise Exception(f"多模态embedding生成失败: {str(e)}")
  132. # 全局单例
  133. def get_embedding_model() -> Embedding:
  134. """
  135. 获取全局单例的Embedding模型
  136. Returns:
  137. Embedding: 全局单例的Embedding模型
  138. """
  139. return Embedding(
  140. model_name=model_settings.embedding_model_name,
  141. api_key=model_settings.api_key
  142. )