一架梯子,一头程序猿,仰望星空!
LangChain教程(Python版本) > 内容正文

LangChain 自定义模型


自定义 LLM

目前AI模型领域百家争鸣,LangChain官方也没有对接好所有模型,有时候你需要自定义模型,接入LangChain框架。

本章将介绍如何创建自定义 LLM 包装器,方便你使用自己的模型或者LangChain官方不支持的模型。

在LangChain中,如果想要使用自己的LLM或与LangChain支持的不同包装器不同的包装器,可以创建一个自定义LLM包装器。自定义LLM只需要实现两个必需的方法:

  • 一个 _call 方法,接受一个字符串作为输入,一些可选的停用词,并返回一个字符串,在_call方法实现模型调用。
  • 一个 _llm_type 属性,返回一个代表模型名称的字符串,仅用于日志记录目的。

除了必需的方法之外,自定义LLM还可以实现一个可选的方法:

  • 一个 _identifying_params 属性,用于帮助打印该类。应返回一个字典。

实现一个简单的自定义LLM

让我们实现一个非常简单的自定义LLM,它只返回输入的前n个字符。

from typing import Any, List, Mapping, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM

# 集成`LLM`
class CustomLLM(LLM):
    n: int

    @property
    def _llm_type(self) -> str:
        # 返回我们自定义的模型标记
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        # 在这里实现模型api调用
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        return prompt[: self.n]

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        # 可选debug信息
        """Get the identifying parameters."""
        return {"n": self.n}

现在我们可以像使用其他LLM一样使用这个自定义LLM。

使用自定义LLM

我们可以实例化并使用自定义LLM对象,演示如何调用自定义LLM以及如何自定义打印输出。

llm = CustomLLM(n=10)
llm.invoke("This is a foobar thing")
'This is a '

我们还可以打印LLM并查看其自定义打印输出。

print(llm)
CustomLLM
Params: {'n': 10}

自定义Chat model(聊天模型)

这里讲解如何自定义LangChain的Chat model(聊天模型)。

消息输入和输出

在聊天模型中,消息是输入和输出的重点。消息(message)是指用户输入的内容以及模型生成的回复。

消息

聊天模型将消息作为输入,然后生成一个或多个消息作为输出。在LangChain中,有几种内置的消息类型,包括:

  • SystemMessage:用于初始化AI行为,通常作为一系列输入消息中的第一个消息。
  • HumanMessage:表示用户与聊天模型进行交互的消息。
  • AIMessage:表示来自聊天模型的消息,可以是文本或请求调用工具。
  • FunctionMessage / ToolMessage:用于将工具调用的结果传递回模型。

这些消息类型的使用可以根据具体需求进行扩展和定制,例如按照OpenAI的functiontool参数进行调整。

from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    SystemMessage,
)

流式变体

所有聊天消息都有一个流式变体,名称中包含Chunk

from langchain_core.messages import (
    AIMessageChunk,
    FunctionMessageChunk,
    HumanMessageChunk,
    SystemMessageChunk,
    ToolMessageChunk,
)

这些Chunk在从聊天模型流式输出时使用,并且它们都定义了一个可累加的属性!

例子

AIMessageChunk(content="你好") + AIMessageChunk(content=" 世界!")

返回结果

AIMessageChunk(content='你好 世界!')

简单聊天模型

SimpleChatModel继承可以快速实现一个简单的聊天模型(chat model)

虽然它不能实现聊天模型可能需要的所有功能,但是它快速实现,并且如果需要更多功能,可以过渡到下面介绍的BaseChatModel

继承SimpleChatModel需要实现以下接口:

  • _call方法 - 实现外部模型API调用。

此外,还可以指定以下内容:

  • _identifying_params属性 - 用于记录模型参数化信息。

可选的:

  • _stream方法 - 用于实现流式输出。

基本聊天模型

BaseChatModel继承,实现_generate方法以及_llm_type属性。可选择实现_stream_agenerate_astream以及_identifying_params

自定义聊天模型的例子

在这一部分,我们将展示一个名为CustomChatModelAdvanced的自定义聊天模型的代码实现,包括生成聊天结果、流式输出、异步流实现等。

from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor


class CustomChatModelAdvanced(BaseChatModel):
    """实现一个自定义聊天模型,返回最后一条消息的前`n`个字符

   """

    n: int
    """自定义模型参数"""

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        """这里实现模型调用逻辑,实际情况一般是调用第三方模型的api,然后把api返回的结果封装成langchain可以识别的格式
           关键参数说明:
            messages: 提示词(prompt)组成的消息列表
        """
        last_message = messages[-1]
        tokens = last_message.content[: self.n]
        # 提取最后一条消息的前面`n`个字符,模拟模型生成结果,把结果封装到AIMessage中
        message = AIMessage(content=tokens)
        # 进一步包装模型结果
        generation = ChatGeneration(message=message)
        # 最后使用ChatResult包装模型输出结果
        return ChatResult(generations=[generation])

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        """模型流式输出实现,跟_generate方法类似,区别是要处理流式输出
        """
        last_message = messages[-1]
        tokens = last_message.content[: self.n]

        for token in tokens:
            # 使用chunk版本的消息对象封装model返回结果,分段返回模型结果
            chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))

            if run_manager:
                run_manager.on_llm_new_token(token, chunk=chunk)

            yield chunk

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:
        """异步处理版本的`stream`方法实现
        """
        result = await run_in_executor(
            None,
            self._stream,
            messages,
            stop=stop,
            run_manager=run_manager.get_sync() if run_manager else None,
            **kwargs,
        )
        for chunk in result:
            yield chunk

    @property
    def _llm_type(self) -> str:
        """返回自定义模型的标记"""
        return "echoing-chat-model-advanced"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """返回自定义的debug信息"""
        return {"n": self.n}

测试自定义聊天模型

让我们来测试聊天模型,包括使用invokebatchstream方法以及异步流实现。

model = CustomChatModelAdvanced(n=3)

model.invoke([HumanMessage(content="hello!")])
# 输出: AIMessage(content='Meo')

model.invoke("hello")
# 输出: AIMessage(content='hel')

model.batch(["hello", "goodbye"])
# 输出: [AIMessage(content='hel'), AIMessage(content='goo')]

for chunk in model.stream("cat"):
    print(chunk.content, end="|")
# 输出: c|a|t|

async for chunk in model.astream("cat"):
    print(chunk.content, end="|")
# 输出: c|a|t|

async for event in model.astream_events("cat", version="v1"):
    print(event)


关联主题