找回密码
 立即注册
首页 业界区 业界 大模型基础补全计划(八)---相关知识点回顾与Qwen3-VL-2B ...

大模型基础补全计划(八)---相关知识点回顾与Qwen3-VL-2B-Instruct实例分析(终章)

忿惺噱 7 天前
PS:要转载请注明出处,本人版权所有。

PS: 这个只是基于《我自己》的理解,

如果和你的原则及想法相冲突,请谅解,勿喷。

环境说明

  无
前言

   本文是这个系列第八篇,也是本系列的终章,它们是:

  • 《大模型基础补全计划(一)---重温一些深度学习相关的数学知识》 https://www.cnblogs.com/Iflyinsky/p/18717317
  • 《大模型基础补全计划(二)---词嵌入(word embedding) 》 https://www.cnblogs.com/Iflyinsky/p/18775451
  • 《大模型基础补全计划(三)---RNN实例与测试》 https://www.cnblogs.com/Iflyinsky/p/18967569
  • 《大模型基础补全计划(四)---LSTM的实例与测试(RNN的改进)》 https://www.cnblogs.com/Iflyinsky/p/19091089
  • 《大模型基础补全计划(五)---seq2seq实例与测试(编码器、解码器架构)》 https://www.cnblogs.com/Iflyinsky/p/19150535
  • 《大模型基础补全计划(六)---带注意力机制的seq2seq实例与测试(Bahdanau Attention)》 https://www.cnblogs.com/Iflyinsky/p/19184558
  • 《大模型基础补全计划(七)---Transformer(多头注意力、自注意力、位置编码)及实例与测试》https://www.cnblogs.com/Iflyinsky/p/19228410
  本文主要是用一个实际的大模型例子来联系和回顾之前的知识点,让大家能够感受一些,前面文中的一些知识点是真正用到了实际大模型里面的哪些地方。
  由于近期正在学习和应用的Qwen3-VL系列相关模型,因此这里挑了一个Qwen3-VL-2B-Instruct来独立分析,并联系和回顾之前的知识点。
  注意:本文不会详细介绍Qwen3-VL-2B-Instruct的推理过程及原理,如果想学习详细的技术原理,请忽略本文内容,并查看其它相关的文章。




Qwen3-VL-2B-Instruct 简介

  


下载及运行

   首先qwen3-vl的官方工程是 https://github.com/QwenLM/Qwen3-VL ,下面的官方示例的下载及变更推理代码(由于国内的原因,从魔塔下载):
  1. modelscope download --model Qwen/Qwen3-VL-2B-Instruct  --local_dir ./cache
复制代码
  1. from transformers import AutoModelForImageTextToText, AutoProcessor
  2. model_path = "./cache"
  3. # default: Load the model on the available device(s)
  4. model = AutoModelForImageTextToText.from_pretrained(
  5.     model_path, cache_dir=model_path, dtype="auto", device_map="auto"
  6. )
  7. # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
  8. # model = AutoModelForImageTextToText.from_pretrained(
  9. #     "Qwen/Qwen3-VL-235B-A22B-Instruct",
  10. #     dtype=torch.bfloat16,
  11. #     attn_implementation="flash_attention_2",
  12. #     device_map="auto",
  13. # )
  14. processor = AutoProcessor.from_pretrained(model_path, cache_dir=model_path)
  15. messages = [
  16.     {
  17.         "role": "user",
  18.         "content": [
  19.             {
  20.                 "type": "image",
  21.                 "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
  22.             },
  23.             {"type": "text", "text": "Describe this image."},
  24.         ],
  25.     }
  26. ]
  27. # Preparation for inference
  28. inputs = processor.apply_chat_template(
  29.     messages,
  30.     tokenize=True,
  31.     add_generation_prompt=True,
  32.     return_dict=True,
  33.     return_tensors="pt"
  34. )
  35. inputs = inputs.to(model.device)
  36. # Inference: Generation of the output
  37. generated_ids = model.generate(**inputs, max_new_tokens=128)
  38. generated_ids_trimmed = [
  39.     out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  40. ]
  41. output_text = processor.batch_decode(
  42.     generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  43. )
  44. print(output_text)
复制代码



模型结构

   我们在上面的例子基础上,添加如下代码打印其模型结构:
  1.     print(model)
  2.     vit_model = model.visual
  3.     llm_model = model.language_model
复制代码
  得到的模型结构如下:
  1. Qwen3VLForConditionalGeneration(
  2.   (model): Qwen3VLModel(
  3.     (visual): Qwen3VLVisionModel(
  4.       (patch_embed): Qwen3VLVisionPatchEmbed(
  5.         (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  6.       )
  7.       (pos_embed): Embedding(2304, 1024)
  8.       (rotary_pos_emb): Qwen3VLVisionRotaryEmbedding()
  9.       (blocks): ModuleList(
  10.         (0-23): 24 x Qwen3VLVisionBlock(
  11.           (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  12.           (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  13.           (attn): Qwen3VLVisionAttention(
  14.             (qkv): Linear(in_features=1024, out_features=3072, bias=True)
  15.             (proj): Linear(in_features=1024, out_features=1024, bias=True)
  16.           )
  17.           (mlp): Qwen3VLVisionMLP(
  18.             (linear_fc1): Linear(in_features=1024, out_features=4096, bias=True)
  19.             (linear_fc2): Linear(in_features=4096, out_features=1024, bias=True)
  20.             (act_fn): GELUTanh()
  21.           )
  22.         )
  23.       )
  24.       (merger): Qwen3VLVisionPatchMerger(
  25.         (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
  26.         (linear_fc1): Linear(in_features=4096, out_features=4096, bias=True)
  27.         (act_fn): GELU(approximate='none')
  28.         (linear_fc2): Linear(in_features=4096, out_features=2048, bias=True)
  29.       )
  30.       (deepstack_merger_list): ModuleList(
  31.         (0-2): 3 x Qwen3VLVisionPatchMerger(
  32.           (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=True)
  33.           (linear_fc1): Linear(in_features=4096, out_features=4096, bias=True)
  34.           (act_fn): GELU(approximate='none')
  35.           (linear_fc2): Linear(in_features=4096, out_features=2048, bias=True)
  36.         )
  37.       )
  38.     )
  39.     (language_model): Qwen3VLTextModel(
  40.       (embed_tokens): Embedding(151936, 2048)
  41.       (layers): ModuleList(
  42.         (0-27): 28 x Qwen3VLTextDecoderLayer(
  43.           (self_attn): Qwen3VLTextAttention(
  44.             (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  45.             (k_proj): Linear(in_features=2048, out_features=1024, bias=False)
  46.             (v_proj): Linear(in_features=2048, out_features=1024, bias=False)
  47.             (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  48.             (q_norm): Qwen3VLTextRMSNorm((128,), eps=1e-06)
  49.             (k_norm): Qwen3VLTextRMSNorm((128,), eps=1e-06)
  50.           )
  51.           (mlp): Qwen3VLTextMLP(
  52.             (gate_proj): Linear(in_features=2048, out_features=6144, bias=False)
  53.             (up_proj): Linear(in_features=2048, out_features=6144, bias=False)
  54.             (down_proj): Linear(in_features=6144, out_features=2048, bias=False)
  55.             (act_fn): SiLUActivation()
  56.           )
  57.           (input_layernorm): Qwen3VLTextRMSNorm((2048,), eps=1e-06)
  58.           (post_attention_layernorm): Qwen3VLTextRMSNorm((2048,), eps=1e-06)
  59.         )
  60.       )
  61.       (norm): Qwen3VLTextRMSNorm((2048,), eps=1e-06)
  62.       (rotary_emb): Qwen3VLTextRotaryEmbedding()
  63.     )
  64.   )
  65.   (lm_head): Linear(in_features=2048, out_features=151936, bias=False)
  66. )
复制代码
  从上面的模型结构来看,我们可以知道其分为两个部分,一个是visual,一个是language_model,这也是现在的视觉多模态的常见结构。




Qwen3-VL-2B-Instruct 的模型结构简单分析 及 知识回顾

  还记得我们前面的模型中的词表这个概念吗?当时的做法是直接将将整个训练用到的文字映射成对应的id,将所有的id组合在一起作为一个词表。在现在的大模型中,其实就有类似的东西,一般放在tokenizer.json文件里面。对于当前这个模型来说,这里有几个特殊的东西说明一下:

  • 以前文章中的/对应的是当前这个模型的/
  • 由于是视觉多模态模型,当前这个模型还会有几个本文会用到的特殊token://,他们是用来描述一张图怎么被输入到大语言模型中被理解的。
  • 一个token不一定对应一个文字,可能对应多个、或者零点几个字,感兴趣可以私下了解一下,其和文字编码有关系。
  当上文的 processor.apply_chat_template执行后,然后得到的inputs会有如下四个内容:

  • input_ids (做完tokenizer之后的输出,已经将输入的文字“Describe this image.”和图片占位符“*N”转换为了对应的token id)
  • attention_mask (input_ids的掩码,用于屏蔽无效或者pad输入序列)
  • pixel_values (图片预处理好的矩阵,不仅仅做了归一化,还做了分patch操作,本文不用太关注)
  • image_grid_thw (本文用不上,别管。)
  对于input_ids来说,我们知道里面有图片的占位符的token_id,这里后面会替换为真实的图像数据,这样才能把图、文字送入到大语言模型,当然,语音等也是一样的。
  我们首先来看看上文model.generate调用之后发生了什么,他会经过一系列变化后,到达如下的Qwen3VLModel的forward的入口:
  1. def forward(
  2.     self,
  3.     input_ids: torch.LongTensor = None,
  4.     attention_mask: Optional[torch.Tensor] = None,
  5.     position_ids: Optional[torch.LongTensor] = None,
  6.     past_key_values: Optional[Cache] = None,
  7.     inputs_embeds: Optional[torch.FloatTensor] = None,
  8.     pixel_values: Optional[torch.Tensor] = None,
  9.     pixel_values_videos: Optional[torch.FloatTensor] = None,
  10.     image_grid_thw: Optional[torch.LongTensor] = None,
  11.     video_grid_thw: Optional[torch.LongTensor] = None,
  12.     cache_position: Optional[torch.LongTensor] = None,
  13.     **kwargs: Unpack[TransformersKwargs],
  14. ) -> Union[tuple, Qwen3VLModelOutputWithPast]:
  15.     r"""
  16.     image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  17.         The temporal, height and width of feature shape of each image in LLM.
  18.     video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  19.         The temporal, height and width of feature shape of each video in LLM.
  20.     """
  21.     if (input_ids is None) ^ (inputs_embeds is not None):
  22.         raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  23.     if inputs_embeds is None:
  24.         inputs_embeds = self.get_input_embeddings()(input_ids)
  25.     image_mask = None
  26.     video_mask = None
  27.     if pixel_values is not None:
  28.         image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
  29.         image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  30.         image_mask, _ = self.get_placeholder_mask(
  31.             input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
  32.         )
  33.         inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  34.     if pixel_values_videos is not None:
  35.         video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
  36.         video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  37.         _, video_mask = self.get_placeholder_mask(
  38.             input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
  39.         )
  40.         inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
  41.     visual_pos_masks = None
  42.     deepstack_visual_embeds = None
  43.     if image_mask is not None and video_mask is not None:
  44.         # aggregate visual_pos_masks and deepstack_visual_embeds
  45.         image_mask = image_mask[..., 0]
  46.         video_mask = video_mask[..., 0]
  47.         visual_pos_masks = image_mask | video_mask
  48.         deepstack_visual_embeds = []
  49.         image_mask_joint = image_mask[visual_pos_masks]
  50.         video_mask_joint = video_mask[visual_pos_masks]
  51.         for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
  52.             embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
  53.             embed_joint[image_mask_joint, :] = img_embed
  54.             embed_joint[video_mask_joint, :] = vid_embed
  55.             deepstack_visual_embeds.append(embed_joint)
  56.     elif image_mask is not None:
  57.         image_mask = image_mask[..., 0]
  58.         visual_pos_masks = image_mask
  59.         deepstack_visual_embeds = deepstack_image_embeds
  60.     elif video_mask is not None:
  61.         video_mask = video_mask[..., 0]
  62.         visual_pos_masks = video_mask
  63.         deepstack_visual_embeds = deepstack_video_embeds
  64.     if position_ids is None:
  65.         attention_mask_tensor = (
  66.             attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
  67.         )
  68.         if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
  69.             attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
  70.             # Only apply conversion for floating point tensors (inverted masks)
  71.             if attention_mask_tensor.dtype.is_floating_point:
  72.                 attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
  73.                 attention_mask_tensor = (1.0 - attention_mask_tensor).int()
  74.         # Calculate RoPE index once per generation in the pre-fill stage only.
  75.         # When compiling, we can't check tensor values thus we check only input length
  76.         # It is safe to assume that `length!=1` means we're in pre-fill because compiled
  77.         # models currently cannot do asssisted decoding
  78.         prefill_compiled_stage = is_torchdynamo_compiling() and (
  79.             (input_ids is not None and input_ids.shape[1] != 1)
  80.             or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
  81.         )
  82.         prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
  83.             (cache_position is not None and cache_position[0] == 0)
  84.             or (past_key_values is None or past_key_values.get_seq_length() == 0)
  85.         )
  86.         if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
  87.             position_ids, rope_deltas = self.get_rope_index(
  88.                 input_ids,
  89.                 image_grid_thw,
  90.                 video_grid_thw,
  91.                 attention_mask=attention_mask_tensor,
  92.             )
  93.             self.rope_deltas = rope_deltas
  94.         # then use the prev pre-calculated rope-deltas to get the correct position ids
  95.         else:
  96.             batch_size, seq_length, _ = inputs_embeds.shape
  97.             delta = (
  98.                 (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
  99.                 if cache_position is not None
  100.                 else 0
  101.             )
  102.             position_ids = torch.arange(seq_length, device=inputs_embeds.device)
  103.             position_ids = position_ids.view(1, -1).expand(batch_size, -1)
  104.             if cache_position is not None:  # otherwise `deltas` is an int `0`
  105.                 delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
  106.             position_ids = position_ids.add(delta)
  107.             position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
  108.     outputs = self.language_model(
  109.         input_ids=None,
  110.         position_ids=position_ids,
  111.         attention_mask=attention_mask,
  112.         past_key_values=past_key_values,
  113.         inputs_embeds=inputs_embeds,
  114.         cache_position=cache_position,
  115.         visual_pos_masks=visual_pos_masks,
  116.         deepstack_visual_embeds=deepstack_visual_embeds,
  117.         **kwargs,
  118.     )
  119.     return Qwen3VLModelOutputWithPast(
  120.         last_hidden_state=outputs.last_hidden_state,
  121.         past_key_values=outputs.past_key_values,
  122.         rope_deltas=self.rope_deltas,
  123.     )
复制代码
   看上面的代码,我们来看看 input_ids 中的主要的几个数据分别做了什么:

  • input_ids 通过get_input_embeddings获取了input_ids对应的原始inputs_embeds,这一步和我们以前文章中做embedding是一样的。唯一注意的,这里的embedding向量里面包含对应的嵌入向量,是占位的,后面要替换为真实的数据。
  • pixel_values 通过get_image_features获取了图像数据对应的image_embeds,这里对应Qwen3VLVisionModel的推理过程,下面会简单说明一下。
  • 在masked_scatter中,将inputs_embeds中的占位向量替换为image_embeds。
  • 根据输入的inputs_embeds,获取token对应的position_ids,也就是获取位置信息,在前面的文中提到了为什么transformer需要位置信息。
  • 将最终的position_ids,attention_mask,inputs_embeds,past_key_values(此项内容在下文解释)给Qwen3VLTextModel进行推理得到logits序列
  • 然后将logits按采样参数进行采样,得到最终的输出的文字token,然后进行tokenizer解码,得到最终输出的文字。(此部分不在上面所在代码范围内部,但是是大模型的后处理部分的必要逻辑部分。)
   我们从上文已经知道,其模型分为两个部分,下面分别简单介绍这两部分的forward过程,看看我们之前提到的知识点在真实的多模态大模型中是怎么样的存在。


visual 部分简单分析

  本系列文章严格来说是不应该涉及到多模态大模型的,但是现在常见的多模态大模型应用场景已经逐渐扩大,因此这里用视觉多模态大模型为例子,看看视觉多模态大模型和普通的大模型有什么区别,首先visual部分的forward代码如下:
  1.     def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
  2.         """
  3.         Args:
  4.             hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
  5.                 The final hidden states of the model.
  6.             grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
  7.                 The temporal, height and width of feature shape of each image in LLM.
  8.         Returns:
  9.             `torch.Tensor`: hidden_states.
  10.         """
  11.         hidden_states = self.patch_embed(hidden_states)
  12.         pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
  13.         hidden_states = hidden_states + pos_embeds
  14.         rotary_pos_emb = self.rot_pos_emb(grid_thw)
  15.         seq_len, _ = hidden_states.size()
  16.         hidden_states = hidden_states.reshape(seq_len, -1)
  17.         rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
  18.         emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  19.         position_embeddings = (emb.cos(), emb.sin())
  20.         cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
  21.             dim=0,
  22.             # Select dtype based on the following factors:
  23.             #  - FA2 requires that cu_seqlens_q must have dtype int32
  24.             #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
  25.             # See https://github.com/huggingface/transformers/pull/34852 for more information
  26.             dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
  27.         )
  28.         cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  29.         deepstack_feature_lists = []
  30.         for layer_num, blk in enumerate(self.blocks):
  31.             hidden_states = blk(
  32.                 hidden_states,
  33.                 cu_seqlens=cu_seqlens,
  34.                 position_embeddings=position_embeddings,
  35.                 **kwargs,
  36.             )
  37.             if layer_num in self.deepstack_visual_indexes:
  38.                 deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
  39.                     hidden_states
  40.                 )
  41.                 deepstack_feature_lists.append(deepstack_feature)
  42.         hidden_states = self.merger(hidden_states)
  43.         return hidden_states, deepstack_feature_lists
复制代码
  由于本文并不是要细节介绍这个模型的结构,因此这里我们只需要知道其输入是:预处理好的图片数据+grid_thw,其输出是:hidden_states+deepstack_feature_lists。其中最重要的就是输出的hidden_states,它含义是图片token的ebedding向量矩阵,在上面已经提到了其作用。




language_model 部分分析

  对于语言模型部分来说,这个部分才是和我们前面训练的模型比较像的,下面我们先来看看其forward过程:
  1. def forward(
  2.     self,
  3.     input_ids: Optional[torch.LongTensor] = None,
  4.     attention_mask: Optional[torch.Tensor] = None,
  5.     position_ids: Optional[torch.LongTensor] = None,
  6.     past_key_values: Optional[Cache] = None,
  7.     inputs_embeds: Optional[torch.FloatTensor] = None,
  8.     use_cache: Optional[bool] = None,
  9.     cache_position: Optional[torch.LongTensor] = None,
  10.     # args for deepstack
  11.     visual_pos_masks: Optional[torch.Tensor] = None,
  12.     deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
  13.     **kwargs: Unpack[FlashAttentionKwargs],
  14. ) -> Union[tuple, BaseModelOutputWithPast]:
  15.     r"""
  16.     visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
  17.         The mask of the visual positions.
  18.     deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
  19.         The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
  20.         The feature is extracted from the different visual encoder layers, and fed to the decoder
  21.         hidden states. It's from the paper DeepStack(https://arxiv.org/abs/2406.04334).
  22.     """
  23.     if (input_ids is None) ^ (inputs_embeds is not None):
  24.         raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  25.     # torch.jit.trace() doesn't support cache objects in the output
  26.     if use_cache and past_key_values is None and not torch.jit.is_tracing():
  27.         past_key_values = DynamicCache(config=self.config)
  28.     if inputs_embeds is None:
  29.         inputs_embeds = self.embed_tokens(input_ids)
  30.     if cache_position is None:
  31.         past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  32.         cache_position = torch.arange(
  33.             past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  34.         )
  35.     # the hard coded `3` is for temporal, height and width.
  36.     if position_ids is None:
  37.         position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
  38.     elif position_ids.ndim == 2:
  39.         position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
  40.     if position_ids.ndim == 3 and position_ids.shape[0] == 4:
  41.         text_position_ids = position_ids[0]
  42.         position_ids = position_ids[1:]
  43.     else:
  44.         text_position_ids = position_ids[0]
  45.     attention_mask = create_causal_mask(
  46.         config=self.config,
  47.         input_embeds=inputs_embeds,
  48.         attention_mask=attention_mask,
  49.         cache_position=cache_position,
  50.         past_key_values=past_key_values,
  51.         position_ids=text_position_ids,
  52.     )
  53.     hidden_states = inputs_embeds
  54.     # create position embeddings to be shared across the decoder layers
  55.     position_embeddings = self.rotary_emb(hidden_states, position_ids)
  56.     # decoder layers
  57.     for layer_idx, decoder_layer in enumerate(self.layers):
  58.         layer_outputs = decoder_layer(
  59.             hidden_states,
  60.             attention_mask=attention_mask,
  61.             position_ids=text_position_ids,
  62.             past_key_values=past_key_values,
  63.             cache_position=cache_position,
  64.             position_embeddings=position_embeddings,
  65.             **kwargs,
  66.         )
  67.         hidden_states = layer_outputs
  68.         # add visual features to the hidden states of first several layers
  69.         if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
  70.             hidden_states = self._deepstack_process(
  71.                 hidden_states,
  72.                 visual_pos_masks,
  73.                 deepstack_visual_embeds[layer_idx],
  74.             )
  75.     hidden_states = self.norm(hidden_states)
  76.     return BaseModelOutputWithPast(
  77.         last_hidden_state=hidden_states,
  78.         past_key_values=past_key_values,
  79.     )
复制代码
  我们看到了将position_ids,attention_mask,inputs_embeds,past_key_values传入推理过程后,得到了两个重要的内容,一个logits,一个past_key_values,下面重点介绍一下这两个是什么:

  • logits 输出的是一次推理后,词表大小的一个概率矩阵,然后根据我们的采样相关参数(例如我们常见的:Temperature/Top P/Frequency Penalty等就是在这一阶段生效),选择对应的token_id,然后转换为文字。
  • past_key_values 保存的是每一层decoder layer的注意力机制里面的K/V内容,也就是我们常见的KV Cache一词的存在的地方。
  最后我们来看看现在常见的KV cache(缓存命中、缓存未命中)到底意味着什么?我们举一个简单直观的例子:我们保存了“你好”的KV cache,那我们再一次推理“你好世界。”,那么我们可以直接使用“你好”的KV cache,不用重复计算前面部分,可以直接计算新的部分,加快推理速度、减少了计算资源使用。




后记

  本文基于Qwen3-VL-2B-Instruct,回顾了之前的一些知识,从这里我们可以看到,当前大模型里面用到的好多知识点,其实都来自于以前的某个地方。
  本系列到此,完结散花。
参考文献


  • https://github.com/QwenLM/Qwen3-VL


                    打赏、订阅、收藏、丢香蕉、硬币,请关注公众号(攻城狮的搬砖之路)               
1.jpeg
    PS: 请尊重原创,不喜勿喷。

PS: 要转载请注明出处,本人版权所有。

PS: 有问题请留言,看到后我会第一时间回复。


来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册