| |
| |
| ''' |
| @license: (C) Copyright 2025, Hey. |
| @author: Hey |
| @email: sanyuan.hy@alibaba-inc.com |
| @tel: 137****6540 |
| @datetime: 2025/12/30 11:33 |
| @project: lucaone |
| @file: tokenization_lucaone |
| @desc: tokenization_lucaone |
| ''' |
|
|
| import os |
| import json |
| import itertools |
| from typing import List, Optional, Dict, Any, Tuple, Union |
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
|
| def gene_seq_replace(seq): |
| """ |
| Gene sequence preprocessing: A->1, U/T->2, C->3, G->4, N->5 |
| Optimized for performance. |
| """ |
| |
| mapping = { |
| 'A': '1', 'a': '1', |
| 'T': '2', 't': '2', 'U': '2', 'u': '2', |
| 'C': '3', 'c': '3', |
| 'G': '4', 'g': '4' |
| } |
| |
| return "".join([mapping.get(ch, '5') for ch in seq]) |
|
|
| class LucaGPLMTokenizer(PreTrainedTokenizer): |
| """ |
| HuggingFace-compatible tokenizer that performs identical tokenization |
| to the old model's Alphabet class. |
| """ |
| |
| |
| gene_prepend_toks = ['[PAD]', '[UNK]'] |
| gene_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| gene_standard_toks = ['1', '2', '3', '4', '5', '.', '-', '*'] |
| |
| prot_prepend_toks = ['[PAD]', '[UNK]'] |
| prot_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| prot_standard_toks = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', 'J', '.', '-', '*'] |
| |
| gene_prot_prepend_toks = ['[PAD]', '[UNK]'] |
| gene_prot_append_toks = ['[CLS]', '[SEP]', '[MASK]'] |
| |
|
|
| gene_prot_standard_toks = [ |
| '1', |
| '2', |
| '3', |
| '4', |
| '5', |
| 'L', |
| 'A', |
| 'G', |
| 'V', |
| 'S', |
| 'E', |
| 'R', |
| 'T', |
| 'I', |
| 'D', |
| 'P', |
| 'K', |
| 'Q', |
| 'N', |
| 'F', |
| 'Y', |
| 'M', |
| 'H', |
| 'W', |
| 'C', |
| 'X', |
| 'B', |
| 'U', |
| 'Z', |
| 'O', |
| 'J', |
| '.', |
| '-', |
| '*' |
| ] |
|
|
| def __init__( |
| self, |
| vocab_type: str = "gene_prot", |
| prepend_bos: bool = True, |
| append_eos: bool = True, |
| unk_token="[UNK]", |
| pad_token="[PAD]", |
| cls_token="[CLS]", |
| sep_token="[SEP]", |
| mask_token="[MASK]", |
| **kwargs |
| ): |
| |
| if vocab_type.lower() == "prot": |
| prepend_toks = self.prot_prepend_toks |
| append_toks = self.prot_append_toks |
| standard_toks = self.prot_standard_toks |
| elif vocab_type.lower() == "gene": |
| prepend_toks = self.gene_prepend_toks |
| append_toks = self.gene_append_toks |
| standard_toks = self.gene_standard_toks |
| elif vocab_type.lower() in ["gene_prot", "prot_gene"]: |
| prepend_toks = self.gene_prot_prepend_toks |
| append_toks = self.gene_prot_append_toks |
| standard_toks = self.gene_prot_standard_toks |
| else: |
| raise ValueError(f"Not support tokenizer vocab_type: {vocab_type}") |
| |
| |
| self.all_toks = list(prepend_toks) + list(append_toks) + list(standard_toks) |
| self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} |
| self.idx_to_tok = {i: tok for i, tok in enumerate(self.all_toks)} |
| |
| |
| self.vocab_type = vocab_type |
| self.prepend_bos = prepend_bos |
| self.append_eos = append_eos |
| self.unique_no_split_tokens = self.all_toks.copy() |
| |
| |
| self.unk_idx = self.tok_to_idx.get("[UNK]", 1) |
| self.padding_idx = self.tok_to_idx.get("[PAD]", 0) |
| self.cls_idx = self.tok_to_idx.get("[CLS]", 2) |
| self.mask_idx = self.tok_to_idx.get("[MASK]", 4) |
| self.eos_idx = self.tok_to_idx.get("[SEP]", 3) |
|
|
| super().__init__( |
| unk_token=unk_token, |
| pad_token=pad_token, |
| cls_token=cls_token, |
| sep_token=sep_token, |
| mask_token=mask_token, |
| **kwargs |
| ) |
|
|
| def get_vocab(self) -> Dict[str, int]: |
| return self.tok_to_idx.copy() |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self.all_toks) |
|
|
| def get_idx(self, tok): |
| return self.tok_to_idx.get(tok, self.unk_idx) |
|
|
| def get_tok(self, idx): |
| return self.idx_to_tok.get(idx, "[UNK]") |
|
|
| def _tokenize_char_level(self, text: str) -> List[str]: |
| """Simple character-level tokenization (fallback)""" |
| return list(text) |
|
|
| def _tokenize(self, text: str) -> List[str]: |
| """ |
| Tokenize text using the same logic as the old Alphabet.tokenize() method |
| """ |
| text = text.strip() |
| if not text: |
| return [] |
| |
| return list(text) |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| return self.get_idx(token) |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| return self.get_tok(index) |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| return "".join(tokens) |
|
|
| def _convert_text_to_ids(self, text: str, seq_type: str) -> List[int]: |
| """Internal helper to convert text to IDs without special tokens.""" |
| if seq_type == "gene": |
| text = gene_seq_replace(text) |
| tokens = self._tokenize(text) |
| return [self._convert_token_to_id(token) for token in tokens] |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
| """ |
| Build model inputs from a sequence by adding special tokens. |
| This mimics the old model's prepend_bos and append_eos behavior. |
| """ |
| result = token_ids_0.copy() |
| |
| if self.prepend_bos: |
| result = [self.cls_idx] + result |
| if self.append_eos: |
| result = result + [self.eos_idx] |
| |
| return result |
|
|
| def get_special_tokens_mask( |
| self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False |
| ) -> List[int]: |
| """ |
| Retrieve sequence ids from a token list. |
| """ |
| if already_has_special_tokens: |
| return super().get_special_tokens_mask( |
| token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True |
| ) |
|
|
| result = [0] * len(token_ids_0) |
| if self.prepend_bos: |
| result = [1] + result |
| if self.append_eos: |
| result = result + [1] |
| return result |
|
|
| def encode( |
| self, |
| text: str, |
| seq_type: str = "gene", |
| add_special_tokens: bool = True, |
| padding: Union[bool, str] = False, |
| truncation: bool = False, |
| max_length: Optional[int] = None, |
| **kwargs |
| ) -> List[int]: |
| |
| |
| token_ids = self._convert_text_to_ids(text, seq_type) |
| |
| |
| if add_special_tokens: |
| token_ids = self.build_inputs_with_special_tokens(token_ids) |
| |
| |
| if truncation and max_length is not None and len(token_ids) > max_length: |
| token_ids = token_ids[:max_length] |
| |
| if add_special_tokens and self.append_eos: |
| token_ids[-1] = self.eos_idx |
| |
| return token_ids |
|
|
| def __call__( |
| self, |
| text: Union[str, List[str]], |
| text_pair: Optional[Union[str, List[str]]] = None, |
| seq_type: str = "gene", |
| add_special_tokens: bool = True, |
| padding: Union[bool, str] = False, |
| max_length: Optional[int] = None, |
| return_attention_mask: bool = True, |
| return_token_type_ids: bool = True, |
| return_tensors: Optional[str] = None, |
| truncation: bool = False, |
| **kwargs |
| ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| """ |
| Main callable method for tokenization - HuggingFace standard interface |
| """ |
| if isinstance(text, list): |
| |
| return self.batch_encode_plus( |
| text, |
| text_pair=text_pair, |
| seq_type=seq_type, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| max_length=max_length, |
| return_attention_mask=return_attention_mask, |
| return_token_type_ids=return_token_type_ids, |
| return_tensors=return_tensors, |
| truncation=truncation, |
| **kwargs |
| ) |
| else: |
| |
| return self.encode_plus( |
| text, |
| text_pair=text_pair, |
| seq_type=seq_type, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| max_length=max_length, |
| return_attention_mask=return_attention_mask, |
| return_token_type_ids=return_token_type_ids, |
| return_tensors=return_tensors, |
| truncation=truncation, |
| **kwargs |
| ) |
|
|
| def batch_encode_plus(self, *args, **kwargs): |
| |
| |
| |
| |
| batch_outputs = [] |
| batch_text = kwargs["text"] |
| seq_type = kwargs["seq_type"] |
| for text in batch_text: |
| batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs)) |
|
|
| |
| |
| combined = {key: [] for key in batch_outputs[0].keys()} |
| for output in batch_outputs: |
| for key, value in output.items(): |
| combined[key].append(value) |
|
|
| return combined |
|
|
| def encode_plus( |
| self, |
| text: str, |
| text_pair: Optional[str] = None, |
| seq_type: str = "gene", |
| add_special_tokens: bool = True, |
| padding: Union[bool, str] = False, |
| max_length: Optional[int] = None, |
| return_attention_mask: bool = True, |
| return_token_type_ids: bool = True, |
| return_tensors: Optional[str] = None, |
| truncation: bool = False, |
| **kwargs |
| ) -> Dict[str, Any]: |
| |
| |
| kwargs.pop("text_pair", None) |
| |
| token_ids = self.encode( |
| text, |
| seq_type=seq_type, |
| add_special_tokens=add_special_tokens, |
| truncation=truncation, |
| max_length=max_length |
| ) |
| |
| |
| attention_mask = [1] * len(token_ids) |
| if padding == "max_length" and max_length is not None: |
| if len(token_ids) < max_length: |
| pad_length = max_length - len(token_ids) |
| token_ids.extend([self.padding_idx] * pad_length) |
| attention_mask.extend([0] * pad_length) |
| |
| |
| result = {"input_ids": token_ids} |
| |
| if return_attention_mask: |
| result["attention_mask"] = attention_mask |
| |
| if return_token_type_ids: |
| |
| type_value = 0 if seq_type == "gene" else 1 |
| result["token_type_ids"] = [type_value] * len(token_ids) |
| |
| if return_tensors == "pt": |
| import torch |
| for key, value in result.items(): |
| result[key] = torch.tensor(value, dtype=torch.long).unsqueeze(0) |
| |
| return result |
|
|
| def encode_old_model_style( |
| self, |
| text: str, |
| seq_type: str = "gene", |
| max_length: int = None |
| ) -> List[int]: |
| """ |
| Encode using the EXACT same process as the old model's encoder function. |
| This replicates the logic from src/llm/lucaone_virus/get_embedding.py:encoder() |
| """ |
| |
| if seq_type == "gene": |
| text = gene_seq_replace(text) |
| |
| |
| seq_encoded = self.encode(text, seq_type=seq_type, add_special_tokens=False) |
| |
| |
| if max_length and len(seq_encoded) > max_length: |
| seq_encoded = seq_encoded[:max_length] |
| |
| |
| processed_seq_len = len(seq_encoded) + int(self.prepend_bos) + int(self.append_eos) |
| |
| |
| input_ids = [self.padding_idx] * processed_seq_len |
| |
| |
| if self.prepend_bos: |
| input_ids[0] = self.cls_idx |
| |
| |
| start_idx = int(self.prepend_bos) |
| for i, token_id in enumerate(seq_encoded): |
| input_ids[start_idx + i] = token_id |
| |
| |
| if self.append_eos: |
| input_ids[len(seq_encoded) + int(self.prepend_bos)] = self.eos_idx |
| |
| return input_ids |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| """ |
| Save the tokenizer vocabulary to a JSON file. |
| Required by HuggingFace tokenizer interface. |
| """ |
| if filename_prefix is None: |
| filename_prefix = "" |
| else: |
| filename_prefix = filename_prefix + "-" |
| |
| vocab_file = os.path.join(save_directory, f"{filename_prefix}vocab.json") |
| vocab_dict = self.get_vocab() |
| with open(vocab_file, "w", encoding="utf-8") as f: |
| json.dump(vocab_dict, f, ensure_ascii=False, indent=2) |
| |
| return (vocab_file,) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| """ |
| Load tokenizer from pretrained model path (standard HuggingFace interface) |
| """ |
| vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") |
| if os.path.exists(vocab_file): |
| print("Load from saved vocabulary (not implemented yet, use default)") |
| return cls(vocab_type="gene_prot", **kwargs) |
| else: |
| return cls(vocab_type="gene_prot", **kwargs) |
|
|
| class LucaGPLMTokenizerFast(PreTrainedTokenizerFast): |
| """ |
| Fast tokenizer version - currently just delegates to slow tokenizer |
| """ |
| slow_tokenizer_class = LucaGPLMTokenizer |
| |
| def __init__(self, **kwargs): |
| |
| |
| super().__init__(**kwargs) |
|
|
| __all__ = ["LucaGPLMTokenizer", "LucaGPLMTokenizerFast", "gene_seq_replace"] |