multimodal_embedding.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from typing import List
  2. from PIL import Image
  3. import base64
  4. import io
  5. from langchain_openai import OpenAIEmbeddings
  6. from dashscope import MultiModalEmbedding
  7. from conf.settings import model_settings
  8. from langfuse import observe
  9. from utils.file.image_util import image_util
  10. class Embedding:
  11. """Embedding模型工具"""
  12. def __init__(self, model_name: str = None, api_key: str = None):
  13. """
  14. 初始化Embedding模型
  15. Args:
  16. model_name: 模型名称,若为None则使用配置文件中的值
  17. api_key: API密钥,若为None则使用配置文件中的值
  18. """
  19. # 获取模型配置
  20. self.model_provider = model_settings.model_provider
  21. self.model_name = model_name or model_settings.model_name
  22. self.base_url = model_settings.base_url
  23. self.api_key = api_key or model_settings.api_key
  24. @observe(name="text_embedding", as_type="embedding")
  25. def get_text_embedding(self, text: str) -> List[float]:
  26. """
  27. 获取文本的embedding
  28. Args:
  29. text: 要获取embedding的文本
  30. Returns:
  31. List[float]: 文本的embedding向量
  32. """
  33. try:
  34. # 使用langchain_openai初始化OpenAI Embeddings模型
  35. embeddings = OpenAIEmbeddings(
  36. model=self.model_name,
  37. base_url=self.base_url,
  38. api_key=self.api_key
  39. )
  40. embedding = embeddings.embed_query(text)
  41. return embedding
  42. except Exception as e:
  43. raise Exception(f"文本embedding生成失败: {str(e)}")
  44. @observe(name="texts_embedding", as_type="embedding")
  45. def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
  46. """
  47. 获取多个文本的embedding
  48. Args:
  49. texts: 要获取embedding的文本列表
  50. Returns:
  51. List[List[float]]: 文本列表的embedding向量列表
  52. """
  53. try:
  54. # 使用langchain_openai初始化OpenAI Embeddings模型
  55. embeddings = OpenAIEmbeddings(
  56. model=self.model_name,
  57. base_url=self.base_url,
  58. api_key=self.api_key
  59. )
  60. embeddings = embeddings.embed_documents(texts)
  61. return embeddings
  62. except Exception as e:
  63. raise Exception(f"多个文本embedding生成失败: {str(e)}")
  64. @observe(name="image_embedding", as_type="embedding")
  65. def get_image_embedding(self, image: Image.Image) -> List[float]:
  66. """
  67. 获取图像的embedding
  68. Args:
  69. image: PIL图像对象
  70. Returns:
  71. List[float]: 图像的embedding向量
  72. """
  73. try:
  74. # 将图像转换为base64
  75. buffer = io.BytesIO()
  76. image.save(buffer, format="PNG")
  77. buffer.seek(0)
  78. # 压缩图片字节流
  79. compressed_bytes = image_util._compress_image_to_bytes(buffer)
  80. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  81. # 构建输入项
  82. item = [
  83. {
  84. "image": f"data:image/png;base64,{image_base64}"
  85. }
  86. ]
  87. response = MultiModalEmbedding.call(
  88. model=self.model_name,
  89. api_key=self.api_key,
  90. input=item
  91. )
  92. if response.status_code == 200:
  93. return response.output["embeddings"][0]["embedding"]
  94. except Exception as e:
  95. raise Exception(f"图像embedding生成失败: {str(e)}")
  96. @observe(name="multimodal_embedding", as_type="embedding")
  97. def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
  98. """
  99. 获取多模态(文本+图像)的embedding
  100. Args:
  101. text: 文本内容
  102. image: PIL图像对象
  103. Returns:
  104. List[float]: 多模态的embedding向量
  105. """
  106. try:
  107. item = []
  108. if image is not None:
  109. buffer = io.BytesIO()
  110. image.save(buffer, format="PNG")
  111. buffer.seek(0)
  112. # 压缩图片字节流
  113. compressed_bytes = image_util._compress_image_to_bytes(buffer)
  114. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  115. item.append({'image': f"data:image/png;base64,{image_base64}"})
  116. # 判断text部位None或者不为空字符串
  117. if text is not None and text.strip() != "":
  118. item.append({'text': text})
  119. response = MultiModalEmbedding.call(
  120. model=self.model_name,
  121. api_key=self.api_key,
  122. input=item
  123. )
  124. if response.status_code == 200:
  125. return response.output["embeddings"][0]["embedding"]
  126. else:
  127. raise Exception(f"Error: {response.message}")
  128. except Exception as e:
  129. raise Exception(f"多模态embedding生成失败: {str(e)}")
  130. # 全局单例
  131. def get_embedding_model() -> Embedding:
  132. """
  133. 获取全局单例的Embedding模型
  134. Returns:
  135. Embedding: 全局单例的Embedding模型
  136. """
  137. return Embedding(
  138. model_name=model_settings.embedding_model_name,
  139. api_key=model_settings.api_key
  140. )