src.dl_trainer

Functions

predict(model, X_test)

Makes predictions with a trained PyTorch model.

predict_adjacent(model, X_test)

Makes predictions with a trained scratch Adjacent Category model.

predict_coral(model, X_test)

Makes predictions with a trained CORAL model.

predict_poker_cnn(model, X_test)

Makes predictions with a trained PokerCNNModel.

predict_poker_lstm(model, X_test)

Makes predictions with a trained PokerLSTMModel.

predict_pom_scratch(model, X_test)

Makes predictions with a trained scratch POM model.

train_adjacent_model(model, X_train, y_train)

Handles the training loop for the scratch Adjacent Category model.

train_coral_model(model, X_train, y_train, ...)

Handles the training loop for a CORAL model.

train_corn_model(model, X_train, y_train, ...)

Handles the training loop for a CORN model.

train_emd_model(model, X_train, y_train[, ...])

Handles the training loop for a model using EMD loss.

train_model(model, X_train, y_train[, ...])

Handles the training loop for a standard PyTorch classification model.

train_poker_cnn_model(model, X_train, ...[, ...])

Handles the training loop for a PokerCNNModel.

train_poker_lstm_model(model, X_train, ...)

Handles the training loop for a PokerLSTMModel.

train_pom_scratch_model(model, X_train, ...)

Handles the training loop for the scratch POM model.