"""PyTorch CodeShell model."""
import os
import math
from typing import List, Optional, Tuple, Union, Callable
from threading import Thread
from queue import Queue
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
torch_npu_available = True
try:
import torch_npu
except ModuleNotFoundError:
torch_npu_available = False
from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
from transformers.generation.utils import GenerationConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from .configuration_codeshell import CodeShellConfig
@torch.jit.script
def upcast_masked_softmax(
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def upcast_softmax(x: torch.Tensor, scale: float, softmax_dtype: torch.dtype):
input_dtype = x.dtype
x = x.to(softmax_dtype) * scale
x = torch.nn.functional.softmax(x, dim=-1).to(input_dtype)
return x
@torch.jit.script
def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor):
x = torch.where(mask, x, mask_value)
x = torch.nn.functional.softmax(x, dim=-1)
return x
class CodeShellRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class CodeShellLinearScalingRotaryEmbedding(CodeShellRotaryEmbedding):
"""CodeShellRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class CodeShellDynamicNTKScalingRotaryEmbedding(CodeShellRotaryEmbedding):
"""ShellRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
if torch_npu_available:
return torch_npu.npu_rotary_mul(q, cos, sin), torch_npu.npu_rotary_mul(k, cos, sin)
else:
cos = cos.squeeze(1).squeeze(0)
sin = sin.squeeze(1).squeeze(0)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class CodeShellAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.mask_value = None
self.position_embedding_type = config.position_embedding_type
self.rope_scaling = config.rope_scaling
self.max_position_embeddings = config.max_position_embeddings
self.group_query_attention = config.group_query_attention
self.num_query_groups = config.num_query_groups
self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.kv_heads = config.num_query_groups if self.group_query_attention else self.num_heads
self.kv_dim = self.kv_heads * self.head_dim
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.layer_idx = layer_idx
self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
if self.position_embedding_type == "rope":
self._init_rope()
def _init_rope(self):
if self.rope_scaling is None:
self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.rope_scaling["type"]
scaling_factor = self.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = CodeShellLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = CodeShellDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _get_mask_value(self, device, dtype):
if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
return self.mask_value
def forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
bsz, q_len, _ = hidden_states.size()
query_states, key_states, value_states = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if layer_past is not None:
key_states = torch.cat([layer_past[0], key_states], dim=2)
value_states = torch.cat([layer_past[1], value_states], dim=2)
layer_past = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_heads // self.kv_heads)
value_states = repeat_kv(value_states, self.num_heads // self.kv_heads)
if torch_npu_available:
attn_mask_npu = torch.logical_not(attention_mask.bool()).to(attention_mask.device)
head_num = query_states.shape[1]
attn_output = torch_npu.npu_fusion_attention(
query_states,
key_states,
value_states,
head_num,
input_layout="BNSD",
pse=None,
atten_mask=attn_mask_npu,
scale=1.0 / math.sqrt(query_states.shape[-1]),
pre_tockens=2147483647,
next_tockens=2147483647,
keep_prob=1,
inner_precise=0
)[0]
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
mask_value = self._get_mask_value(attn_weights.device, attn_weights.dtype)
attn_weights = torch.where(attention_mask, attn_weights, mask_value)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = self.attn_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, layer_past)
if output_attentions:
outputs += (attn_weights,)
return outputs
class CodeShellMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.Tensor]]) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class CodeShellBlock(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
self.inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = CodeShellAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = CodeShellMLP(self.inner_dim, config)
def forward(
self,
hidden_states: Optional[Tuple[torch.Tensor]],
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs
class CodeShellPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = CodeShellConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["ShellBlock"]
_skip_keys_device_placement = "past_key_values"
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (CodeShellMLP, CodeShellAttention)):
module.c_proj.weight.data.normal_(
mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))
)
module.c_proj._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CodeShellModel):
module.gradient_checkpointing = value
GPT_BIGCODE_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`CodeShellConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
GPT_BIGCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
`past_key_values`. In other words, the `attention_mask` always has to have the length:
`len(past_key_values) + len(input_ids)`
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids)
position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`).
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
GPT_BIGCODE_START_DOCSTRING,
)
class CodeShellModel(CodeShellPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.group_query_attention = config.group_query_attention
self.num_query_groups = config.num_query_groups
self.position_embedding_type = config.position_embedding_type
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
if self.position_embedding_type == "learned_absolute":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
else:
pass
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([CodeShellBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)), persistent=False
)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.reshape(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.reshape(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.reshape(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_length > 0:
position_ids = position_ids[:, past_length : input_shape[-1] + past_length :]
elif position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).reshape(-1, input_shape[-1])
query_length = input_shape[-1]
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
if attention_mask is not None:
self_attention_mask = self_attention_mask * attention_mask.reshape(batch_size, 1, -1).to(
dtype=torch.bool, device=self_attention_mask.device
)
attention_mask = self_attention_mask.unsqueeze(1)
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
if self.position_embedding_type == "learned_absolute":
position_embeds = self.wpe(position_ids)
hidden_states = hidden_states + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
presents = [] if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
position_ids,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache:
presents.append(outputs[1])
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.reshape(output_shape)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class EndOfFunctionCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, input_lengths, eof_strings, tokenizer):
self.input_lengths = input_lengths
self.eof_strings = eof_strings
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
decoded_generations = []
for _input_ids, input_length in zip(input_ids, self.input_lengths):
decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
done = []
for decoded_generation in decoded_generations:
done.append(
any(
[
stop_string in decoded_generation
for stop_string in self.eof_strings
]
)
)
return all(done)
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
@add_start_docstrings(
"""
The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
GPT_BIGCODE_START_DOCSTRING,
)
class CodeShellForCausalLM(CodeShellPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.transformer = CodeShellModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.post_init()
def quantize(self, bits: int):
try:
import bitsandbytes
from .quantizer import quantize
except ImportError:
raise ImportError(f"Needs bitsandbytes to run quantize.")
return quantize(self, bits)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
user_name = "## human:"
ai_name = "## assistant: "
stop = '|<end>|'
prompt = ''
for q, r in history:
prompt += f"{user_name}{q}{stop}"
prompt += f"{ai_name}{r}{stop}"
prompt += f"{user_name}{query}{stop}"
prompt += ai_name.rstrip()
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
max_new_tokens = max_new_tokens or 128
max_input_tokens = self.config.n_positions - max_new_tokens
input_tokens = tokenizer.encode(prompt)
input_tokens = input_tokens[-max_input_tokens:]
return torch.LongTensor([input_tokens]).to(self.device)
def chat(self, query, history, tokenizer, stream=False,
generation_config: Optional[GenerationConfig]=None):
generation_config = generation_config or self.generation_config
input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
stopping_criteria = StoppingCriteriaList(
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '|end|', '<|endoftext|>', '## human'], tokenizer)]
)
if stream:
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
Thread(target=self.generate, kwargs=dict(
inputs=input_ids, streamer=streamer,
stopping_criteria = stopping_criteria,
generation_config=generation_config,
)).start()
return streamer
else:
outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
return response
def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
generation_config = generation_config or self.generation_config
max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
input_ids = tokenizer.encode(prompt)
input_ids = input_ids[-max_input_tokens:]
stopping_criteria = StoppingCriteriaList(
[EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '|end|', '<|endoftext|>', '## human'], tokenizer)]
)
streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
Thread(target=self.generate, kwargs=dict(
inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
)).start()
return streamer
class CodeShell4bitForCausalLM(CodeShellForCausalLM):
def __init__(self, config):
CodeShellPreTrainedModel.__init__(self, config)
self.transformer = CodeShellModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
try:
import bitsandbytes
from .quantizer import quantize_offline
quantize_offline(self)
except ImportError:
raise ImportError(f"Needs bitsandbytes to run quantize.")
self.post_init()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
**kwargs,
):
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, _ = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=False,
proxies=None,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder="",
_from_auto=False,
_from_pipeline=None,
**kwargs,
)
from .quantizer import load_state_dict_for_qunantied_model
model = cls(config)
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin'), map_location="cpu")
model = load_state_dict_for_qunantied_model(model, state_dict)
model.eval()
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=False,
proxies=None,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder="",
_from_auto=False,
_from_pipeline=None,
**kwargs,
)
except (OSError, TypeError):
pass
device_map = kwargs.pop("device_map", None)
if device_map is not None:
model = model.to(torch.device(device_map))
return model