This project provides a comprehensive demonstration of training a Long Short-Term Memory (LSTM) model using Reinforcement Learning (RL) with PyTorch. The project is meticulously organized into distinct components, including a custom agent, environment, and model, to enhance readability and maintainability. By generating a random dataset, the model is trained to identify sequences where the sum of the current datapoint and the preceding datapoint exceeds a value of 3. This fictive training scenario serves as an illustrative example of how to effectively train an LSTM using RL techniques.
For more high-level details about the project, you can read this article where i share some more insights.
The project is organized into the following directories and files:
-
agent/__init__.py: Package initialization file.triple_action_agent.py: Defines theTripleActionAgentclass, which uses an LSTM model for action selection and learning.
-
environment/__init__.py: Package initialization file.simple_env.py: Defines theSimpleEnvclass, which represents a custom environment for the agent to interact with.
-
model/__init__.py: Package initialization file.lstm.py: Defines theLSTMModelclass, which implements the LSTM neural network.
-
run.py: The main entry point for running the project. It contains the training and testing logic for the agent. -
README.md: This documentation file. -
requirements.txt: Lists the project's dependencies, which can be installed using a package manager like pip. -
rewards_and_losses_per_episode.png: A plot showing the total rewards and losses per episode during training.
The TripleActionAgent class is a reinforcement learning agent that uses an LSTM model to select actions and learn from experiences. It follows an ε-greedy policy for action selection and uses Q-learning for training.
The SimpleEnv class represents a custom environment for the agent to interact with. It provides methods for resetting the environment, taking steps, and calculating rewards.
The LSTMModel class implements a Long Short-Term Memory (LSTM) neural network. It consists of an LSTM layer and a fully connected layer for predicting Q-values.
The run.py file contains the main logic for training and testing the agent. It initializes the environment and agent, runs the training loop, and evaluates the agent's performance.
- Python 3.12 or higher installed
- Install the required dependencies:
pip install -r requirements.txt
- Run the training and testing script:
python run.py

