o
    hK                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZ ddlZ	ddl
mZ ddlmZmZmZmZ ddlmZ dd	lmZ eeZeG d
d dZG dd dZeG dd deZG dd dZG dd deZG dd deZG dd deZG dd deZG dd deeZ dS )zJ
Callbacks to use with the Trainer class and customize the training loop.
    N)	dataclass)OptionalUnion)tqdm   )HPSearchBackendIntervalStrategySaveStrategy
has_length)TrainingArguments)loggingc                   @   sv  e Zd ZU dZdZee ed< dZe	ed< dZ
e	ed< dZe	ed< dZe	ed	< dZe	ed
< dZee	 ed< dZe	ed< dZe	ed< dZeed< dZeeeef  ed< dZee ed< dZee	 ed< dZee ed< dZeed< dZeed< dZeed< dZee ed< dZeeeeee	ef f ed< dZ ed ed< dd Z!defdd Z"e#defd!d"Z$d#d$ Z%d%d& Z&dS )'TrainerStatea  
    A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
    and passed to the [`TrainerCallback`].

    <Tip>

    In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
    step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
    step requires going through *n* batches.

    </Tip>

    Args:
        epoch (`float`, *optional*):
            Only set during training, will represent the epoch the training is at (the decimal part being the
            percentage of the current epoch completed).
        global_step (`int`, *optional*, defaults to 0):
            During training, represents the number of update steps completed.
        max_steps (`int`, *optional*, defaults to 0):
            The number of update steps to do during the current training.
        logging_steps (`int`, *optional*, defaults to 500):
            Log every X updates steps
        eval_steps (`int`, *optional*):
            Run an evaluation every X steps.
        save_steps (`int`, *optional*, defaults to 500):
            Save checkpoint every X updates steps.
        train_batch_size (`int`, *optional*):
            The batch size for the training dataloader. Only needed when
            `auto_find_batch_size` has been used.
        num_input_tokens_seen (`int`, *optional*, defaults to 0):
            When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
            number of prediction tokens).
        total_flos (`float`, *optional*, defaults to 0):
            The total number of floating operations done by the model since the beginning of training (stored as floats
            to avoid overflow).
        log_history (`List[Dict[str, float]]`, *optional*):
            The list of logs done since the beginning of training.
        best_metric (`float`, *optional*):
            When tracking the best model, the value of the best metric encountered so far.
        best_global_step (`int`, *optional*):
            When tracking the best model, the step at which the best metric was encountered.
            Used for setting `best_model_checkpoint`.
        best_model_checkpoint (`str`, *optional*):
            When tracking the best model, the value of the name of the checkpoint for the best model encountered so
            far.
        is_local_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
            several machines) main process.
        is_world_process_zero (`bool`, *optional*, defaults to `True`):
            Whether or not this process is the global main process (when training in a distributed fashion on several
            machines, this is only going to be `True` for one process).
        is_hyper_param_search (`bool`, *optional*, defaults to `False`):
            Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
            impact the way data will be logged in TensorBoard.
        stateful_callbacks (`List[StatefulTrainerCallback]`, *optional*):
            Callbacks attached to the `Trainer` that should have their states be saved or restored.
            Relevant callbacks should implement a `state` and `from_state` function.
    Nepochr   global_step	max_stepsi  logging_steps
eval_steps
save_stepstrain_batch_sizenum_train_epochsnum_input_tokens_seen
total_floslog_historybest_metricbest_global_stepbest_model_checkpointTis_local_process_zerois_world_process_zeroFis_hyper_param_search
trial_nametrial_paramsTrainerCallbackstateful_callbacksc                 C   s   | j d u rg | _ | jd u ri | _d S t| jtrd S i }| jD ]6}t|ts/tdt| |jj}||v rOt|| t	sE|| g||< || 
|  q| ||< q|| _d S )NzNAll callbacks passed to be saved must inherit `ExportableState`, but received )r   r"   
isinstancedictExportableState	TypeErrortype	__class____name__listappendstate)selfr"   callbackname r0   q/var/www/html/construction_image-detection-poc/venv/lib/python3.10/site-packages/transformers/trainer_callback.py__post_init__u   s&   





zTrainerState.__post_init__	json_pathc                 C   sX   t jt| dddd }t|ddd}|| W d   dS 1 s%w   Y  dS )	zDSave the content of this instance in JSON format inside `json_path`.   T)indent	sort_keys
wutf-8encodingN)jsondumpsdataclassesasdictopenwrite)r-   r3   json_stringfr0   r0   r1   save_to_json   s   "zTrainerState.save_to_jsonc                 C   sH   t |dd}| }W d   n1 sw   Y  | di t|S )z3Create an instance from the content of `json_path`.r9   r:   Nr0   )r@   readr<   loads)clsr3   rC   textr0   r0   r1   load_from_json   s   
zTrainerState.load_from_jsonc                 C   sN   dD ]"}t || d}|dur$|dk rt|| }t| | d| qdS )z
        Calculates and stores the absolute value for logging,
        eval, and save steps based on if it was a proportion
        or not.
        )r   evalsave_stepsNr   )getattrmathceilsetattr)r-   argsr   	step_kind	num_stepsr0   r0   r1   compute_steps   s   zTrainerState.compute_stepsc                 C   s   |j dur|jdur| |j| _d| _|dur.ddlm} |jtjkr'|j	n|}||| _|| _
|| _| | _| | _dS )zI
        Stores the initial training references needed in `self`
        Nr   )	hp_params)hp_name_trialr   r    transformers.integrationsrU   hp_search_backendr   SIGOPTassignmentsr   r   r   r   )r-   trainerr   r   trialrU   r[   r0   r0   r1   init_training_references   s   

z%TrainerState.init_training_references)'r)   
__module____qualname____doc__r   r   float__annotations__r   intr   r   r   r   r   r   r   r   r   r*   r$   strr   r   r   r   boolr   r   r   r    r   r"   r2   rD   classmethodrI   rT   r^   r0   r0   r0   r1   r   #   s8   
 ; r   c                   @   s*   e Zd ZdZdefddZedd ZdS )r%   aj  
    A class for objects that include the ability to have its state
    be saved during `Trainer._save_checkpoint` and loaded back in during
    `Trainer._load_from_checkpoint`.

    These must implement a `state` function that gets called during the respective
    Trainer function call. It should only include parameters and attributes needed to
    recreate the state at a particular time, to avoid utilizing pickle/maintain standard
    file IO writing.

    Example:

    ```python
    class EarlyStoppingCallback(TrainerCallback, ExportableState):
        def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
            self.early_stopping_patience = early_stopping_patience
            self.early_stopping_threshold = early_stopping_threshold
            # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
            self.early_stopping_patience_counter = 0

        def state(self) -> dict:
            return {
                "args": {
                    "early_stopping_patience": self.early_stopping_patience,
                    "early_stopping_threshold": self.early_stopping_threshold,
                },
                "attributes": {
                    "early_stopping_patience_counter": self.early_stopping_patience_counter,
                }
            }
    ```returnc                 C   s   t d)Nz<You must implement a `state` function to utilize this class.)NotImplementedErrorr-   r0   r0   r1   r,      s   zExportableState.statec                 C   s8   | di |d }|d   D ]
\}}t||| q|S )NrQ   
attributesr0   )itemsrP   )rG   r,   instancekvr0   r0   r1   
from_state   s   zExportableState.from_stateN)r)   r_   r`   ra   r$   r,   rg   rp   r0   r0   r0   r1   r%      s
     r%   c                   @   st   e Zd ZU dZdZeed< dZeed< dZeed< dZ	eed< dZ
eed< dd	 Zd
d Zdd ZdefddZdS )TrainerControlaA  
    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
    switches in the training loop.

    Args:
        should_training_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the training should be interrupted.

            If `True`, this variable will not be set back to `False`. The training will just stop.
        should_epoch_stop (`bool`, *optional*, defaults to `False`):
            Whether or not the current epoch should be interrupted.

            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
        should_save (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be saved at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_evaluate (`bool`, *optional*, defaults to `False`):
            Whether or not the model should be evaluated at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
        should_log (`bool`, *optional*, defaults to `False`):
            Whether or not the logs should be reported at this step.

            If `True`, this variable will be set back to `False` at the beginning of the next step.
    Fshould_training_stopshould_epoch_stopshould_saveshould_evaluate
should_logc                 C   
   d| _ dS )z<Internal method that resets the variable for a new training.FN)rr   rj   r0   r0   r1   _new_training     
zTrainerControl._new_trainingc                 C   rw   )z9Internal method that resets the variable for a new epoch.FN)rs   rj   r0   r0   r1   
_new_epoch  ry   zTrainerControl._new_epochc                 C   s   d| _ d| _d| _dS )z8Internal method that resets the variable for a new step.FN)rt   ru   rv   rj   r0   r0   r1   	_new_step  s   
zTrainerControl._new_steprh   c                 C   s    | j | j| j| j| jdi dS )Nrr   rs   rt   ru   rv   rQ   rk   r|   rj   r0   r0   r1   r,     s   zTrainerControl.stateN)r)   r_   r`   ra   rr   rf   rc   rs   rt   ru   rv   rx   rz   r{   r$   r,   r0   r0   r0   r1   rq      s   
 rq   c                   @   sZ  e Zd ZdZdededefddZdededefddZdededefd	d
Z	dededefddZ
dededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefddZdededefdd Zdededefd!d"Zd#S )$r!   a	  
    A class for objects that will inspect the state of the training loop at some events and take some decisions. At
    each of those events the following arguments are available:

    Args:
        args ([`TrainingArguments`]):
            The training arguments used to instantiate the [`Trainer`].
        state ([`TrainerState`]):
            The current state of the [`Trainer`].
        control ([`TrainerControl`]):
            The object that is returned to the [`Trainer`] and can be used to make some decisions.
        model ([`PreTrainedModel`] or `torch.nn.Module`):
            The model being trained.
        tokenizer ([`PreTrainedTokenizer`]):
            The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`.
        processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
            The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
        optimizer (`torch.optim.Optimizer`):
            The optimizer used for the training steps.
        lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
            The scheduler used for setting the learning rate.
        train_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for training.
        eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
            The current dataloader used for evaluation.
        metrics (`Dict[str, float]`):
            The metrics computed by the last evaluation phase.

            Those are only accessible in the event `on_evaluate`.
        logs  (`Dict[str, float]`):
            The values to log.

            Those are only accessible in the event `on_log`.

    The `control` object is the only one that can be changed by the callback, in which case the event that changes it
    should return the modified version.

    The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
    You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
    simple [`~transformers.PrinterCallback`].

    Example:

    ```python
    class PrinterCallback(TrainerCallback):
        def on_log(self, args, state, control, logs=None, **kwargs):
            _ = logs.pop("total_flos", None)
            if state.is_local_process_zero:
                print(logs)
    ```rQ   r,   controlc                 K      dS )zS
        Event called at the end of the initialization of the [`Trainer`].
        Nr0   r-   rQ   r,   r~   kwargsr0   r0   r1   on_init_end^     zTrainerCallback.on_init_endc                 K   r   )z<
        Event called at the beginning of training.
        Nr0   r   r0   r0   r1   on_train_begind  r   zTrainerCallback.on_train_beginc                 K   r   )z6
        Event called at the end of training.
        Nr0   r   r0   r0   r1   on_train_endj  r   zTrainerCallback.on_train_endc                 K   r   )z<
        Event called at the beginning of an epoch.
        Nr0   r   r0   r0   r1   on_epoch_beginp  r   zTrainerCallback.on_epoch_beginc                 K   r   )z6
        Event called at the end of an epoch.
        Nr0   r   r0   r0   r1   on_epoch_endv  r   zTrainerCallback.on_epoch_endc                 K   r   )z
        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr0   r   r0   r0   r1   on_step_begin|     zTrainerCallback.on_step_beginc                 K   r   )zv
        Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
        Nr0   r   r0   r0   r1   on_pre_optimizer_step  r   z%TrainerCallback.on_pre_optimizer_stepc                 K   r   )z}
        Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
        Nr0   r   r0   r0   r1   on_optimizer_step  r   z!TrainerCallback.on_optimizer_stepc                 K   r   )zU
        Event called at the end of an substep during gradient accumulation.
        Nr0   r   r0   r0   r1   on_substep_end  r   zTrainerCallback.on_substep_endc                 K   r   )z
        Event called at the end of a training step. If using gradient accumulation, one training step might take
        several inputs.
        Nr0   r   r0   r0   r1   on_step_end  r   zTrainerCallback.on_step_endc                 K   r   )z9
        Event called after an evaluation phase.
        Nr0   r   r0   r0   r1   on_evaluate  r   zTrainerCallback.on_evaluatec                 K   r   )z=
        Event called after a successful prediction.
        Nr0   )r-   rQ   r,   r~   metricsr   r0   r0   r1   
on_predict  r   zTrainerCallback.on_predictc                 K   r   )z7
        Event called after a checkpoint save.
        Nr0   r   r0   r0   r1   on_save  r   zTrainerCallback.on_savec                 K   r   )z;
        Event called after logging the last logs.
        Nr0   r   r0   r0   r1   on_log  r   zTrainerCallback.on_logc                 K   r   )z7
        Event called after a prediction step.
        Nr0   r   r0   r0   r1   on_prediction_step  r   z"TrainerCallback.on_prediction_stepN)r)   r_   r`   ra   r   r   rq   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r0   r0   r0   r1   r!   )  s"    3r!   c                   @   s  e Zd ZdZdd Zdd Zdd Zdd	 Zed
d Z	de
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefddZde
dedefdd Zde
dedefd!d"Zde
dedefd#d$Zde
dedefd%d&Zde
dedefd'd(Zde
dedefd)d*Zde
dedefd+d,Zd-d. Zd/S )0CallbackHandlerz>Internal class that just calls the list of callbacks in order.c                 C   sj   g | _ |D ]}| | q|| _|| _|| _|| _d | _d | _tdd | j D s3t	
d| j  d S d S )Nc                 s   s    | ]}t |tV  qd S N)r#   DefaultFlowCallback.0cbr0   r0   r1   	<genexpr>  s    z+CallbackHandler.__init__.<locals>.<genexpr>zThe Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You
should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list ofcallbacks is
:)	callbacksadd_callbackmodelprocessing_class	optimizerlr_schedulertrain_dataloadereval_dataloaderanyloggerwarningcallback_list)r-   r   r   r   r   r   r   r0   r0   r1   __init__  s    zCallbackHandler.__init__c                 C   sh   t |tr| n|}t |tr|n|j}|dd | jD v r,td| dd | j  | j| d S )Nc                 S   s   g | ]}|j qS r0   )r(   )r   cr0   r0   r1   
<listcomp>  s    z0CallbackHandler.add_callback.<locals>.<listcomp>zYou are adding a zH to the callbacks of this Trainer, but there is already one. The currentzlist of callbacks is
:)r#   r'   r(   r   r   r   r   r+   )r-   r.   r   cb_classr0   r0   r1   r     s   
zCallbackHandler.add_callbackc                 C   sd   t |tr| jD ]}t ||r| j| |  S qd S | jD ]}||kr/| j| |  S qd S r   r#   r'   r   remover-   r.   r   r0   r0   r1   pop_callback  s   



zCallbackHandler.pop_callbackc                 C   sF   t |tr| jD ]}t ||r| j|  d S qd S | j| d S r   r   r   r0   r0   r1   remove_callback  s   


zCallbackHandler.remove_callbackc                 C   s   d dd | jD S )Nr7   c                 s   s    | ]}|j jV  qd S r   )r(   r)   r   r0   r0   r1   r     s    z0CallbackHandler.callback_list.<locals>.<genexpr>)joinr   rj   r0   r0   r1   r     s   zCallbackHandler.callback_listrQ   r,   r~   c                 C      |  d|||S )Nr   
call_eventr-   rQ   r,   r~   r0   r0   r1   r        zCallbackHandler.on_init_endc                 C      d|_ | d|||S )NFr   )rr   r   r   r0   r0   r1   r        zCallbackHandler.on_train_beginc                 C   r   )Nr   r   r   r0   r0   r1   r     r   zCallbackHandler.on_train_endc                 C   r   )NFr   )rs   r   r   r0   r0   r1   r     r   zCallbackHandler.on_epoch_beginc                 C   r   )Nr   r   r   r0   r0   r1   r     r   zCallbackHandler.on_epoch_endc                 C   s"   d|_ d|_d|_| d|||S )NFr   )rv   ru   rt   r   r   r0   r0   r1   r     s   zCallbackHandler.on_step_beginc                 C   r   )Nr   r   r   r0   r0   r1   r     r   z%CallbackHandler.on_pre_optimizer_stepc                 C   r   )Nr   r   r   r0   r0   r1   r     r   z!CallbackHandler.on_optimizer_stepc                 C   r   )Nr   r   r   r0   r0   r1   r     r   zCallbackHandler.on_substep_endc                 C   r   )Nr   r   r   r0   r0   r1   r     r   zCallbackHandler.on_step_endc                 C      d|_ | jd||||dS )NFr   r   )ru   r   r-   rQ   r,   r~   r   r0   r0   r1   r        zCallbackHandler.on_evaluatec                 C   s   | j d||||dS )Nr   r   r   r   r0   r0   r1   r     s   zCallbackHandler.on_predictc                 C   r   )NFr   )rt   r   r   r0   r0   r1   r     r   zCallbackHandler.on_savec                 C   r   )NFr   )logs)rv   r   )r-   rQ   r,   r~   r   r0   r0   r1   r   #  r   zCallbackHandler.on_logc                 C   r   )Nr   r   r   r0   r0   r1   r   '  r   z"CallbackHandler.on_prediction_stepc              
   K   sP   | j D ]"}t|||||f| j| j| j| j| j| jd|}|d ur%|}q|S )N)r   r   r   r   r   r   )r   rM   r   r   r   r   r   r   )r-   eventrQ   r,   r~   r   r.   resultr0   r0   r1   r   *  s&   

zCallbackHandler.call_eventN)r)   r_   r`   ra   r   r   r   r   propertyr   r   r   rq   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r0   r0   r0   r1   r     s0    	
r   c                   @   s<   e Zd ZdZdededefddZdededefddZd	S )
r   zx
    A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
    rQ   r,   r~   c                 K   s   |j dkr|jrd|_|jtjkr|j |j dkrd|_|jtjkr3|j |j dkr3|j	|j kr3d|_
|jtjkrI|jdkrI|j |j dkrId|_|j |jkr[d|_|jtjkr[d|_|S )Nr   Tr   )r   logging_first_steprv   logging_strategyr   STEPSr   eval_strategyr   
eval_delayru   save_strategyr	   r   rt   r   rr   r   r0   r0   r1   r   C  s"   
zDefaultFlowCallback.on_step_endc                 K   sF   |j tjkr	d|_|jtjkr|j|jkrd|_|jt	jkr!d|_
|S )NT)r   r   EPOCHrv   r   r   r   ru   r   r	   rt   r   r0   r0   r1   r   c  s   z DefaultFlowCallback.on_epoch_endN)	r)   r_   r`   ra   r   r   rq   r   r   r0   r0   r0   r1   r   >  s     r   c                   @   s\   e Zd ZdZddefddZdd Zdd	 ZdddZdd Z	dd Z
dddZdd Zd
S )ProgressCallbackz
    A [`TrainerCallback`] that displays the progress of training or evaluation.
    You can modify `max_str_len` to control how long strings are truncated when logging.
    d   max_str_lenc                 C   s   d| _ d| _|| _dS )a!  
        Initialize the callback with optional max_str_len parameter to control string truncation length.

        Args:
            max_str_len (`int`):
                Maximum length of strings to display in logs.
                Longer strings will be truncated with a message.
        N)training_barprediction_barr   )r-   r   r0   r0   r1   r   y  s   	
zProgressCallback.__init__c                 K   s    |j rt|jdd| _d| _d S )NT)totaldynamic_ncolsr   )r   r   r   r   current_stepr   r0   r0   r1   r     s   
zProgressCallback.on_train_beginc                 K   s*   |j r| j|j| j  |j| _d S d S r   )r   r   updater   r   r   r0   r0   r1   r     s   zProgressCallback.on_step_endNc                 K   sJ   |j r!t|r#| jd u rtt|| jd u dd| _| jd d S d S d S )NT)r   leaver   r   )r   r
   r   r   lenr   r   )r-   rQ   r,   r~   r   r   r0   r0   r1   r     s   
z#ProgressCallback.on_prediction_stepc                 K   (   |j r| jd ur| j  d | _d S d S r   r   r   closer   r0   r0   r1   r     
   


zProgressCallback.on_evaluatec                 K   r   r   r   r   r0   r0   r1   r     r   zProgressCallback.on_predictc           
      K   s   |j rO| jd urQi }| D ]#\}}t|tr-t|| jkr-dt| d| j d||< q|||< q|dd }	d|v rEt|d d|d< | j	t| d S d S d S )Nz%[String too long to display, length: z > z/. Consider increasing `max_str_len` if needed.]r   r   r4   )
r   r   rl   r#   re   r   r   poproundrA   )
r-   rQ   r,   r~   r   r   shallow_logsrn   ro   _r0   r0   r1   r     s   
zProgressCallback.on_logc                 K   s   |j r| j  d | _d S d S r   )r   r   r   r   r0   r0   r1   r     s   

zProgressCallback.on_train_end)r   r   )r)   r_   r`   ra   rd   r   r   r   r   r   r   r   r   r0   r0   r0   r1   r   s  s    

r   c                   @   s   e Zd ZdZdddZdS )PrinterCallbackz?
    A bare [`TrainerCallback`] that just prints the logs.
    Nc                 K   s"   | dd }|jrt| d S d S )Nr   )r   r   print)r-   rQ   r,   r~   r   r   r   r0   r0   r1   r     s   zPrinterCallback.on_logr   )r)   r_   r`   ra   r   r0   r0   r0   r1   r     s    r   c                   @   sN   e Zd ZdZddedee fddZdd	 Zd
d Z	dd Z
defddZdS )EarlyStoppingCallbacka1  
    A [`TrainerCallback`] that handles early stopping.

    Args:
        early_stopping_patience (`int`):
            Use with `metric_for_best_model` to stop training when the specified metric worsens for
            `early_stopping_patience` evaluation calls.
        early_stopping_threshold(`float`, *optional*):
            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
            specified metric must improve to satisfy early stopping conditions. `

    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
    early stopping will not occur until the next save step.
    r           early_stopping_patienceearly_stopping_thresholdc                 C   s   || _ || _d| _d S )Nr   r   r   early_stopping_patience_counter)r-   r   r   r0   r0   r1   r     s   
zEarlyStoppingCallback.__init__c                 C   sX   |j rtjntj}|jd u s|||jr#t||j | jkr#d| _d S |  jd7  _d S )Nr   r   )greater_is_betternpgreaterlessr   absr   r   )r-   rQ   r,   r~   metric_valueoperatorr0   r0   r1   check_metric_value  s   


z(EarlyStoppingCallback.check_metric_valuec                 K   s:   |j std |jd usJ d|jtjksJ dd S )NzUsing EarlyStoppingCallback without load_best_model_at_end=True. Once training is finished, the best model will not be loaded automatically.zBEarlyStoppingCallback requires metric_for_best_model to be definedzAEarlyStoppingCallback requires IntervalStrategy of steps or epoch)load_best_model_at_endr   r   metric_for_best_modelr   r   NOr   r0   r0   r1   r     s   z$EarlyStoppingCallback.on_train_beginc                 K   sl   |j }|dsd| }||}|d u r!td| d d S | |||| | j| jkr4d|_d S d S )Neval_z@early stopping required metric_for_best_model, but did not find z so early stopping is disabledT)	r   
startswithgetr   r   r   r   r   rr   )r-   rQ   r,   r~   r   r   metric_to_checkr   r0   r0   r1   r     s   




z!EarlyStoppingCallback.on_evaluaterh   c                 C   s   | j | jdd| jidS )N)r   r   r   r}   r   rj   r0   r0   r1   r,     s   zEarlyStoppingCallback.stateN)r   r   )r)   r_   r`   ra   rd   r   rb   r   r   r   r   r$   r,   r0   r0   r0   r1   r     s    r   )!ra   r>   r<   rN   r   typingr   r   numpyr   	tqdm.autor   trainer_utilsr   r   r	   r
   training_argsr   utilsr   
get_loggerr)   r   r   r%   rq   r!   r   r   r   r   r   r0   r0   r0   r1   <module>   s4   
 ,=  5J