IDQL

IDQL class

source

IDQL

 IDQL (idql_net:Optional[jaxrl5.agents.ddpm_iql_simple.ddpm_iql_learner.DD
       PMIQLLearner]=None,
       action_space:Optional[gymnasium.spaces.space.Space]=None,
       observation_space:Optional[gymnasium.spaces.space.Space]=None,
       _ckpt_idql_dir:Optional[pathlib.Path]=None,
       _truck:tspace.config.vehicles.Truck,
       _driver:tspace.config.drivers.Driver, _resume:bool, _coll_type:str,
       _hyper_param:Union[tspace.agent.utils.hyperparams.HyperParamDDPG,ts
       pace.agent.utils.hyperparams.HyperParamRDPG,tspace.agent.utils.hype
       rparams.HyperParamIDQL], _pool_key:str, _data_folder:str,
       _infer_mode:bool, _buffer:Union[tspace.storage.buffer.mongo.MongoBu
       ffer,tspace.storage.buffer.dask.DaskBuffer,NoneType]=None, _episode
       _start_dt:Optional[pandas._libs.tslibs.timestamps.Timestamp]=None, 
       _observation_meta:Union[tspace.data.core.ObservationMetaCloud,tspac
       e.data.core.ObservationMetaECU,NoneType]=None,
       _torque_table_row_names:Optional[list[str]]=None,
       _observations:Optional[list[pandas.core.series.Series]]=None,
       _epi_no:Optional[int]=None, logger:Optional[logging.Logger]=None,
       dict_logger:Optional[dict]=None)

*IDQL agent for VEOS.

Abstracts:

data interface:
    - pool in mongodb
    - buffer in memory (numpy array)
model interface:
    - idql_net: the implicit diffusion q-learning networks, which contains
        - actor_net: the behavior actor network (from the data)
        - critic_net: the critic network (Q-value function)
        - value_net: the value network (V-value function)
        The immplicit policy is re-weighting the sample from the behavior actor network with the importance weights 
        recommending the expectile loss by the paper 
    _ckpt_idql_dir: checkpoint directory for critic*

source

IDQL.__post_init__

 IDQL.__post_init__ ()

*initialize the rdpg agent.

args:

- truck.ObservationNumber (int): dimension of the state space.
- padding_value (float): value to pad the state with, impossible value for observation, action or re*

source

IDQL.__repr__

 IDQL.__repr__ ()

Return repr(self).


source

IDQL.__str__

 IDQL.__str__ ()

Return str(self).


source

IDQL.__hash__

 IDQL.__hash__ ()

Return hash(self).


source

IDQL.touch_gpu

 IDQL.touch_gpu ()

touch the gpu to avoid the first time delay


source

IDQL.init_checkpoint

 IDQL.init_checkpoint ()

create or restore from checkpoint


source

IDQL.actor_predict

 IDQL.actor_predict (state:pandas.core.series.Series)

*sample actions with additive ou noise

input: state is a pd.Series of length 3103/4503 (r*c), output numpy array

Action outputs and noise object are all row vectors of length 2117 (rc), output numpy array*

Type Details
state Series state sequence of the current episode
Returns ndarray action sequence of the current episode

source

IDQL.sample_minibatch

 IDQL.sample_minibatch ()

Convert batch type from DataFrames to flattened tensors.


source

IDQL.train

 IDQL.train ()

Train the networks on the batch sampled from the pool.


source

IDQL.soft_update_target

 IDQL.soft_update_target ()

update the target networks with Polyak averaging


source

IDQL.save_ckpt

 IDQL.save_ckpt ()

TODO Save the checkpoint of the actor, critic and value network in Flax.


source

IDQL.get_losses

 IDQL.get_losses ()

Get the losses of the networks on the batch sampled from the pool.