详解Keras3.0 Models API: Model training APIs

2023-12-13 05:34:27

1、compile?方法

Model.compile(
    optimizer="rmsprop",
    loss=None,
    loss_weights=None,
    metrics=None,
    weighted_metrics=None,
    run_eagerly=False,
    steps_per_execution=1,
    jit_compile="auto",
    auto_scale_loss=True,
)

?

参数说明
  • optimizer: 优化器,用于指定在训练过程中使用的优化算法。在这个例子中,使用的是RMSprop优化器。
  • loss: 损失函数,用于衡量模型预测值与真实值之间的差距。在这个例子中,损失函数被设置为None,表示使用模型自带的默认损失函数。
  • loss_weights: 损失权重,用于为不同的损失项分配权重。在这个例子中,损失权重被设置为None,表示使用模型自带的默认损失权重。
  • metrics: 评估指标,用于在训练和验证过程中评估模型的性能。在这个例子中,评估指标被设置为None,表示使用模型自带的默认评估指标。
  • weighted_metrics: 加权评估指标,用于在训练和验证过程中评估模型的性能,并考虑样本的权重。在这个例子中,加权评估指标被设置为None,表示使用模型自带的默认加权评估指标。
  • run_eagerly: 是否立即执行计算图。如果设置为True,则在每次迭代时都会立即执行计算图;如果设置为False,则只有在需要时才会执行计算图。在这个例子中,设置为False,表示只在需要时执行计算图。
  • steps_per_execution: 每次执行的步数。这个参数用于控制每次执行模型前向传播的次数。在这个例子中,设置为1,表示每次执行只进行一次前向传播。
  • jit_compile: JIT编译模式。如果设置为"auto",则根据输入数据的形状自动选择是否进行JIT编译;如果设置为True,则总是进行JIT编译;如果设置为False,则不进行JIT编译。在这个例子中,设置为"auto",表示根据输入数据的形状自动选择是否进行JIT编译。
  • auto_scale_loss: 是否自动缩放损失值。如果设置为True,则根据损失值的范围自动缩放损失值;如果设置为False,则不进行自动缩放。在这个例子中,设置为True,表示根据损失值的范围自动缩放损失值。

2、fit?方法

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
)

?

参数说明?
  • x: 输入数据,通常是一个numpy数组或者TensorFlow张量。
  • y: 目标值,通常是一个numpy数组或者TensorFlow张量。
  • batch_size: 每次训练迭代中使用的样本数量。
  • epochs: 训练的总轮数。
  • verbose: 日志显示模式。可以是0(不显示日志)、1(只显示进度条)或2(同时显示进度条和日志)。
  • callbacks: 回调函数列表,可以在训练过程中执行自定义操作。
  • validation_split: 用于验证集的比例,取值范围为0到1之间。如果设置为0,则不使用验证集;如果设置为1,则将整个数据集用作验证集。
  • validation_data: 验证集的数据和标签,可以是一个元组,包含两个numpy数组或者TensorFlow张量。
  • shuffle: 是否在每个epoch开始时打乱数据顺序。
  • class_weight: 类别权重,用于处理不平衡数据集。可以是一个字典,键为类别索引,值为对应的权重。
  • sample_weight: 样本权重,用于处理不同样本的重要性。可以是一个numpy数组或者TensorFlow张量。
  • initial_epoch: 从哪个epoch开始训练。默认为0。
  • steps_per_epoch: 每个epoch中的步数。如果设置了这个参数,那么validation_steps将被忽略。
  • validation_steps: 验证集中的步数。如果设置了这个参数,那么steps_per_epoch将被忽略。
  • validation_batch_size: 验证集的批量大小。
  • validation_freq: 每隔多少个epoch进行一次验证。

文章来源:https://blog.csdn.net/lymake/article/details/134884581
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。