IDQL
IDQL class
IDQL
IDQL (idql_net:Optional[jaxrl5.agents.ddpm_iql_simple.ddpm_iql_learner.DD PMIQLLearner]=None, action_space:Optional[gymnasium.spaces.space.Space]=None, observation_space:Optional[gymnasium.spaces.space.Space]=None, _ckpt_idql_dir:Optional[pathlib.Path]=None, _truck:tspace.config.vehicles.Truck, _driver:tspace.config.drivers.Driver, _resume:bool, _coll_type:str, _hyper_param:Union[tspace.agent.utils.hyperparams.HyperParamDDPG,ts pace.agent.utils.hyperparams.HyperParamRDPG,tspace.agent.utils.hype rparams.HyperParamIDQL], _pool_key:str, _data_folder:str, _infer_mode:bool, _buffer:Union[tspace.storage.buffer.mongo.MongoBu ffer,tspace.storage.buffer.dask.DaskBuffer,NoneType]=None, _episode _start_dt:Optional[pandas._libs.tslibs.timestamps.Timestamp]=None, _observation_meta:Union[tspace.data.core.ObservationMetaCloud,tspac e.data.core.ObservationMetaECU,NoneType]=None, _torque_table_row_names:Optional[list[str]]=None, _observations:Optional[list[pandas.core.series.Series]]=None, _epi_no:Optional[int]=None, logger:Optional[logging.Logger]=None, dict_logger:Optional[dict]=None)
*IDQL agent for VEOS.
Abstracts:
data interface:
- pool in mongodb
- buffer in memory (numpy array)
model interface:
- idql_net: the implicit diffusion q-learning networks, which contains
- actor_net: the behavior actor network (from the data)
- critic_net: the critic network (Q-value function)
- value_net: the value network (V-value function)
The immplicit policy is re-weighting the sample from the behavior actor network with the importance weights
recommending the expectile loss by the paper
_ckpt_idql_dir: checkpoint directory for critic*
IDQL.__post_init__
IDQL.__post_init__ ()
*initialize the rdpg agent.
args:
- truck.ObservationNumber (int): dimension of the state space.
- padding_value (float): value to pad the state with, impossible value for observation, action or re*
IDQL.__repr__
IDQL.__repr__ ()
Return repr(self).
IDQL.__str__
IDQL.__str__ ()
Return str(self).
IDQL.__hash__
IDQL.__hash__ ()
Return hash(self).
IDQL.touch_gpu
IDQL.touch_gpu ()
touch the gpu to avoid the first time delay
IDQL.init_checkpoint
IDQL.init_checkpoint ()
create or restore from checkpoint
IDQL.actor_predict
IDQL.actor_predict (state:pandas.core.series.Series)
*sample actions with additive ou noise
input: state is a pd.Series of length 3103/4503 (r*c), output numpy array
Action outputs and noise object are all row vectors of length 2117 (rc), output numpy array*
Type | Details | |
---|---|---|
state | Series | state sequence of the current episode |
Returns | ndarray | action sequence of the current episode |
IDQL.sample_minibatch
IDQL.sample_minibatch ()
Convert batch type from DataFrames to flattened tensors.
IDQL.train
IDQL.train ()
Train the networks on the batch sampled from the pool.
IDQL.soft_update_target
IDQL.soft_update_target ()
update the target networks with Polyak averaging
IDQL.save_ckpt
IDQL.save_ckpt ()
TODO Save the checkpoint of the actor, critic and value network in Flax.
IDQL.get_losses
IDQL.get_losses ()
Get the losses of the networks on the batch sampled from the pool.