multimodal_embedding.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from typing import Dict, Any, List
  2. from PIL import Image
  3. import base64
  4. import io
  5. from langchain_openai import OpenAIEmbeddings
  6. from dashscope import MultiModalEmbedding
  7. from conf.config import ModelConfig
  8. class MultimodalEmbedding:
  9. """OpenAI Embedding模型工具"""
  10. def __init__(self, model_name: str = None, api_key: str = None):
  11. """
  12. 初始化OpenAI Embedding模型
  13. Args:
  14. model_name: 模型名称,若为None则使用配置文件中的值
  15. api_key: API密钥,若为None则使用配置文件中的值
  16. """
  17. # 获取模型配置
  18. self.model_provider = ModelConfig.get_model_provider()
  19. self.model_name = model_name or ModelConfig.get_model_name()
  20. self.base_url = ModelConfig.get_base_url()
  21. self.api_key = api_key or ModelConfig.get_api_key()
  22. # 使用langchain_openai初始化OpenAI Embeddings模型
  23. self.embeddings = OpenAIEmbeddings(
  24. model=self.model_name,
  25. base_url=self.base_url,
  26. api_key=self.api_key
  27. )
  28. def get_text_embedding(self, text: str) -> List[float]:
  29. """
  30. 获取文本的embedding
  31. Args:
  32. text: 要获取embedding的文本
  33. Returns:
  34. List[float]: 文本的embedding向量
  35. """
  36. try:
  37. embedding = self.embeddings.embed_query(text)
  38. return embedding
  39. except Exception as e:
  40. raise Exception(f"文本embedding生成失败: {str(e)}")
  41. def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
  42. """
  43. 获取多个文本的embedding
  44. Args:
  45. texts: 要获取embedding的文本列表
  46. Returns:
  47. List[List[float]]: 文本列表的embedding向量列表
  48. """
  49. try:
  50. embeddings = self.embeddings.embed_documents(texts)
  51. return embeddings
  52. except Exception as e:
  53. raise Exception(f"多个文本embedding生成失败: {str(e)}")
  54. def get_image_embedding(self, image: Image.Image) -> List[float]:
  55. """
  56. 获取图像的embedding
  57. Args:
  58. image: PIL图像对象
  59. Returns:
  60. List[float]: 图像的embedding向量
  61. """
  62. try:
  63. # 将图像转换为base64
  64. buffer = io.BytesIO()
  65. image.save(buffer, format="PNG")
  66. image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
  67. # 构建输入项
  68. item = [
  69. {
  70. "image": f"data:image/png;base64,{image_base64}"
  71. }
  72. ]
  73. response = MultiModalEmbedding.call(
  74. model=self.model_name,
  75. api_key=self.api_key,
  76. input=item
  77. )
  78. if response.status_code == 200:
  79. return response.output["embeddings"][0]["embedding"]
  80. except Exception as e:
  81. raise Exception(f"图像embedding生成失败: {str(e)}")
  82. def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
  83. """
  84. 获取多模态(文本+图像)的embedding
  85. Args:
  86. text: 文本内容
  87. image: PIL图像对象
  88. Returns:
  89. List[float]: 多模态的embedding向量
  90. """
  91. try:
  92. item = []
  93. if image is not None:
  94. # 将图像转换为base64
  95. buffer = io.BytesIO()
  96. image.save(buffer, format="PNG")
  97. image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
  98. item.append({'image': f"data:image/png;base64,{image_base64}"})
  99. if text is not None:
  100. item.append({'text': text})
  101. response = MultiModalEmbedding.call(
  102. model=self.model_name,
  103. api_key=self.api_key,
  104. input=item
  105. )
  106. if response.status_code == 200:
  107. return response.output["embeddings"][0]["embedding"]
  108. else:
  109. raise Exception(f"Error: {response.message}")
  110. except Exception as e:
  111. raise Exception(f"多模态embedding生成失败: {str(e)}")