In the typical RL/MDP framework, I have offline data of $(s,a,r,s')$ of expert Atari gameplay.
I'm looking to train a CNN to predict $r$ based on $(s, a)$.
The states are represented by a $4 \times 84 \times 84$ image of the Atari screen, where 4 represents 4 sequential frames, and $84 \times 84$ is the size of the image. The action is an integer from 0 to 3.
I'm not sure how best to merge these two inputs $(s, a)$ together. How should I incorporate the action into the CNN?