首页 > 编程语言 >ChatGLM3 源码分析(四)

ChatGLM3 源码分析(四)

时间:2024-03-11 17:24:36浏览次数:33  
标签:分析 None self ChatGLM3 content 源码 role logits history

ChatGLMForSequenceClassification

class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)
        
        # NLabels:分类或者回归的标签数
        self.num_labels = config.num_labels
        # TFM
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)

        # 输出层,[HidSize, NLabels]
        self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
        # 输出层之后的 dropout
        if config.classifier_dropout is not None:
            self.dropout = nn.Dropout(config.classifier_dropout)
        else:
            self.dropout = None
        self.config = config

        # 如果指定了量化位数则执行量化
        if self.config.quantization_bit:
            self.quantize(self.config.quantization_bit, empty_init=True)

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            full_attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # 单词 ID:[BatchSize, SeqLen]
        # 将单词 ID 等东西传入 TFM,得到最终隐藏状态,KVCache,所有隐藏状态和所有层的注意力矩阵(None)
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # 获取最终隐藏状态,[SeqLen, BatchSize, HidSize]
        hidden_states = transformer_outputs[0]
        # 取它的最后一个步骤,由于GLM是单向注意力,这个步骤根据前面所有步骤计算
        # [BatchSize, HidSize]
        pooled_hidden_states = hidden_states[-1]
        # 如果指定了 dropout 就添加
        if self.dropout is not None:
            pooled_hidden_states = self.dropout(pooled_hidden_states)
        # 将隐藏状态转入输出层得到标签的 logits,[BatchSize, NLabels]
        logits = self.classifier_head(pooled_hidden_states)

        # 如果提供了标签,计算损失
        loss = None
        if labels is not None:
            # 如果没有定义任务类型则猜测它
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    # 如果标签数为 1,则为回归
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    # 如果不为 1 但为整数,则为单标签分类
                    self.config.problem_type = "single_label_classification"
                else:
                    # 否则为多标签分类
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                # 如果执行回归,损失函数选择 MSE
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze().float(), labels.squeeze())
                else:
                    loss = loss_fct(logits.float(), labels)
            elif self.config.problem_type == "single_label_classification":
                # 如果是单标签分类,损失函数选 Softmax 交叉熵
                loss_fct = CrossEntropyLoss()
                # logits 变形为 [BatchSize, NLabels]
                # labels 变形为 [BatchSize]
                loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                # 如果是多标签分类,损失函数选 Sigmoid 交叉熵,所有类别单独计算
                loss_fct = BCEWithLogitsLoss()
                # labels 变形为 [BatchSize, NLabels]
                loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))

        # 如果指定不返回字典,将损失,logits 和其他东西打包成元组返回
        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        # 否则返回字典
        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

ChatGLMForConditionalGeneration.chat()

In [1]: q = '你好'

In [2]: r, his = model.chat(tok, q)

In [3]: r
Out[3]: '\n 你好!很高兴见到你。有什么问题我可以帮助你解答吗?'

In [4]: his
Out[4]:
[{'role': 'user', 'content': '你好'},
 {'role': 'assistant', 'metadata': '', 'content': '你好!很高兴见到你。有什么问题我可以帮助你解答吗?'}]

In [5]: q = '你可以做什么?'

In [6]: r, his = model.chat(tok, q, history=his)

In [7]: r
Out[7]: '\n 作为人工智能助手,我可以帮助您解答各种问题。以下是一些我擅长的领域:\n\n1. 日常生活建议:如购物建议、健康建议、旅行建议等。\n2. 学习辅导:如数学、科学、历史等学科问题。\n3. 语言学习:如中文、英文、日语等语言学习。\n4. 娱乐休闲:如音乐、电影、书籍 、游戏等推荐。\n5. 技术支持:如操作系统、软件应用、电子设备等使用问题。\n\n当然,我会不断学习和进步,随着时间的推移,我将能帮助您 解答更多领域的疑问。如果您有任何问题,请随时向我提问。'

In [8]: his
Out[8]:
[{'role': 'user', 'content': '你好'},
 {'role': 'assistant', 'metadata': '', 'content': '你好!很高兴见到你。有什么问题我可以帮助你解答吗?'},
 {'role': 'user', 'content': '你可以做什么?'},
 {'role': 'assistant',
  'metadata': '',
  'content': '作为人工智能助手,我可以帮助您解答各种问题。以下是一些我擅长的领域:\n\n1. 日常生活建议:如购物建议、健康建议、旅行 建议等。\n2. 学习辅导:如数学、科学、历史等学科问题。\n3. 语言学习:如中文、英文、日语等语言学习。\n4. 娱乐休闲:如音乐、电影、书 籍、游戏等推荐。\n5. 技术支持:如操作系统、软件应用、电子设备等使用问题。\n\n当然,我会不断学习和进步,随着时间的推移,我将能帮助 您解答更多领域的疑问。如果您有任何问题,请随时向我提问。'}]
    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
             max_length: int = 32768, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        # 如果没有提供历史,初始化为空数组
        if history is None:
            history = []
        # 如果没有提供 logits 处理器,初始化为空列表
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        # 添加后备的 logits 处理器
        logits_processor.append(InvalidScoreLogitsProcessor())
        # 定义生成配置项
        # max_length:最大长度
        # num_beams:候选集数量
        # do_sample:是否采样,或者只取 TOP1
        # top_p:候选集的概率阈值
        # temperature:候选集采样策略,0 只取最高,1 均匀采样
        # logits_processor:logits 处理器列表
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        # 提问文本加上对话格式转换为整个的提问单词 ID
        '''
        In [1]: tok.build_chat_input('你好')
        Out[1]: {'input_ids': tensor([[64790, 64792, 64795, 30910,    13, 36474, 54591, 64796]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]]), 'position_ids': tensor([[0, 1, 2, 3, 4, 5, 6, 7]])}
        In [2]: tok.decode(_1['input_ids'][0])
        Out[2]: '[gMASK]sop<|user|> \n 你好<|assistant|>'
        '''
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
        inputs = inputs.to(self.device)
        # 定义终止符,<EOS>,或者用户和观察者的角色符号
        eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
        # 调用 HF 库生成回答
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        # 取第一个回答,并且忽略前面的提问部分
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        # 将回答 ID 转换成文本
        response = tokenizer.decode(outputs)
        # 历史对话中加入当前提问
        history.append({"role": role, "content": query})
        # 处理回答,解析其中的角色、元信息等等,并将当前回答添加到历史记录
        response, history = self.process_response(response, history)
        # 返回回答和历史记录
        return response, history


    # 处理模型回答中的角色和元信息
    def process_response(self, output, history):
        content = ""
        history = deepcopy(history)
        # 将回答按照机器人角色分割,得到每一段回答
        for response in output.split("<|assistant|>"):
            # 将每段回答按照第一个换行分割,得到元信息和内容
            metadata, content = response.split("\n", maxsplit=1)
            if not metadata.strip():
                # 如果元信息为空,将内容添加到历史中,替换训练时间占位符
                content = content.strip()
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                content = content.replace("[[训练时间]]", "2023年")
            else:
                # 否则解析工具调用
                # 首先将元信息和回答加入历史中
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                # 如果历史记录第一条角色为系统,并且其中定义了工具
                if history[0]["role"] == "system" and "tools" in history[0]:
                    # 忽略内容的第一行和最后一行
                    content = "\n".join(content.split("\n")[1:-1])

                    def tool_call(**kwargs):
                        return kwargs
                    # 将内容当作代码执行
                    parameters = eval(content)
                    # 将内容设为字典,`name`为元信息,`parameters`为执行结果
                    content = {"name": metadata.strip(), "parameters": parameters}
                else:
                    # 否则不执行工具调用
                    # 将内容设为字典,`name`为元信息,`parameters`内容本身
                    content = {"name": metadata.strip(), "content": content}
        # 返回回答和历史记录
        return content, history

ChatGLMForConditionalGeneration.stream_chat()

In [19]: q = '你好'

In [23]: it = model.stream_chat(tok, q)

In [24]: for r, his in it: print(repr(r)); print(repr(his))
'\n'
[{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'metadata': '', 'content': ''}]
'\n 你'
[{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'metadata': '', 'content': '你'}]
'\n 你好'
[{'role': 'user', 'content': '你好'}, {'role': 'assistant', 'metadata': '', 'content': '你好'}]
...
'\n 你好

标签:分析,None,self,ChatGLM3,content,源码,role,logits,history
From: https://www.cnblogs.com/apachecn/p/18066605

相关文章

  • 河北稳控科技振弦采集仪在岩土工程应力分析中的应用及效果评估
    振弦采集仪在岩土工程应力分析中的应用及效果评估河北稳控科技振弦采集仪是一种常用于岩土工程中的应力分析工具。它通过测量岩土体中的应变波动情况,间接地推测出岩土体中的应力状态。振弦采集仪的应用能够提供岩土体中的应力分布情况,对于岩土体的工程设计和施工具有重要的指导作......
  • Kubernetes: kube-controller-manager 源码分析
    0.前言在Kubernetes架构中,controllermanager是一个永不休止的控制回路组件,其负责控制集群资源的状态。通过监控kube-apiserver的资源状态,比较当前资源状态和期望状态,如果不一致,更新kube-apiserver的资源状态以保持当前资源状态和期望状态一致。1.kube-controller-ma......
  • drf源码剖析----版本、reverse
    点击查看代码classAPIView(View):defdispatch(self,request,*args,**kwargs):self.args=argsself.kwargs=kwargsrequest=self.initialize_request(request,*args,**kwargs)self.request=requestself.headers......
  • 网络流量监测分析,国产、高性能、高可用
        随着网络规模不断扩大,复杂程度不断增加,给运维工作带来更大挑战。为保障网络正常、稳定、高效运行,对网络流量进行监测、存储、回溯成为不可或缺的手段,通过对流量的分析,运维人员可以更加全面的了解整体网络的运行状态,快速定位、解决网络中存在问题。    智和信......
  • 第一期:分析一下新能源汽车充电桩行业的市场情况和痛点!
    1:政策环境自从我国提出“新基建”以来,充电基础设施产业也成为行业的话题与关注焦点。“十三五”期间,我国充电基础设施实现了跨越式发展,标准体系逐步完备,产业生态稳步形成。国务院《新能源汽车产业发展规划(2021—2035年)的通知》指出,到“十四五”末,我国电动汽车充电保障能力进......
  • 当利用数据分析和改进过头了怎么办?
    当利用数据分析和改进过头时,可能会出现几种情况:过度依赖数据:有时候,团队可能会过度依赖数据,忽视其他重要因素,如用户反馈、创意灵感等。这可能导致创新的缺失和决策的僵化。数据误解:有时候,数据分析可能会被错误地解释或应用。这可能会导致错误的结论和不良的决策。局限性:数据分析......
  • 最新二次注入攻击和代码分析技术
    二次注入攻击二次注入攻击的测试地址在本书第2章。double1.php页面的功能是添加用户。第一步,输入用户名test'和密码123456,如图4-45所示,单击“send”按钮提交。 图4-45  页面返回链接/4.3/double2.php?id=4,是添加的新用户个人信息的页面,访问该链接,结果如图4-46所示。......
  • zookeeper源码(10)node增删改查及监听
    本文将从leader处理器入手,详细分析node的增删改查流程及监听器原理。回顾数据读写流程leaderZookeeperServer.processPacket封装Request并提交给业务处理器LeaderRequestProcessor做本地事务升级PrepRequestProcessor做事务准备ProposalRequestProcessor事务操作发proposal......
  • Swoole 源码分析之 epoll 多路复用模块
    首发原文链接:Swoole源码分析之HttpServer模块大家好,我是码农先森。引言在传统的IO模型中,每个IO操作都需要创建一个单独的线程或进程来处理,这样的操作会导致系统资源的大量消耗和管理开销。而IO多路复用技术通过使用少量的线程或进程同时监视多个IO事件,能够更高效地处理大......
  • 常用数据分析模型与方法
    一、背景数据分析中,会有一些分析方法来处理不同的问题。简单总结一下。方法汇总:https://share.mindmanager.com/#publish/5v_9k6Z9J3gqPL9sQwAGGKL5DgNrclp4iq_q8C7L    方法链接: 二、RFM分析2.1 定义R(Recency): 客户距离最近的一次采购时间的间隔。F( Freq......