diff --git "a/llmlingua/prompt_compressor.py" "b/llmlingua/prompt_compressor.py"
deleted file mode 100644--- "a/llmlingua/prompt_compressor.py"
+++ /dev/null
@@ -1,2412 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import bisect
-import re
-from collections import defaultdict
-from typing import List
-
-import numpy as np
-import torch
-
-import nltk
-import tiktoken
-from transformers import (
- AutoConfig,
- AutoModelForCausalLM,
- AutoModelForTokenClassification,
- AutoTokenizer,
-)
-import torch.nn.functional as F
-import string
-import copy
-from torch.utils.data import DataLoader
-
-from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token
-
-
-class PromptCompressor:
- """
- PromptCompressor is designed for compressing prompts based on a given language model.
-
- This class initializes with the language model and its configuration, preparing it for prompt compression tasks.
- The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing.
- Users can specify different model names and configurations as needed for their particular use case.The architecture is
- based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu,
- Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models."
- arXiv preprint arXiv:2310.05736 (2023).
-
- Args:
- model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf".
- device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
- model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
- open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
- use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper
- "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression".
- Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang.
- "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:,
- Default is True.
- llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is
- {
- "max_batch_size": 50,
- "max_force_token": 100, # max number of the tokens which will be forcely preserved
- }
- Example:
- >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, )
- >>> context = ["This is the first context sentence.", "Here is another context sentence."]
- >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5)
- >>> print(result["compressed_prompt"])
- # This will print the compressed version of the context.
-
- Note:
- The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models.
- """
-
- def __init__(
- self,
- model_name: str = "NousResearch/Llama-2-7b-hf",
- device_map: str = "cuda",
- model_config: dict = {},
- open_api_config: dict = {},
- use_llmlingua2: bool = True,
- llmlingua2_config: dict = {},
- ):
- self.model_name = model_name
- self.use_llmlingua2 = use_llmlingua2
- self.retrieval_model = None
- self.retrieval_model_name = None
- self.open_api_config = open_api_config
- self.cache_bos_num = 10
- self.prefix_bos_num = 100
- self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
-
- self.load_model(model_name, device_map, model_config)
- if use_llmlingua2:
- self.init_llmlingua2(**llmlingua2_config)
-
- def init_llmlingua2(
- self,
- max_batch_size: int = 50,
- max_force_token: int = 100,
- ):
-
- seed_everything(42)
- self.max_batch_size = max_batch_size
- self.max_seq_len = 512
- self.max_force_token = max_force_token
- self.special_tokens = set(self.tokenizer.special_tokens_map.values())
-
- self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)]
- self.tokenizer.add_special_tokens(
- {"additional_special_tokens": self.added_tokens}
- )
- self.model.resize_token_embeddings(len(self.tokenizer))
-
- def load_model(
- self, model_name: str, device_map: str = "cuda", model_config: dict = {}
- ):
- trust_remote_code = model_config.get("trust_remote_code", True)
- if "trust_remote_code" not in model_config:
- model_config["trust_remote_code"] = trust_remote_code
- config = AutoConfig.from_pretrained(model_name, **model_config)
- tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config)
- if model_config.get("pad_to_left", True):
- tokenizer.padding_side = "left"
- tokenizer.pad_token_id = (
- config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
- )
- MODEL_CLASS = (
- AutoModelForTokenClassification
- if any("ForTokenClassification" in ar for ar in config.architectures)
- else AutoModelForCausalLM
- )
- self.device = (
- device_map
- if any(key in device_map for key in ["cuda", "cpu", "mps"])
- else "cuda"
- )
- if "cuda" in device_map or "cpu" in device_map:
- model = MODEL_CLASS.from_pretrained(
- model_name,
- torch_dtype=model_config.get(
- "torch_dtype", "auto" if device_map == "cuda" else torch.float32
- ),
- device_map=device_map,
- config=config,
- ignore_mismatched_sizes=True,
- **model_config,
- )
- else:
- model = MODEL_CLASS.from_pretrained(
- model_name,
- device_map=device_map,
- torch_dtype=model_config.get("torch_dtype", "auto"),
- pad_token_id=tokenizer.pad_token_id,
- **model_config,
- )
- self.tokenizer = tokenizer
- self.model = model
- self.context_idxs = []
- self.max_position_embeddings = config.max_position_embeddings
-
- def get_ppl(
- self,
- text: str,
- granularity: str = "sentence",
- input_ids=None,
- attention_mask=None,
- past_key_values=None,
- return_kv=False,
- end=None,
- condition_mode: str = "none",
- condition_pos_id: int = 0,
- ):
- if input_ids is None:
- tokenized_text = self.tokenizer(text, return_tensors="pt")
- input_ids = tokenized_text["input_ids"].to(self.device)
- attention_mask = tokenized_text["attention_mask"].to(self.device)
- if past_key_values is not None:
- past_length = past_key_values[0][0].shape[2]
- else:
- past_length = 0
- if end is None:
- end = input_ids.shape[1]
- end = min(end, past_length + self.max_position_embeddings)
- with torch.no_grad():
- response = self.model(
- input_ids[:, past_length:end],
- attention_mask=attention_mask[:, :end],
- past_key_values=past_key_values,
- use_cache=True,
- )
- past_key_values = response.past_key_values
-
- shift_logits = response.logits[..., :-1, :].contiguous()
- shift_labels = input_ids[..., past_length + 1 : end].contiguous()
- # Flatten the tokens
- active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
- active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
- active_labels = shift_labels.view(-1)[active]
- loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
- loss = loss_fct(active_logits, active_labels)
- if condition_mode == "before":
- loss = loss[:condition_pos_id]
- elif condition_mode == "after":
- loss = loss[condition_pos_id:]
- res = loss.mean() if granularity == "sentence" else loss
- return (res, past_key_values) if return_kv else res
-
- def __call__(self, *args, **kwargs):
- return self.compress_prompt(*args, **kwargs)
-
- def structured_compress_prompt(
- self,
- context: List[str],
- instruction: str = "",
- question: str = "",
- rate: float = 0.5,
- target_token: float = -1,
- iterative_size: int = 200,
- force_context_ids: List[int] = None,
- force_context_number: int = None,
- use_sentence_level_filter: bool = False,
- use_context_level_filter: bool = True,
- use_token_level_filter: bool = True,
- keep_split: bool = False,
- keep_first_sentence: int = 0,
- keep_last_sentence: int = 0,
- keep_sentence_number: int = 0,
- high_priority_bonus: int = 100,
- context_budget: str = "+100",
- token_budget_ratio: float = 1.4,
- condition_in_question: str = "none",
- reorder_context: str = "original",
- dynamic_context_compression_ratio: float = 0.0,
- condition_compare: bool = False,
- add_instruction: bool = False,
- rank_method: str = "llmlingua",
- concate_question: bool = True,
- ):
- """
- Compresses the given prompt context based on a specified structure.
-
- Each element of context should be segmented using one or more non-nested '' tags.
- Each '' tag can include optional parameters 'rate' and 'compress' (e.g., ''),
- indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'.
- When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment.
-
- Args:
- context (List[str]): List of context strings divided by '' tags with optional compression settings.
- instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
- question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
- rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression".
- Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern,
- Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023):
- .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
- Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
- fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
- to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter
- and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate.
- However, for segments where no specific rate is defined, the global rate serves as the default value. The final
- compression rate of the entire text is a composite result of multiple compression rates applied across different sections.
- target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no
- specific target. The actual number of tokens after compression should generally be less than the specified target_token,
- but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
- the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level
- filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token.
- However, for segments where no specific rate is defined, the global rate calculated from global target token serves
- as the default value. The final target token of the entire text is a composite result of multiple compression rates
- applied across different sections.
- iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
- force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
- force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
- use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
- use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
- use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
- keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
- keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
- keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
- keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
- high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
- context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
- token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
- condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
- reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
- dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
- condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
- add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
- rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
- concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
-
- Returns:
- dict: A dictionary containing:
- - "compressed_prompt" (str): The resulting compressed prompt.
- - "origin_tokens" (int): The original number of tokens in the input.
- - "compressed_tokens" (int): The number of tokens in the compressed output.
- - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
- - "rate" (str): The compression rate achieved, in a human-readable format.
- - "saving" (str): Estimated savings in GPT-4 token usage.
- """
- if not context:
- context = [" "]
- if isinstance(context, str):
- context = [context]
- context = [
- self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids)
- for c in context
- ]
- context_tokens_length = [self.get_token_length(c) for c in context]
- instruction_tokens_length, question_tokens_length = self.get_token_length(
- instruction
- ), self.get_token_length(question)
- if target_token == -1:
- target_token = (
- (
- instruction_tokens_length
- + question_tokens_length
- + sum(context_tokens_length)
- )
- * rate
- - instruction_tokens_length
- - (question_tokens_length if concate_question else 0)
- )
- else:
- rate = target_token / sum(context_tokens_length)
- (
- context,
- context_segs,
- context_segs_rate,
- context_segs_compress,
- ) = self.segment_structured_context(context, rate)
- return self.compress_prompt(
- context,
- instruction,
- question,
- rate,
- target_token,
- iterative_size,
- force_context_ids,
- force_context_number,
- use_sentence_level_filter,
- use_context_level_filter,
- use_token_level_filter,
- keep_split,
- keep_first_sentence,
- keep_last_sentence,
- keep_sentence_number,
- high_priority_bonus,
- context_budget,
- token_budget_ratio,
- condition_in_question,
- reorder_context,
- dynamic_context_compression_ratio,
- condition_compare,
- add_instruction,
- rank_method,
- concate_question,
- context_segs=context_segs,
- context_segs_rate=context_segs_rate,
- context_segs_compress=context_segs_compress,
- )
-
- def compress_prompt(
- self,
- context: List[str],
- instruction: str = "",
- question: str = "",
- rate: float = 0.5,
- target_token: float = -1,
- iterative_size: int = 200,
- force_context_ids: List[int] = None,
- force_context_number: int = None,
- use_sentence_level_filter: bool = False,
- use_context_level_filter: bool = True,
- use_token_level_filter: bool = True,
- keep_split: bool = False,
- keep_first_sentence: int = 0,
- keep_last_sentence: int = 0,
- keep_sentence_number: int = 0,
- high_priority_bonus: int = 100,
- context_budget: str = "+100",
- token_budget_ratio: float = 1.4,
- condition_in_question: str = "none",
- reorder_context: str = "original",
- dynamic_context_compression_ratio: float = 0.0,
- condition_compare: bool = False,
- add_instruction: bool = False,
- rank_method: str = "llmlingua",
- concate_question: bool = True,
- context_segs: List[str] = None,
- context_segs_rate: List[float] = None,
- context_segs_compress: List[bool] = None,
- target_context: int = -1,
- context_level_rate: float = 1.0,
- context_level_target_token: int = -1,
- return_word_label: bool = False,
- word_sep: str = "\t\t|\t\t",
- label_sep: str = " ",
- token_to_word: str = "mean",
- force_tokens: List[str] = [],
- force_reserve_digit: bool = False,
- drop_consecutive: bool = False,
- chunk_end_tokens: List[str] = [".", "\n"],
- ):
- """
- Compresses the given context.
-
- Args:
- context (List[str]): List of context strings that form the basis of the prompt.
- instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
- question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
- rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined
- the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne,
- Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression."
- arXiv preprint arXiv:2309.10668 (2023):
- .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
- Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
- fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
- to 1.0, representing the target compression rate.
- target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
- The actual number of tokens after compression should generally be less than the specified target_token, but there can
- be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
- the sole criterion, overriding the ``rate``.
- iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
- force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
- force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
- use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
- use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
- use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
- keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
- keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
- keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
- keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
- high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
- context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
- token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
- condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
- reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
- dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
- condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
- add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
- rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
- concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
-
- target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
- context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
- context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
- Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
- force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
- return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
- word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
- label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
- token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
- force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
- force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
- drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
- Default is False.
- chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"],
- Returns:
- dict: A dictionary containing:
- - "compressed_prompt" (str): The resulting compressed prompt.
- - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2.
- - "fn_labeled_original_prompt" (str): original words along with their labels
- indicating whether to reserve in compressed prompt, in the format (word label_sep label)
- Only used in llmlingua2 when return_word_label = True.
- - "origin_tokens" (int): The original number of tokens in the input.
- - "compressed_tokens" (int): The number of tokens in the compressed output.
- - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
- - "rate" (str): The compression rate achieved, in a human-readable format.
- - "saving" (str): Estimated savings in GPT-4 token usage.
- """
- if self.use_llmlingua2:
- return self.compress_prompt_llmlingua2(
- context,
- rate=rate,
- target_token=target_token,
- use_context_level_filter=use_context_level_filter,
- use_token_level_filter=use_token_level_filter,
- target_context=target_context,
- context_level_rate=context_level_rate,
- context_level_target_token=context_level_target_token,
- force_context_ids=force_context_ids,
- return_word_label=return_word_label,
- word_sep=word_sep,
- label_sep=label_sep,
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- force_reserve_digit=force_reserve_digit,
- drop_consecutive=drop_consecutive,
- chunk_end_tokens=chunk_end_tokens,
- )
- assert (
- rate <= 1.0
- ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
-
- if not context:
- context = [" "]
- if isinstance(context, str):
- context = [context]
- assert not (
- rank_method == "longllmlingua" and not question
- ), "In the LongLLMLingua, it is necessary to set a question."
- if condition_compare and "_condition" not in condition_in_question:
- condition_in_question += "_condition"
- if rank_method == "longllmlingua":
- if condition_in_question == "none":
- condition_in_question = "after"
- elif rank_method == "llmlingua":
- condition_in_question = (
- "none"
- if "_condition" not in condition_in_question
- else "none_condition"
- )
- origin_tokens = len(
- self.oai_tokenizer.encode(
- "\n\n".join([instruction] + context + [question]).strip()
- )
- )
- context_tokens_length = [self.get_token_length(c) for c in context]
- instruction_tokens_length, question_tokens_length = self.get_token_length(
- instruction
- ), self.get_token_length(question)
- if target_token == -1:
- target_token = (
- (
- instruction_tokens_length
- + question_tokens_length
- + sum(context_tokens_length)
- )
- * rate
- - instruction_tokens_length
- - (question_tokens_length if concate_question else 0)
- )
- condition_flag = "_condition" in condition_in_question
- condition_in_question = condition_in_question.replace("_condition", "")
-
- if len(context) > 1 and use_context_level_filter:
- context, dynamic_ratio, context_used = self.control_context_budget(
- context,
- context_tokens_length,
- target_token,
- force_context_ids,
- force_context_number,
- question,
- condition_in_question,
- reorder_context=reorder_context,
- dynamic_context_compression_ratio=dynamic_context_compression_ratio,
- rank_method=rank_method,
- context_budget=context_budget,
- context_segs=context_segs,
- context_segs_rate=context_segs_rate,
- context_segs_compress=context_segs_compress,
- )
- if context_segs is not None:
- context_segs = [context_segs[idx] for idx in context_used]
- context_segs_rate = [context_segs_rate[idx] for idx in context_used]
- context_segs_compress = [
- context_segs_compress[idx] for idx in context_used
- ]
- else:
- dynamic_ratio = [0.0] * len(context)
-
- segments_info = []
- if use_sentence_level_filter:
- context, segments_info = self.control_sentence_budget(
- context,
- target_token,
- keep_first_sentence=keep_first_sentence,
- keep_last_sentence=keep_last_sentence,
- keep_sentence_number=keep_sentence_number,
- high_priority_bonus=high_priority_bonus,
- token_budget_ratio=token_budget_ratio,
- question=question,
- condition_in_question=condition_in_question,
- rank_method=rank_method,
- context_segs=context_segs,
- context_segs_rate=context_segs_rate,
- context_segs_compress=context_segs_compress,
- )
- elif context_segs is not None:
- for context_idx in range(len(context)):
- segments_info.append(
- [
- (len(seg_text), seg_rate, seg_compress)
- for seg_text, seg_rate, seg_compress in zip(
- context_segs[context_idx],
- context_segs_rate[context_idx],
- context_segs_compress[context_idx],
- )
- ]
- )
- segments_info = [
- self.concate_segment_info(segment_info) for segment_info in segments_info
- ]
-
- if condition_flag:
- prefix = question + "\n\n" + instruction if add_instruction else question
- if (
- self.get_token_length(prefix + "\n\n") + iterative_size * 2
- > self.max_position_embeddings
- ):
- tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
- prefix = self.tokenizer.decode(
- tokens[: self.prefix_bos_num]
- + tokens[
- len(tokens)
- - self.max_position_embeddings
- + 2
- + self.prefix_bos_num
- + 2 * iterative_size :
- ]
- )
- start = self.get_prefix_length(prefix + "\n\n", context[0])
- context = [prefix] + context
- else:
- start = 0
-
- if use_token_level_filter:
- context = self.iterative_compress_prompt(
- context,
- target_token,
- iterative_size=iterative_size,
- keep_split=keep_split,
- start=start,
- dynamic_ratio=dynamic_ratio,
- condition_compare=condition_compare,
- segments_info=segments_info,
- )
- compressed_prompt = (
- self.tokenizer.batch_decode(context[0])[0]
- .replace(" ", "")
- .replace("", "")
- )
- else:
- if condition_flag:
- context = context[1:]
- compressed_prompt = "\n\n".join(context)
-
- res = []
- if instruction:
- res.append(instruction)
- if compressed_prompt.strip():
- res.append(compressed_prompt)
- if question and concate_question:
- res.append(question)
-
- compressed_prompt = "\n\n".join(res)
-
- compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt))
- saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
- ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens
- rate = 1 / ratio
- return {
- "compressed_prompt": compressed_prompt,
- "origin_tokens": origin_tokens,
- "compressed_tokens": compressed_tokens,
- "ratio": f"{ratio:.1f}x",
- "rate": f"{rate * 100:.1f}%",
- "saving": f", Saving ${saving:.1f} in GPT-4.",
- }
-
- def compress_prompt_llmlingua2(
- self,
- context: List[str],
- rate: float = 0.5,
- target_token: int = -1,
- use_context_level_filter: bool = False,
- use_token_level_filter: bool = True,
- target_context: int = -1,
- context_level_rate: float = 1.0,
- context_level_target_token: int = -1,
- force_context_ids: List[int] = [],
- return_word_label: bool = False,
- word_sep: str = "\t\t|\t\t",
- label_sep: str = " ",
- token_to_word: str = "mean",
- force_tokens: List[str] = [],
- force_reserve_digit: bool = False,
- drop_consecutive: bool = False,
- chunk_end_tokens: List[str] = [".", "\n"],
- ):
- """
- Compresses the given context, instruction and question.
-
- Args:
- context (List[str]): List of context strings that form the basis of the prompt.
- rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate
- generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
- it should be a float greater than or equal to 1.0, representing the target compression rate.
- target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
- The actual number of tokens after compression should generally be less than the specified target_token, but there can
- be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
- the sole criterion, overriding the rate.
- target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
- Only used in the coarse-to-fine compression.
- context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
- Only used in the coarse-to-fine compression.
- context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
- Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
- force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
- return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
- word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
- label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label. Default is " ".
- token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
- force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
- force_reserve_digit (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
- drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt.
- Default is False.
- chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"].
- Returns:
- dict: A dictionary containing:
- - "compressed_prompt" (str): The resulting compressed prompt.
- - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt.
- - "fn_labeled_original_prompt" (str): original words along with their labels
- indicating whether to reserve in compressed prompt, in the format (word label_sep label)
- - "origin_tokens" (int): The original number of tokens in the input.
- - "compressed_tokens" (int): The number of tokens in the compressed output.
- - "ratio" (str): The compression ratio achieved, in a human-readable format.
- - "rate" (str): The compression rate achieved, in a human-readable format.
- - "saving" (str): Estimated savings in GPT-4 token usage.
-
- """
- assert len(force_tokens) <= self.max_force_token
- token_map = {}
- for i, t in enumerate(force_tokens):
- if len(self.tokenizer.tokenize(t)) != 1:
- token_map[t] = self.added_tokens[i]
- chunk_end_tokens = copy.deepcopy(chunk_end_tokens)
- for c in chunk_end_tokens:
- if c in token_map:
- chunk_end_tokens.append(token_map[c])
- chunk_end_tokens = set(chunk_end_tokens)
-
- if type(context) == str:
- context = [context]
- context = copy.deepcopy(context)
-
- if len(context) == 1 and use_context_level_filter:
- use_context_level_filter = False
-
- n_original_token = 0
- context_chunked = []
- for i in range(len(context)):
- n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True)
- for ori_token, new_token in token_map.items():
- context[i] = context[i].replace(ori_token, new_token)
- context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens))
-
- if use_context_level_filter:
- # want use_context_level_filter but do not specify any parameters in context level?
- # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token
- if (
- target_context <= 0
- and context_level_rate >= 1.0
- and context_level_target_token <= 0
- ):
- if target_token < 0 and rate < 1.0:
- context_level_rate = (
- (rate + 1.0) / 2 if use_token_level_filter else rate
- )
- print(
- f"set context level compression rate to {context_level_rate}."
- )
- if target_token >= 0:
- context_level_target_token = (
- target_token * 2 if use_token_level_filter else target_token
- )
- print(
- f"set context level target token to {context_level_target_token}."
- )
-
- if target_context >= 0:
- context_level_rate = min(target_context / len(context), 1.0)
- # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.')
- if context_level_target_token >= 0:
- context_level_rate = min(
- context_level_target_token / n_original_token, 1.0
- )
- # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.')
-
- context_probs, context_words = self.__get_context_prob(
- context_chunked,
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- )
-
- threshold = np.percentile(
- context_probs, int(100 * (1 - context_level_rate))
- )
-
- reserved_context = []
- context_label = [False] * len(context_probs)
- for i, p in enumerate(context_probs):
- if p >= threshold or (
- force_context_ids is not None and i in force_context_ids
- ):
- reserved_context.append(context_chunked[i])
- context_label[i] = True
- n_reserved_token = 0
- for chunks in reserved_context:
- for c in chunks:
- n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
- if target_token >= 0:
- rate = min(target_token / n_reserved_token, 1.0)
- print(
- f"override compression rate to {rate} because you specified target_token = {target_token}."
- )
-
- if use_token_level_filter:
- compressed_context, word_list, word_label_list = self.__compress(
- reserved_context,
- reduce_rate=max(0, 1 - rate),
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- drop_consecutive=drop_consecutive,
- )
- else:
- compressed_context, word_list, word_label_list = self.__compress(
- reserved_context,
- reduce_rate=0,
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- drop_consecutive=drop_consecutive,
- )
- print(
- "return the original text because you specify use_token_level_filter=False"
- )
-
- n_compressed_token = 0
- for c in compressed_context:
- n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
- saving = (n_original_token - n_compressed_token) * 0.06 / 1000
- ratio = (
- 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
- )
- res = {
- "compressed_prompt": "\n\n".join(compressed_context),
- "compressed_prompt_list": compressed_context,
- "origin_tokens": n_original_token,
- "compressed_tokens": n_compressed_token,
- "ratio": f"{ratio:.1f}x",
- "rate": f"{1 / ratio * 100:.1f}%",
- "saving": f", Saving ${saving:.1f} in GPT-4.",
- }
- if return_word_label:
- words = []
- labels = []
- j = 0
- for i in range(len(context)):
- if context_label[i]:
- words.extend(word_list[j])
- labels.extend(word_label_list[j])
- j += 1
- else:
- words.extend(context_words[i])
- labels.extend([0] * len(context_words[i]))
- word_label_lines = word_sep.join(
- [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
- )
- res["fn_labeled_original_prompt"] = word_label_lines
- return res
-
- if target_token > 0:
- rate = min(target_token / n_original_token, 1.0)
- print(
- f"override compression rate to {rate} \
- because you specified target_token = {target_token}."
- )
-
- if use_token_level_filter:
- compressed_context, word_list, word_label_list = self.__compress(
- context_chunked,
- reduce_rate=max(0, 1 - rate),
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- drop_consecutive=drop_consecutive,
- )
- else:
- compressed_context, word_list, word_label_list = self.__compress(
- context_chunked,
- reduce_rate=0,
- token_to_word=token_to_word,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- drop_consecutive=drop_consecutive,
- )
- print(
- "return the original text because you specify use_token_level_filter=False"
- )
-
- n_compressed_token = 0
- for c in compressed_context:
- n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
- saving = (n_original_token - n_compressed_token) * 0.06 / 1000
- ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
- res = {
- "compressed_prompt": "\n\n".join(compressed_context),
- "compressed_prompt_list": compressed_context,
- "origin_tokens": n_original_token,
- "compressed_tokens": n_compressed_token,
- "ratio": f"{ratio:.1f}x",
- "rate": f"{1 / ratio * 100:.1f}%",
- "saving": f", Saving ${saving:.1f} in GPT-4.",
- }
- if return_word_label:
- words = []
- labels = []
- for w_list, l_list in zip(word_list, word_label_list):
- words.extend(w_list)
- labels.extend(l_list)
-
- # new_words = []
- # new_labels = []
- # for i in range(len(words)):
- # word, label = words[i], labels[i]
- # if word in string.punctuation:
- # if labels[i-1] == 1 and label == 1 and i > 0:
- # new_words[-1] += word
- # else:
- # new_words.append(word)
- # new_labels.append(label)
- # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)])
-
- word_label_lines = word_sep.join(
- [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
- )
- res["fn_labeled_original_prompt"] = word_label_lines
- return res
-
- def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False):
- if use_oai_tokenizer:
- return len(self.oai_tokenizer.encode(text))
- else:
- return len(
- self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
- )
-
- def get_prefix_length(self, prefix: str, text: str):
- possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
- full_input_ids = self.tokenizer(
- prefix + text[:100], add_special_tokens=False
- ).input_ids
- for i in range(possible_prefix_token, len(full_input_ids)):
- cur_prefix = self.tokenizer.decode(full_input_ids[:i])
- if cur_prefix == prefix:
- break
- assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
- return i
-
- def get_condition_ppl(
- self,
- text: str,
- question: str,
- condition_in_question: str = "none",
- granularity: str = "sentence",
- ):
- if condition_in_question == "none":
- return self.get_ppl(text, granularity=granularity)
- elif condition_in_question == "before":
- return self.get_ppl(
- question + text,
- granularity=granularity,
- condition_mode="after",
- condition_pos_id=self.get_token_length(question) - 1,
- )
- elif condition_in_question == "after":
- return self.get_ppl(
- text + question,
- granularity=granularity,
- condition_mode="after",
- condition_pos_id=self.get_token_length(text) - 1,
- )
-
- def get_dynamic_compression_ratio(
- self,
- context: list,
- target_token: float,
- iterative_size: int,
- dynamic_ratio: list,
- start: int,
- seg_info: List[List[tuple]] = None,
- ):
- def get_ratio(base: float, delta: float):
- return max(min(1, base + delta), 0)
-
- context_length = [self.get_token_length(ii, False) + 2 for ii in context]
- if start:
- context_length = context_length[1:]
- tau = target_token / (sum(context_length) + 1)
- res, idx, last, last_target = [], 0, 1, []
- while idx < len(context_length):
- if last + context_length[idx] >= iterative_size:
- last_target.append(
- (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
- )
- res.append(last_target)
- last = last + context_length[idx] - iterative_size
- if last > iterative_size:
- k = last // iterative_size
- res.extend(
- [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
- )
- last -= k * iterative_size
-
- last_target = (
- [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
- )
- else:
- last += context_length[idx]
- last_target.append(
- (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
- )
- idx += 1
- if last_target:
- res.append(last_target)
- return res
-
- def get_structured_dynamic_compression_ratio(
- self,
- context: list,
- iterative_size: int,
- dynamic_ratio: list,
- start: int,
- seg_info: List[List[tuple]] = None,
- ):
- if start:
- pure_context = context[1:]
- else:
- pure_context = context
- global_dynamic_rate, global_dynamic_compress, segments = [], [], []
- for context_idx, text in enumerate(pure_context):
- text_seen = 0
- for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate(
- seg_info[context_idx]
- ):
- seg_text = text[text_seen : text_seen + seg_len]
- if (
- seg_idx == len(seg_info[context_idx]) - 1
- and context_idx != len(pure_context) - 1
- ):
- seg_text += "\n\n"
- segments.append(seg_text)
- if seg_compress:
- global_dynamic_rate.append(seg_rate)
- else:
- global_dynamic_rate.append(1.0)
- global_dynamic_compress.append(seg_compress)
- text_seen += seg_len
- origin_text = "\n\n".join(pure_context)
- assert len("".join(segments)) == len(origin_text)
- assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)
-
- text_input_ids = self.tokenizer(
- "\n\n".join(context), add_special_tokens=False
- ).input_ids[start:]
- assert self.tokenizer.decode(text_input_ids) == origin_text
- dynamic_compression_ratio = self.token_segment(
- text_input_ids,
- iterative_size,
- segments,
- global_dynamic_rate,
- global_dynamic_compress,
- )
- return dynamic_compression_ratio
-
- def token_segment(
- self,
- text_input_ids: List[int],
- iterative_size: int,
- segments: List[str],
- global_dynamic_rate: List[float],
- global_dynamic_compress: List[bool],
- ):
- decode_window = 3
- seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1
- dynamic_compression_rate, local_compresssion_rate = [], []
- for i in range(len(text_input_ids)):
- if i < decode_window:
- id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1]
- else:
- id_pre, id_cur = (
- text_input_ids[i - decode_window + 1 : i],
- text_input_ids[i - decode_window + 1 : i + 1],
- )
- cur_word = self.tokenizer.decode(id_cur)[
- len(self.tokenizer.decode(id_pre)) :
- ]
- cur_word_len = len(cur_word)
- if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen:
- possible_rate, possible_compress = [], []
- while (
- cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen
- ):
- possible_rate.append(global_dynamic_rate[seg_idx])
- possible_compress.append(global_dynamic_compress[seg_idx])
- cur_word_len -= len(segments[seg_idx]) - seg_seen
- seg_idx += 1
- seg_seen = 0
- if cur_word_len:
- possible_rate.append(global_dynamic_rate[seg_idx])
- possible_compress.append(global_dynamic_compress[seg_idx])
- new_rate = 1.0 if False in possible_compress else min(possible_rate)
- else:
- new_rate = global_dynamic_rate[seg_idx]
- if new_rate != last_rate and i - token_seen_num:
- local_compresssion_rate.append((i - token_seen_num, last_rate))
- token_seen_num = i
- last_rate = new_rate
- seg_seen += cur_word_len
- if (i + 1) % iterative_size == 0:
- if token_seen_num != i + 1:
- local_compresssion_rate.append((i + 1 - token_seen_num, last_rate))
- token_seen_num = i + 1
- dynamic_compression_rate.append(local_compresssion_rate[:])
- local_compresssion_rate = []
- if token_seen_num != len(text_input_ids):
- local_compresssion_rate.append(
- (len(text_input_ids) - token_seen_num, last_rate)
- )
- if local_compresssion_rate != []:
- dynamic_compression_rate.append(local_compresssion_rate[:])
- return dynamic_compression_rate
-
- def control_context_budget(
- self,
- context: List[str],
- context_tokens_length: List[int],
- target_token: float,
- force_context_ids: List[int] = None,
- force_context_number: int = None,
- question: str = "",
- condition_in_question: str = "none",
- reorder_context: str = "original",
- dynamic_context_compression_ratio: float = 0.0,
- rank_method: str = "longllmlingua",
- context_budget: str = "+100",
- context_segs: List[List[str]] = None,
- context_segs_rate: List[List[float]] = None,
- context_segs_compress: List[List[bool]] = None,
- ):
- demostrations_sort = self.get_rank_results(
- context,
- question,
- rank_method,
- condition_in_question,
- context_tokens_length,
- )
-
- if target_token < 0:
- target_token = 100
- target_token = eval("target_token" + context_budget)
- res = []
- used = force_context_ids if force_context_ids is not None else []
- if context_segs is not None:
- for idx, _ in enumerate(context):
- if False in context_segs_compress[idx]:
- used.append(idx)
-
- self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
- for idx, _ in demostrations_sort:
- if idx >= len(context_tokens_length):
- continue
- target_token -= context_tokens_length[idx]
- if idx not in used:
- used.append(idx)
- if target_token < 0 or (
- force_context_number is not None and len(res) >= force_context_number
- ):
- break
- original_used = used
- if reorder_context == "original":
- used = sorted(used)
- elif reorder_context == "two_stage":
- l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
- _ for idx, _ in enumerate(used) if idx % 2 == 1
- ]
- used = l + r[::-1]
-
- if dynamic_context_compression_ratio > 0:
- N = len(used)
- dynamic_ratio = [
- i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
- for i in range(-(N - 1), N, 2)
- ][::-1]
- dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
- dynamic_ratio = [dynamic_ratio_map[i] for i in used]
- else:
- dynamic_ratio = [0.0] * len(used)
-
- res = [context[idx] for idx in used if idx < len(context)]
- return res, dynamic_ratio, used
-
- def control_sentence_budget(
- self,
- context: List[str],
- target_token: float,
- keep_first_sentence: int = 0,
- keep_last_sentence: int = 0,
- keep_sentence_number: int = 0,
- high_priority_bonus: int = 100,
- token_budget_ratio: float = 1.4,
- question: str = "",
- condition_in_question: str = "none",
- rank_method: str = "longllmlingua",
- context_segs: List[List[str]] = None,
- context_segs_rate: List[List[float]] = None,
- context_segs_compress: List[List[bool]] = None,
- ):
- def keep_sentence(dem_idx: int, sent_keep: int):
- idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
- for idx in idxs:
- sentence_ppl[idx] += high_priority_bonus
-
- def sync_sentence(segments, text):
- seg_num = len(segments)
- new_segments = []
- text_seen = 0
- seg_idx, cur_seg_seen = 0, 0
- for i, s in enumerate(text):
- while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]:
- if cur_seg_seen < len(segments[seg_idx]) - 1:
- cur_seg_seen += 1
- continue
- new_segments.append(text[text_seen:i])
- text_seen = i
- seg_idx += 1
- cur_seg_seen = 0
- cur_seg_seen += 1
- if seg_idx == seg_num:
- break
- if cur_seg_seen == len(segments[seg_idx]):
- new_segments.append(text[text_seen : i + 1])
- text_seen = i + 1
- seg_idx += 1
- cur_seg_seen = 0
- if text_seen < len(text):
- new_segments.append(text[text_seen:])
- assert len("".join(new_segments)) == len(text)
- return new_segments
-
- sentences = [nltk.sent_tokenize(c) for c in context]
- dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
- for idx_d, s in enumerate(sentences):
- for _ in s:
- dem_g[idx_d].add(idx)
- s2de[idx] = idx_d
- idx += 1
-
- if context_segs is not None:
- context_segs = [
- sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences)
- ]
- sen2seg_ratio = {}
- idx = 0
- for idx_d, sentences_each_context in enumerate(sentences):
- segments_length = [len(s) for s in context_segs[idx_d]]
- seg_idx, cur_seg_seen = 0, 0
- for sentence in sentences_each_context:
- sentence_seg_ratio = []
- remain = len(sentence)
- while remain:
- if segments_length[seg_idx] - cur_seg_seen <= remain:
- new_seg_len = segments_length[seg_idx] - cur_seg_seen
- sentence_seg_ratio.append(
- (
- new_seg_len,
- context_segs_rate[idx_d][seg_idx],
- context_segs_compress[idx_d][seg_idx],
- )
- )
- seg_idx += 1
- cur_seg_seen = 0
- remain -= new_seg_len
- else:
- sentence_seg_ratio.append(
- (
- remain,
- context_segs_rate[idx_d][seg_idx],
- context_segs_compress[idx_d][seg_idx],
- )
- )
- cur_seg_seen += remain
- remain = 0
- sen2seg_ratio[idx] = sentence_seg_ratio
- idx += 1
-
- context_sentences = [s for ii in sentences for s in ii]
- sentence_tokens_length = [
- self.get_token_length(sentence) for sentence in context_sentences
- ]
- N = len(context_sentences)
- flags = list(range(len(context_sentences)))
- if len(sentence_tokens_length) == 1:
- return context
- if rank_method == "longllmlingua":
- sentence_ppl = [
- self.get_condition_ppl(sentence, question, condition_in_question)
- .cpu()
- .numpy()
- .item()
- for sentence in context_sentences
- ]
- if keep_first_sentence:
- sentence_ppl[:keep_first_sentence] = [
- ii + high_priority_bonus
- for ii in sentence_ppl[:keep_first_sentence]
- ]
- if keep_last_sentence:
- sentence_ppl[-keep_last_sentence:] = [
- ii + high_priority_bonus
- for ii in sentence_ppl[-keep_last_sentence:]
- ]
- if keep_sentence_number:
- for dem_idx in range(len(sentences)):
- keep_sentence(dem_idx, keep_sentence_number)
- sort_direct = -1 if condition_in_question == "none" else 1
- sent_sort = sorted(
- enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
- )
- else:
- sent_sort = self.get_rank_results(
- context_sentences,
- question,
- rank_method,
- condition_in_question,
- [0] * len(context_sentences),
- )
-
- sentence_flags = [False] * N
- if target_token < 0:
- target_token = 100
- target_token *= token_budget_ratio
- res = []
- for idx, _ in sent_sort:
- idx = flags[idx]
- target_token -= sentence_tokens_length[idx]
- sentence_flags[idx] = True
- if target_token < 0:
- break
-
- if context_segs is not None:
- for idx in range(N):
- preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]]
- if False in preserved:
- sentence_flags[idx] = True
-
- idx = 0
- res = []
- new_segments_info = []
- for s in sentences:
- tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
- res.append("".join(tmp))
- if context_segs is not None:
- segment_ratio = []
- for ii in range(len(s)):
- if sentence_flags[idx + ii]:
- segment_ratio.extend(sen2seg_ratio[idx + ii])
- new_segments_info.append(segment_ratio)
- idx += len(s)
- if context_segs is not None:
- new_segments_info = [
- self.concate_segment_info(segment_info)
- for segment_info in new_segments_info
- ]
- return res, new_segments_info
-
- def get_compressed_input(
- self,
- loss,
- input_ids,
- attention_mask,
- end=200,
- iterative_size=200,
- threshold=0.5,
- keep_flag=None,
- split_token_id: int = 13,
- start: int = 0,
- self_loss=None,
- self_input_ids=None,
- self_attention_mask=None,
- ):
- if self_loss is not None:
- need_idx = torch.concat(
- [
- loss[:start] > 0,
- self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
- loss[:1] > 0,
- ]
- )
- else:
- need_idx = torch.concat([loss > threshold, loss[:1] > 0])
- need_idx[end:] = 1
- need_idx[: end - iterative_size] = 1
- loss = loss[need_idx[:-1]]
- if self_loss is not None:
- if need_idx.shape[0] < self_loss.shape[0] + start + 1:
- need_idx = torch.cat(
- [
- need_idx,
- torch.ones(
- self_loss.shape[0] - need_idx.shape[0] + start + 1,
- dtype=torch.bool,
- ).to(need_idx.device),
- ]
- )
- self_loss = self_loss[need_idx[start:-1]]
-
- if need_idx.shape[0] < input_ids.shape[1]:
- need_idx = torch.cat(
- [
- need_idx,
- torch.ones(
- input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
- ).to(need_idx.device),
- ]
- )
- elif need_idx.shape[0] > input_ids.shape[1]:
- need_idx = need_idx[: input_ids.shape[1]]
-
- if keep_flag is not None:
- need_idx[keep_flag == 1] = 1
- last = -1
- if keep_flag is not None:
- for ii in range(max(0, end - iterative_size), end):
- if need_idx[ii] != 1:
- continue
- now = input_ids[0][ii].detach().cpu().item()
- if (
- now == split_token_id
- and last == split_token_id
- and keep_flag[ii].detach().cpu().item() == 0
- ):
- need_idx[ii] = 0
- else:
- last = now
- compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
- compressed_attention_mask = attention_mask[attention_mask == 1][
- need_idx
- ].unsqueeze(0)
-
- if self_loss is not None:
- self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
- need_idx[start:]
- ].unsqueeze(0)
- self_compressed_attention_mask = self_attention_mask[
- self_attention_mask == 1
- ][need_idx[start:]].unsqueeze(0)
- else:
- self_compressed_input_ids, self_compressed_attention_mask = None, None
- if keep_flag is not None:
- if len(keep_flag) > len(need_idx):
- keep_flag = torch.cat(
- [
- keep_flag[:start],
- keep_flag[start : len(need_idx) + start][need_idx],
- keep_flag[start + len(need_idx) :],
- ]
- )
- else:
- keep_flag = keep_flag[need_idx]
- end -= (need_idx[:end] == 0).sum()
- return (
- compressed_input_ids,
- compressed_attention_mask,
- keep_flag,
- end,
- loss,
- self_loss,
- self_compressed_input_ids,
- self_compressed_attention_mask,
- )
-
- def get_estimate_threshold_base_distribution(
- self, ppl, ratio: float, condition_flag: bool = False
- ):
- if ratio == 1.0:
- return float("-inf")
- ppl = ppl[ppl != 10000]
- target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
- return (
- ppl.sort(descending=not condition_flag)
- .values[target_token]
- .detach()
- .cpu()
- .item()
- )
-
- def iterative_compress_prompt(
- self,
- context: List[str],
- target_token: float,
- iterative_size: int = 200,
- keep_split: bool = False,
- split_token_id: int = 13,
- start: int = 0,
- dynamic_ratio: list = None,
- condition_compare: bool = False,
- segments_info: List[List[tuple]] = None,
- ):
- if segments_info is None or segments_info == []:
- iterative_ratios = self.get_dynamic_compression_ratio(
- context, target_token, iterative_size, dynamic_ratio, start
- )
- else:
- iterative_ratios = self.get_structured_dynamic_compression_ratio(
- context, iterative_size, dynamic_ratio, start, segments_info
- )
- context = "\n\n".join(context)
- tokenized_text = self.tokenizer(
- context, return_tensors="pt", add_special_tokens=False
- )
- input_ids = tokenized_text["input_ids"].to(self.device)
- attention_mask = tokenized_text["attention_mask"].to(self.device)
-
- N = (attention_mask == 1).sum()
- compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
- if condition_compare:
- self_input_ids, self_attention_mask = (
- input_ids[:, start:],
- attention_mask[:, start:],
- )
- self_compressed_input_ids, self_compressed_attention_mask = (
- self_input_ids,
- self_attention_mask,
- )
-
- end = min(iterative_size + start, compressed_input_ids.shape[1])
- threshold, keep_flag = None, None
- if keep_split:
- input_ids_numpy = input_ids.cpu().detach().numpy()[0]
- N = len(input_ids_numpy)
- keep_flag = [
- int(
- (
- ii > 0
- and input_ids_numpy[ii] == split_token_id
- and input_ids_numpy[ii - 1] == split_token_id
- )
- or (
- ii < N - 1
- and input_ids_numpy[ii] == split_token_id
- and input_ids_numpy[ii + 1] == split_token_id
- )
- )
- for ii in range(N)
- ]
- keep_flag = torch.tensor(keep_flag).to(self.device)
- past_key_values, past_loss, ready_end = None, None, 0
- self_past_key_values, self_past_loss, self_ready_end = None, None, 0
- pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
- idx = 0
- while end <= compressed_input_ids.shape[1]:
- if end > self.max_position_embeddings and past_key_values is not None:
- # KV-Cache Compression
- e, s = end - self.max_position_embeddings, min(
- self.cache_bos_num + start, self.max_position_embeddings
- )
- if pop_compressed_input_ids is None:
- pop_compressed_input_ids = compressed_input_ids[:, :e]
- else:
- pop_compressed_input_ids = torch.cat(
- [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
- )
- compressed_input_ids = compressed_input_ids[:, e:]
- compressed_attention_mask = compressed_attention_mask[:, e:]
- past_key_values = [
- [
- torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
- torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
- ]
- for k, v in past_key_values
- ]
- if keep_flag is not None:
- keep_flag = keep_flag[e:]
- end, ready_end = end - e, ready_end - e
- if condition_compare:
- s = min(s, self_past_key_values[0][0].shape[2] - e)
- self_ready_end -= e
- if pop_self_compressed_input_ids is None:
- pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
- else:
- pop_self_compressed_input_ids = torch.cat(
- [
- pop_self_compressed_input_ids,
- self_compressed_input_ids[:, :e],
- ],
- dim=-1,
- )
- self_compressed_input_ids = self_compressed_input_ids[:, e:]
- self_compressed_attention_mask = self_compressed_attention_mask[
- :, e:
- ]
- self_past_key_values = [
- [
- torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
- torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
- ]
- for k, v in self_past_key_values
- ]
-
- loss, past_key_values = self.get_ppl(
- "",
- "token",
- compressed_input_ids,
- compressed_attention_mask,
- past_key_values=past_key_values,
- return_kv=True,
- end=end if idx else None,
- )
- if loss.shape[0] == 0:
- break
- if past_loss is not None:
- if end - 1 > len(past_loss):
- past_loss = torch.cat(
- [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
- )
- past_loss[ready_end : end - 1] = loss
- loss = past_loss
- else:
- past_loss = loss
- if idx:
- past_key_values = [
- [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
- for k, v in past_key_values
- ]
- else:
- past_key_values = None
-
- if condition_compare:
- self_loss, self_past_key_values = self.get_ppl(
- "",
- "token",
- self_compressed_input_ids,
- self_compressed_attention_mask,
- past_key_values=self_past_key_values,
- return_kv=True,
- end=end - start if idx else None,
- )
- if self_past_loss is not None:
- if end - start - 1 > len(self_past_loss):
- self_past_loss = torch.cat(
- [
- self_past_loss,
- torch.zeros_like(self_loss)[
- : end - 1 - start - len(self_past_loss)
- ],
- ]
- )
- self_past_loss[self_ready_end : end - start - 1] = self_loss
- self_loss = self_past_loss
- else:
- self_past_loss = self_loss
- if idx:
- self_past_key_values = [
- [
- k[:, :, : end - iterative_size - start],
- v[:, :, : end - iterative_size - start],
- ]
- for k, v in self_past_key_values
- ]
- else:
- self_past_key_values = None
-
- self_ready_end = (
- end - start - iterative_size if not (start and idx == 0) else 0
- )
- ready_end = end - iterative_size if not (start and idx == 0) else 0
-
- for delta_end, ratio in iterative_ratios[idx]:
- loss = past_loss
- if condition_compare:
- self_loss = self_past_loss
- threshold = self.get_estimate_threshold_base_distribution(
- self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
- )
- else:
- threshold = self.get_estimate_threshold_base_distribution(
- loss, ratio, False
- )
-
- (
- compressed_input_ids,
- compressed_attention_mask,
- keep_flag,
- end,
- past_loss,
- self_past_loss,
- self_compressed_input_ids,
- self_compressed_attention_mask,
- ) = self.get_compressed_input(
- loss,
- compressed_input_ids,
- compressed_attention_mask,
- end - iterative_size + delta_end,
- iterative_size=delta_end,
- threshold=threshold,
- keep_flag=keep_flag,
- split_token_id=split_token_id,
- start=start,
- self_loss=self_loss if condition_compare else None,
- self_input_ids=(
- self_compressed_input_ids if condition_compare else None
- ),
- self_attention_mask=(
- self_compressed_attention_mask if condition_compare else None
- ),
- )
- end += iterative_size
- idx += 1
- if pop_compressed_input_ids is not None:
- compressed_input_ids = torch.cat(
- [pop_compressed_input_ids, compressed_input_ids], dim=-1
- )
- return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
-
- def recover(
- self,
- original_prompt: str,
- compressed_prompt: str,
- response: str,
- ):
- def match_from_compressed(response_word):
- response_input_ids = self.tokenizer(
- response_word, add_special_tokens=False
- )["input_ids"]
- response_set, response_c = set(response_input_ids), defaultdict(list)
- for idx in range(M):
- if original_input_ids[idx] in response_set:
- response_c[original_input_ids[idx]].append(idx)
- res, res_min, res_c = None, float("inf"), 1
- n = len(response_input_ids)
- for l in response_c[response_input_ids[0]]:
- x, y, c = 0, l, 1
- for x in range(1, n):
- idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
- if (
- idx >= len(response_c[response_input_ids[x]])
- or response_c[response_input_ids[x]][idx] - y > 10
- ):
- continue
- c += 1
- y = response_c[response_input_ids[x]][idx]
- if c > res_c:
- res_c = c
- res_min = y - l + 1
- res = (l, y + 1)
- elif c == res_c and y - l + 1 < res_min:
- res_min = y - l + 1
- res = (l, y + 1)
-
- if res is None:
- return response_word
- # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
- # l -= 1
- # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
- # l -= 1
- return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
-
- response_words = response.split(" ")
-
- original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
- "input_ids"
- ]
- N, M = len(response_words), len(original_input_ids)
- recovered_response_words = []
- l = 0
- while l < N:
- if response_words[l] not in compressed_prompt:
- recovered_response_words.append(response_words[l])
- l += 1
- continue
- r = l
- while (
- r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
- ):
- r += 1
-
- match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
- recovered_response_words.append(match_words)
- l = r + 1
- return " ".join(recovered_response_words)
-
- def get_rank_results(
- self,
- context: list,
- question: str,
- rank_method: str,
- condition_in_question: str,
- context_tokens_length: list,
- ):
- def get_distance_bm25(corpus, query):
- from rank_bm25 import BM25Okapi
-
- tokenized_corpus = [doc.split(" ") for doc in corpus]
- bm25 = BM25Okapi(tokenized_corpus)
- tokenized_query = query.split(" ")
- doc_scores = bm25.get_scores(tokenized_query)
- idx = [(ii, 0) for ii in (-doc_scores).argsort()]
- return idx
-
- def get_distance_gzip(corpus, query):
- def get_score(x, y):
- cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
- cxy = len(gzip.compress(f"{x} {y}".encode()))
- return (cxy - min(cx, cy)) / max(cx, cy)
-
- import gzip
-
- doc_scores = [get_score(doc, query) for doc in corpus]
- idx = [(ii, 0) for ii in np.argsort(doc_scores)]
- return idx
-
- def get_distance_sentbert(corpus, query):
- from sentence_transformers import SentenceTransformer, util
-
- if self.retrieval_model is None or self.retrieval_model_name != rank_method:
- self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
- self.retrieval_model_name = rank_method
- doc_embeds = self.retrieval_model.encode(corpus)
- query = self.retrieval_model.encode(query)
- doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
- idx = [(ii, 0) for ii in np.argsort(doc_scores)]
- return idx
-
- def get_distance_openai(corpus, query):
- import openai
- from sentence_transformers import util
-
- openai.api_key = self.open_api_config.get("api_key", "")
- openai.api_base = self.open_api_config.get(
- "api_base", "https://api.openai.com/v1"
- )
- openai.api_type = self.open_api_config.get("api_type", "open_ai")
- openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
- engine = self.open_api_config.get("engine", "text-embedding-ada-002")
-
- def get_embed(text):
- return openai.Embedding.create(
- input=[text.replace("\n", " ")], engine=engine
- )["data"][0]["embedding"]
-
- doc_embeds = [get_embed(i) for i in corpus]
- query = get_embed(query)
- doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
- idx = [(ii, 0) for ii in np.argsort(doc_scores)]
- return idx
-
- def get_distance_sentbert_bge(corpus, query):
- from sentence_transformers import SentenceTransformer, util
-
- if self.retrieval_model is None or self.retrieval_model_name != rank_method:
- self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
- self.retrieval_model_name = rank_method
- doc_embeds = self.retrieval_model.encode(
- [i for i in corpus], normalize_embeddings=True
- )
- query = self.retrieval_model.encode(query, normalize_embeddings=True)
- doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
- idx = [(ii, 0) for ii in np.argsort(doc_scores)]
- return idx
-
- def get_distance_bge_ranker(corpus, query):
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
-
- pairs = [[i, query] for i in corpus]
- if self.retrieval_model is None or self.retrieval_model_name != rank_method:
- tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
- model = (
- AutoModelForSequenceClassification.from_pretrained(
- "BAAI/bge-reranker-large"
- )
- .eval()
- .to(self.device)
- )
- self.retrieval_model = [tokenizer, model]
- self.retrieval_model_name = rank_method
- with torch.no_grad():
- inputs = self.retrieval_model[0](
- pairs,
- padding=True,
- truncation=True,
- return_tensors="pt",
- max_length=512,
- ).to(self.device)
- scores = (
- self.retrieval_model[1](**inputs, return_dict=True)
- .logits.view(
- -1,
- )
- .float()
- )
- idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
- return idx
-
- def get_distance_bge_llmembedder(corpus, query):
- from transformers import AutoModel, AutoTokenizer
-
- if self.retrieval_model is None or self.retrieval_model_name != rank_method:
- tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
- model = (
- AutoModel.from_pretrained("BAAI/llm-embedder")
- .eval()
- .to(self.device)
- )
- self.retrieval_model = [tokenizer, model]
- self.retrieval_model_name = rank_method
-
- instruction_qa_query = (
- "Represent this query for retrieving relevant documents: "
- )
- instruction_qa_key = "Represent this document for retrieval: "
- queries = [instruction_qa_query + query for _ in corpus]
- keys = [instruction_qa_key + key for key in corpus]
- with torch.no_grad():
- query_inputs = self.retrieval_model[0](
- queries,
- padding=True,
- truncation=True,
- return_tensors="pt",
- max_length=512,
- ).to(self.device)
- key_inputs = self.retrieval_model[0](
- keys,
- padding=True,
- truncation=True,
- return_tensors="pt",
- max_length=512,
- ).to(self.device)
- query_outputs = self.retrieval_model[1](**query_inputs)
- key_outputs = self.retrieval_model[1](**key_inputs)
- # CLS pooling
- query_embeddings = query_outputs.last_hidden_state[:, 0]
- key_embeddings = key_outputs.last_hidden_state[:, 0]
- # Normalize
- query_embeddings = torch.nn.functional.normalize(
- query_embeddings, p=2, dim=1
- )
- key_embeddings = torch.nn.functional.normalize(
- key_embeddings, p=2, dim=1
- )
- similarity = query_embeddings @ key_embeddings.T
- idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
- return idx
-
- def get_distance_jinza(corpus, query):
- from numpy.linalg import norm
-
- from transformers import AutoModel
-
- def cos_sim(a, b):
- return (a @ b.T) / (norm(a) * norm(b))
-
- if self.retrieval_model is None or self.retrieval_model_name != rank_method:
- model = (
- AutoModel.from_pretrained(
- "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
- )
- .eval()
- .to(self.device)
- )
- self.retrieval_model = model
- self.retrieval_model_name = rank_method
-
- doc_embeds = self.retrieval_model.encode(corpus)
- query = self.retrieval_model.encode(query)
- doc_scores = cos_sim(doc_embeds, query)
- idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
- return idx
-
- def get_distance_voyageai(corpus, query):
- import voyageai
- from sentence_transformers import util
-
- voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
-
- def get_embed(text):
- return voyageai.get_embedding(text, model="voyage-01")
-
- doc_embeds = [get_embed(i) for i in corpus]
- query = get_embed(query)
- doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
- idx = [(ii, 0) for ii in np.argsort(doc_scores)]
- return idx
-
- def get_distance_cohere(corpus, query):
- import cohere
-
- api_key = self.open_api_config.get("cohere_api_key", "")
- co = cohere.Client(api_key)
- results = co.rerank(
- model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
- )
- c_map = {jj: ii for ii, jj in enumerate(corpus)}
- doc_rank = [c_map[ii.document["text"]] for ii in results]
- idx = [(ii, 0) for ii in doc_rank]
- return idx
-
- def get_distance_longllmlingua(corpus, query):
- context_ppl = [
- self.get_condition_ppl(
- d,
- query
- + " We can get the answer to this question in the given documents.",
- condition_in_question,
- )
- - dl * 2 / 250 * 0
- for d, dl in zip(corpus, context_tokens_length)
- ]
- sort_direct = -1 if condition_in_question == "none" else 1
- ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
- return ys
-
- method = None
- if rank_method == "bm25":
- method = get_distance_bm25
- elif rank_method == "gzip":
- method = get_distance_gzip
- elif rank_method == "sentbert":
- method = get_distance_sentbert
- elif rank_method == "openai":
- method = get_distance_openai
- elif rank_method in ["longllmlingua", "llmlingua"]:
- method = get_distance_longllmlingua
- elif rank_method == "bge":
- method = get_distance_sentbert_bge
- elif rank_method == "bge_reranker":
- method = get_distance_bge_ranker
- elif rank_method == "bge_llmembedder":
- method = get_distance_bge_llmembedder
- elif rank_method == "jinza":
- method = get_distance_jinza
- elif rank_method == "voyageai":
- method = get_distance_voyageai
- elif rank_method == "cohere":
- method = get_distance_cohere
- return method(context, question)
-
- def segment_structured_context(
- self,
- context: List[str],
- global_rate: float,
- ):
- new_context, context_segs, context_segs_rate, context_segs_compress = (
- [],
- [],
- [],
- [],
- )
- for text in context:
- if not text.startswith(""):
- text = text + ""
-
- # Regular expression to match content, allowing rate and compress in any order
- pattern = r"([^<]+)"
- matches = re.findall(pattern, text)
-
- # Extracting segment contents
- segments = [match[4] for match in matches]
-
- # Extracting rate and compress, considering their possible positions
- segs_rate = [
- float(match[0]) if match[0] else (float(match[2]) if match[2] else None)
- for match in matches
- ]
- segs_compress = [
- (
- match[1] == "True"
- if match[1]
- else (match[3] == "True" if match[3] else None)
- )
- for match in matches
- ]
-
- segs_compress = [
- compress if compress is not None else True for compress in segs_compress
- ]
- segs_rate = [
- rate if rate else (global_rate if compress else 1.0)
- for rate, compress in zip(segs_rate, segs_compress)
- ]
- assert (
- len(segments) == len(segs_rate) == len(segs_compress)
- ), "The number of segments, rates, and compress flags should be the same."
- assert all(
- seg_rate <= 1.0 for seg_rate in segs_rate
- ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
-
- new_context.append("".join(segments))
- context_segs.append(segments)
- context_segs_rate.append(segs_rate)
- context_segs_compress.append(segs_compress)
-
- return new_context, context_segs, context_segs_rate, context_segs_compress
-
- def concate_segment_info(
- self,
- segment_info: List[List[tuple]],
- ):
- new_segment_info = []
- for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info):
- if (
- new_segment_info
- and new_segment_info[-1][1] == seg_ratio
- and new_segment_info[-1][2] == seg_compress
- ):
- new_segment_info[-1] = (
- new_segment_info[-1][0] + seg_len,
- seg_ratio,
- seg_compress,
- )
- else:
- new_segment_info.append((seg_len, seg_ratio, seg_compress))
- return new_segment_info
-
- def __get_context_prob(
- self,
- context_list: list,
- token_to_word="mean",
- force_tokens: List[str]=[],
- token_map: dict={},
- force_reserve_digit: bool=False,
- ):
- chunk_list = []
- for chunks in context_list:
- for c in chunks:
- chunk_list.append(c)
-
- dataset = TokenClfDataset(
- chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
- )
- dataloader = DataLoader(
- dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
- )
-
- chunk_probs = []
- chunk_words = []
- with torch.no_grad():
- for batch in dataloader:
- ids = batch["ids"].to(self.device, dtype=torch.long)
- mask = batch["mask"].to(self.device, dtype=torch.long) == 1
-
- outputs = self.model(input_ids=ids, attention_mask=mask)
- loss, logits = outputs.loss, outputs.logits
- probs = F.softmax(logits, dim=-1)
-
- for j in range(ids.shape[0]):
- _probs = probs[j, :, 1]
- _ids = ids[j]
- _mask = mask[j]
-
- active_probs = torch.masked_select(_probs, _mask)
- active_ids = torch.masked_select(_ids, _mask)
-
- tokens = self.tokenizer.convert_ids_to_tokens(
- active_ids.squeeze().tolist()
- )
- token_probs = [prob for prob in active_probs.cpu().numpy()]
-
- (
- words,
- valid_token_probs,
- valid_token_probs_no_force,
- ) = self.__merge_token_to_word(
- tokens,
- token_probs,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- )
- word_probs_no_force = self.__token_prob_to_word_prob(
- valid_token_probs_no_force, convert_mode=token_to_word
- )
-
- if "xlm-roberta-large" in self.model_name:
- for i in range(len(words)):
- words[i] = words[i].lstrip("▁")
- chunk_words.append(words)
- chunk_probs.append(word_probs_no_force)
-
- prev_idx = 0
- context_probs = []
- context_words = []
- for chunk_list in context_list:
- n_chunk = len(chunk_list)
- context_probs.append([])
- context_words.append([])
- for i in range(n_chunk):
- context_probs[-1].extend(chunk_probs[prev_idx + i])
- context_words[-1].extend(chunk_words[prev_idx + i])
- prev_idx = prev_idx + n_chunk
- context_probs = [sum(probs) / len(probs) for probs in context_probs]
- return context_probs, context_words
-
- def __chunk_context(self, origin_text, chunk_end_tokens):
- origin_list = []
- origin_tokens = self.tokenizer.tokenize(origin_text)
- n = len(origin_tokens)
- st = 0
- while st < n:
- if st + self.max_seq_len > n - 1:
- chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n])
- origin_list.append(chunk)
- break
- else:
- ed = st + self.max_seq_len
- for j in range(0, ed - st):
- if origin_tokens[ed - j] in chunk_end_tokens:
- ed = ed - j
- break
- chunk = self.tokenizer.convert_tokens_to_string(
- origin_tokens[st : ed + 1]
- )
- origin_list.append(chunk)
- st = ed + 1
- return origin_list
-
- def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit):
- words = []
- word_probs = []
- word_probs_no_force = []
-
- for token, prob in zip(tokens, token_probs):
- if token in self.special_tokens:
- continue
- # add a new word
- elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map):
- pure_token = get_pure_token(token, self.model_name)
- prob_no_force = prob
- if pure_token in force_tokens or pure_token in set(token_map.values()):
- prob=1.0
- token = replace_added_token(token, token_map)
- words.append(token)
- word_probs.append(
- [
- 1.0
- if force_reserve_digit
- and bool(re.search(r"\d", token))
- else prob
- ]
- )
- word_probs_no_force.append([prob_no_force])
- # concatenate with previous token
- else:
- pure_token = get_pure_token(token, self.model_name)
- words[-1] += pure_token
- word_probs[-1].append(
- 1.0
- if force_reserve_digit
- and bool(re.search(r"\d", token))
- else prob
- )
- word_probs_no_force[-1].append(prob_no_force)
-
- return words, word_probs, word_probs_no_force
-
- def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"):
- if convert_mode == "mean":
- word_probs = [sum(p) / len(p) for p in token_probs]
- elif convert_mode == "first":
- word_probs = [p[0] for p in token_probs]
- else:
- raise NotImplementedError()
-
- return word_probs
-
- def __compress(
- self,
- context_list: list,
- reduce_rate: float=0.5,
- token_to_word: str="mean",
- force_tokens: List[str]=[],
- token_map: dict={},
- force_reserve_digit: bool=False,
- drop_consecutive: bool=False,
- ):
- def split_string_to_words(input_string):
- pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]'
- result = re.findall(pattern, input_string)
- return result
- # print(force_tokens, token_map, force_reserve_digit, drop_consecutive)
- if reduce_rate <= 0:
- words, word_labels = [], []
- for i in range(len(context_list)):
- chunk_list = context_list[i]
- chunk_words = []
- chunk_word_labels = []
- for j in range(len(chunk_list)):
- # replace to original token
- for ori_token, new_token in token_map.items():
- chunk_list[j] = chunk_list[j].replace(new_token, ori_token)
- ws = split_string_to_words(chunk_list[j])
- chunk_words.extend(ws)
- chunk_word_labels.extend([1 for _ in range(len(ws))])
- context_list[i] = "".join(chunk_list)
- words.append(chunk_words)
- word_labels.append(chunk_word_labels)
- return context_list, words, word_labels
-
- chunk_list = []
- for chunks in context_list:
- for c in chunks:
- chunk_list.append(c)
-
- dataset = TokenClfDataset(
- chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
- )
- dataloader = DataLoader(
- dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
- )
-
- compressed_chunk_list = []
- word_list = []
- word_label_list = []
- with torch.no_grad():
- for batch in dataloader:
- ids = batch["ids"].to(self.device, dtype=torch.long)
- mask = batch["mask"].to(self.device, dtype=torch.long) == 1
-
- outputs = self.model(input_ids=ids, attention_mask=mask)
- loss, logits = outputs.loss, outputs.logits
- probs = F.softmax(logits, dim=-1)
-
- for j in range(ids.shape[0]):
- chunk_probs = probs[j, :, 1]
- chunk_ids = ids[j]
- chunk_mask = mask[j]
-
- active_probs = torch.masked_select(chunk_probs, chunk_mask)
- active_ids = torch.masked_select(chunk_ids, chunk_mask)
-
- tokens = self.tokenizer.convert_ids_to_tokens(
- active_ids.squeeze().tolist()
- )
- token_probs = [prob for prob in active_probs.cpu().numpy()]
-
- words, valid_token_probs, _ = self.__merge_token_to_word(
- tokens=tokens,
- token_probs=token_probs,
- force_tokens=force_tokens,
- token_map=token_map,
- force_reserve_digit=force_reserve_digit,
- )
- word_probs = self.__token_prob_to_word_prob(
- valid_token_probs, convert_mode=token_to_word
- )
-
- if drop_consecutive:
- threshold = np.percentile(word_probs, int(100 * reduce_rate))
- is_token_between = False
- prev = None
- for i, (word, word_prob) in enumerate(zip(words, word_probs)):
- if word in force_tokens:
- if is_token_between:
- is_token_between = False
- elif not is_token_between and word == prev:
- word_probs[i] = 0.0
- prev = word
- else:
- is_token_between |= word_prob > threshold
-
- # calculate compression ratio w.r.t. gpt-4 tokenizer
- new_token_probs = []
- for word, word_prob in zip(words, word_probs):
- num_token = len(self.oai_tokenizer.encode(word))
- new_token_probs.extend([word_prob for _ in range(num_token)])
- threshold = np.percentile(
- new_token_probs, int(100 * reduce_rate + 1)
- )
-
- keep_words = []
- word_labels = []
- assert len(words) == len(word_probs)
- for word, word_porb in zip(words, word_probs):
- if word_porb > threshold:
- if (
- drop_consecutive
- and word in force_tokens
- and len(keep_words) > 0
- and keep_words[-1] == word
- ):
- word_labels.append(0)
- else:
- keep_words.append(word)
- word_labels.append(1)
- else:
- word_labels.append(0)
- keep_str = self.tokenizer.convert_tokens_to_string(keep_words)
- if "xlm-roberta-large" in self.model_name:
- for i in range(len(words)):
- words[i] = words[i].lstrip("▁")
-
- compressed_chunk_list.append(keep_str)
- word_list.append(words[:])
- word_label_list.append(word_labels[:])
-
- compressed_context_list = []
- original_word_list = []
- original_word_label_list = []
- prev_idx = 0
- for chunk_list in context_list:
- n_chunk = len(chunk_list)
- compressed_context_list.append(
- "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk])
- )
- original_word_list.append([])
- original_word_label_list.append([])
- for i in range(n_chunk):
- original_word_list[-1].extend(word_list[prev_idx + i])
- original_word_label_list[-1].extend(word_label_list[prev_idx + i])
- prev_idx = prev_idx + n_chunk
-
- return compressed_context_list, original_word_list, original_word_label_list