o
    hE\                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZmZ d dl	m
Z
mZ d dlmZ d dlmZ d dlZd dlmZmZ ddlmZ e d	krPd dlZe red d
lmZ d dlmZ d dlmZ e rxd dlZd dlm Z m!Z!m"Z"m#Z# e$ej%ej& Z'e$ej%ej( ej& d Z)dZ*g dZ+ddiddiddiddiddiddidZ,dd Z-dd Z.dd Z/dd  Z0d!d" Z1d#ed$ fd%d&Z2d'd( Z3d)d* Z4G d+d, d,Z5e
G d-d. d.Z6d/efd0d1Z7G d2d3 d3eZ8dS )4    N)ArgumentParser	Namespace)	dataclassfield)Thread)Optional)is_rich_availableis_torch_available   )BaseTransformersCLICommandWindows)Console)Live)Markdown)AutoModelForCausalLMAutoTokenizerBitsAndBytesConfigTextIteratorStreamerz .!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~a#  
**TRANSFORMERS CHAT INTERFACE**

The chat interface is a simple tool to try out a chat model.

Besides talking to the model there are several commands:
- **help**: show this help message
- **clear**: clears the current conversation and start a new one
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **exit**: closes the interface
)max_new_tokens	do_sample	num_beamstemperaturetop_ptop_krepetition_penaltytextz5There is a Llama in my lawn, how can I get rid of it?zyWrite a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].z4How many helicopters can a human eat in one sitting?z4Count to 10 but skip every number ending with an 'e'zWhy aren't birds real?z2Why is it important to eat socks after meditating?)llamacode
helicopternumbersbirdssocksc                   C   s$   t  dkr
t S tt jS )Nr   )platformsystemosgetloginpwdgetpwuidgetuidpw_name r*   r*   n/var/www/html/construction_image-detection-poc/venv/lib/python3.10/site-packages/transformers/commands/chat.pyget_usernameZ   s   r,   c                 C   s   t d}|  d| dS )Nz%Y-%m-%d_%H-%M-%Sz/chat_z.json)timestrftime)
model_nametime_strr*   r*   r+   create_default_filenamea   s   
r1   c                 C   s   i }t ||d< | |d< |j}|d u rt|j}tj||}tjtj|dd t	|d}t
j||dd W d    n1 sBw   Y  tj|S )Nsettingschat_historyT)exist_okw   )indent)varssave_folderr1   model_name_or_pathr$   pathjoinmakedirsdirnameopenjsondumpabspath)chatargsfilenameoutput_dictfolderfr*   r*   r+   	save_chatf   s   
rI   c                 C   s    | d u rg }|S d| dg}|S )Nr#   rolecontentr*   )system_promptrC   r*   r*   r+   clear_chat_historyw   s
   rN   c                 C   sR  | dd    d}dd |D }t|}d}|D ]_}t||rpz/tt||trB|| dkr5d||< n|| dkr@d||< nttt|||| ||< W q tyo   |	d	| d
||  dtt|| d Y qw |	d| d q|r|	d |dfS |D ]}t
||||  |d| d||  d qtd |dfS )Nr6   ;c                 S   s8   g | ]}| d d |t| d d d d fqS )=r   r
   N)splitlen).0settingr*   r*   r+   
<listcomp>   s   8 z"parse_settings.<locals>.<listcomp>FTrueTFalsezCannot cast setting z (=z) to .zThere is no 'z
' setting.zGThere was an issue parsing the settings. No settings have been changed.zSet z to g      ?)striprQ   dicthasattr
isinstancegetattrbool
ValueErrortype	print_redsetattrprint_greenr-   sleep)
user_inputcurrent_args	interfacer2   errornamer*   r*   r+   parse_settings   s:   


$

rj   returnr   c                 C   s@   | j rtd| j| j| j| jd}|S | jrtdd}|S d }|S )NT)load_in_4bitbnb_4bit_compute_dtypebnb_4bit_quant_typebnb_4bit_use_double_quantbnb_4bit_quant_storage)load_in_8bit)rl   r   torch_dtypern   use_bnb_nested_quantrq   )
model_argsquantization_configr*   r*   r+   get_quantization_config   s    rv   c                 C   s   t j| j| j| jd}| jdv r| jntt| j}t| }| j| j	|d|d}t
j| jfd| ji|}t|dd d u rB|| j}||fS )N)revisiontrust_remote_code)autoNry   )rw   attn_implementationrr   
device_mapru   rx   hf_device_map)r   from_pretrainedr:   model_revisionrx   rr   r]   torchrv   rz   r   todevice)rD   	tokenizerrr   ru   model_kwargsmodelr*   r*   r+   load_model_and_tokenizer   s.   r   c                 C   s|   | j d u r	| j}n| j }g }|d ur|| |d |d ur.|dd |dD  t|dkr:|| j ||fS )N,c                 S   s   g | ]}t |qS r*   )int)rS   token_idr*   r*   r+   rU      s    z$parse_eos_tokens.<locals>.<listcomp>r   )pad_token_ideos_token_idextendconvert_tokens_to_idsrQ   rR   append)r   
eos_tokenseos_token_idsr   all_eos_token_idsr*   r*   r+   parse_eos_tokens   s   
r   c                   @   sN   e Zd ZdddZdd Zdd Zdd	 Zd
d Zdd Zdd Z	dd Z
dS )RichInterfaceNc                 C   s:   t  | _|d u rd| _n|| _|d u rd| _d S || _d S )N	assistantuser)r   _consoler/   	user_name)selfr/   r   r*   r*   r+   __init__   s   

zRichInterface.__init__c           	      C   s   d}| j d| j d t| j ddJ}t|D ]=\}}|r#|dkr$q||7 }g }| D ]}|| |dr@|d q.|d	 q.td	|
 d
d}|| qW d   n1 saw   Y  | j   |S )zStream output from a role. z[bold blue]<z>:r6   )consolerefresh_per_secondr   z```
z  
zgithub-dark)
code_themeN)r   printr/   r   	enumerate
splitlinesr   
startswithr   r<   rY   update)	r   output_streamr   liveioutputslineslinemarkdownr*   r*   r+   stream_output   s&   


zRichInterface.stream_outputc                 C   s$   | j d| j d}| j   |S )N[bold red]<z>:
)r   inputr   r   )r   r   r*   r*   r+   r     s   
zRichInterface.inputc                 C   s   | j   d S N)r   clearr   r*   r*   r+   r     s   zRichInterface.clearc                 C   s(   | j d| j d|  | j   d S )Nr   z>:[/ bold red]
)r   r   r   r   r   r*   r*   r+   print_user_message  s   z RichInterface.print_user_messagec                 C       | j d|  | j   d S )Nz[bold green]r   r   r   r*   r*   r+   rc   #     zRichInterface.print_greenc                 C   r   )Nz
[bold red]r   r   r*   r*   r+   ra   '  r   zRichInterface.print_redc                 C   s   | j tt | j   d S r   )r   r   r   HELP_STRINGr   r*   r*   r+   
print_help+  s   zRichInterface.print_help)NN)__name__
__module____qualname__r   r   r   r   r   rc   ra   r   r*   r*   r*   r+   r      s    
&r   c                   @   s`  e Zd ZU dZeddidZeed< edddidZe	e ed	< eddd
idZ
e	e ed< edddidZeed< edddidZeed< edddidZe	e ed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< ed dd!idZeed"< eddd#idZeed$< eddd%idZeed&< eddd'idZe	e ed(< eddd)idZe	e ed*< ed+dd,idZeed-< ed.d/g d0d1dZe	e ed2< ed3dd4idZeed5< eddd6idZe	e ed7< ed3dd8idZeed9< ed3dd:idZeed;< ed<d=d>d<gd1dZ eed?< ed3dd@idZ!eedA< dS )BChatArgumentsa%  
    Arguments for the chat script.

    Args:
        model_name_or_path (`str`):
            Name of the pre-trained model.
        user (`str` or `None`, *optional*, defaults to `None`):
            Username to display in chat interface.
        system_prompt (`str` or `None`, *optional*, defaults to `None`):
            System prompt.
        save_folder (`str`, *optional*, defaults to `"./chat_history/"`):
            Folder to save chat history.
        device (`str`, *optional*, defaults to `"cpu"`):
            Device to use for inference.
        examples_path (`str` or `None`, *optional*, defaults to `None`):
            Path to a yaml file with examples.
        max_new_tokens (`int`, *optional*, defaults to `256`):
            Maximum number of tokens to generate.
        do_sample (`bool`, *optional*, defaults to `True`):
            Whether to sample outputs during generation.
        num_beams (`int`, *optional*, defaults to `1`):
            Number of beams for beam search.
        temperature (`float`, *optional*, defaults to `1.0`):
            Temperature parameter for generation.
        top_k (`int`, *optional*, defaults to `50`):
            Value of k for top-k sampling.
        top_p (`float`, *optional*, defaults to `1.0`):
            Value of p for nucleus sampling.
        repetition_penalty (`float`, *optional*, defaults to `1.0`):
            Repetition penalty.
        eos_tokens (`str` or `None`, *optional*, defaults to `None`):
            EOS tokens to stop the generation. If multiple they should be comma separated.
        eos_token_ids (`str` or `None`, *optional*, defaults to `None`):
            EOS token IDs to stop the generation. If multiple they should be comma separated.
        model_revision (`str`, *optional*, defaults to `"main"`):
            Specific model version to use (can be a branch name, tag name or commit id).
        torch_dtype (`str` or `None`, *optional*, defaults to `None`):
            Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype
            will be automatically derived from the model's weights.
        trust_remote_code (`bool`, *optional*, defaults to `False`):
            Whether to trust remote code when loading a model.
        attn_implementation (`str` or `None`, *optional*, defaults to `None`):
            Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case
            you must install this manually by running `pip install flash-attn --no-build-isolation`.
        load_in_8bit (`bool`, *optional*, defaults to `False`):
            Whether to use 8 bit precision for the base model - works only with LoRA.
        load_in_4bit (`bool`, *optional*, defaults to `False`):
            Whether to use 4 bit precision for the base model - works only with LoRA.
        bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
            Quantization type.
        use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
            Whether to use nested quantization.
    helpzName of the pre-trained model.)metadatar:   Nz&Username to display in chat interface.)defaultr   r   zSystem prompt.rM   z./chat_history/zFolder to save chat history.r9   cpuzDevice to use for inference.r   z"Path to a yaml file with examples.examples_path   z%Maximum number of tokens to generate.r   Tz,Whether to sample outputs during generation.r   r
   z Number of beams for beam search.r   g      ?z%Temperature parameter for generation.r   2   zValue of k for top-k sampling.r   z Value of p for nucleus sampling.r   zRepetition penalty.r   zNEOS tokens to stop the generation. If multiple they should be comma separated.r   zQEOS token IDs to stop the generation. If multiple they should be comma separated.r   mainzLSpecific model version to use (can be a branch name, tag name or commit id).r~   ry   zOverride the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights.)ry   bfloat16float16float32)r   choicesrr   Fz2Whether to trust remote code when loading a model.rx   zWhich attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`.rz   zIWhether to use 8 bit precision for the base model - works only with LoRA.rq   zIWhether to use 4 bit precision for the base model - works only with LoRA.rl   nf4zQuantization type.fp4rn   z#Whether to use nested quantization.rs   )"r   r   r   __doc__r   r:   str__annotations__r   r   rM   r9   r   r   r   r   r   r^   r   r   floatr   r   r   r   r   r~   rr   rx   rz   rq   rl   rn   rs   r*   r*   r*   r+   r   0  sf   
 7r   rD   c                 C   s   t | S )z;
    Factory function used to chat with a local model.
    )ChatCommandrD   r*   r*   r+   chat_command_factory  s   r   c                   @   sD   e Zd ZedefddZdd Zededefdd	Z	d
d Z
dS )r   parserc                 C   s&   t f}| jdt|d}|jtd dS )z
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        rC   )r   dataclass_types)funcN)r   
add_parserr   set_defaultsr   )r   r   chat_parserr*   r*   r+   register_subcommand  s   zChatCommand.register_subcommandc                 C   s
   || _ d S r   r   )r   rD   r*   r*   r+   r     s   
zChatCommand.__init__srk   c                 C   s   |  dr	d| vrdS dd | dd  dD }|D ]7}|ddkr' dS |dd\}}| }| }|r;|s> dS t|tsH dS t|tsR dS qd	S )
Nzset rP   Fc                 S   s   g | ]
}|  r|  qS r*   )rY   )rS   ar*   r*   r+   rU     s    z8ChatCommand.is_valid_setting_command.<locals>.<listcomp>r6   rO   r
   T)r   rQ   countrY   setissubsetALLOWED_KEY_CHARSALLOWED_VALUE_CHARS)r   assignments
assignmentkeyvaluer*   r*   r+   is_valid_setting_command  s"   z$ChatCommand.is_valid_setting_commandc                 C   s  t  stdt std| j}|jd u rt}nt|j}t|}W d    n1 s.w   Y  t	
|}|jd u rAt }n|j}t|\}}t|ddd}t||j|j\}	}
t|j|d}|  t|j}	 z| }|dkrt|j}|  W qk|dkr|  W qk|dkrW d S |d	kr|  t	
|}t|j}W qk|d
rt| dk r| }t|dkr|d }nd }t|||}|d| d W qk| |rt|||\}}|rg }|  W qk|dr.t| dkr.| d }||v r|  g }| || d  || d }n|!d| dt"|#  d W qk|$d|d |j%|ddd&|j'}t()|}||||j*|j+|j,|j-|j.|j/|j0|	|
d}t1|j2|d}|3  |4|}|5  |$d|d W n t6y   Y d S w ql)NzHYou need to install rich to use the chat interface. (`pip install rich`)zJYou need to install torch to use the chat interface. (`pip install torch`)T)skip_special_tokensskip_prompt)r/   r   r   r   exitresetsave   r
   zChat saved in !exampler   zExample z* not found in list of available examples: rX   r   rJ   pt)return_tensorsadd_generation_prompt)inputsattention_maskstreamerr   r   r   r   r   r   r   r   r   )targetkwargsr   )7r   ImportErrorr	   rD   r   DEFAULT_EXAMPLESr?   yaml	safe_loadcopydeepcopyr   r,   r   r   r   r   r   r   r:   r   rN   rM   r   r   r   rR   rQ   rI   rc   r   rj   r   ra   listkeysr   apply_chat_templater   r   r   	ones_liker   r   r   r   r   r   r   r   generatestartr   r<   KeyboardInterrupt)r   rD   examplesrH   rf   r   r   r   generation_streamerr   r   rg   rC   re   split_inputrE   successexample_namer   r   generation_kwargsthreadmodel_outputr*   r*   r+   run  s   











zChatCommand.runN)r   r   r   staticmethodr   r   r   r   r^   r   r  r*   r*   r*   r+   r     s    r   )9r   r@   r$   r"   stringr-   argparser   r   dataclassesr   r   	threadingr   typingr   r   transformers.utilsr   r	   r   r   r#   r&   rich.consoler   	rich.liver   rich.markdownr   r   transformersr   r   r   r   r   ascii_letters
whitespacer   digitsr   r   SUPPORTED_GENERATION_KWARGSr   r,   r1   rI   rN   rj   rv   r   r   r   r   r   r   r*   r*   r*   r+   <module>   sb   %Ks