;-*- Package: (discrete-walk) -*- ;;; A simulation of a TD(lambda) learning system to predict the expected outcome ;;; of a discrete-state random walk like that in the original 1988 TD paper. (defpackage :discrete-walk (:use :common-lisp :g :ut :graph) (:nicknames :dwalk)) (in-package :dwalk) (defvar n 5) ; the number of nonterminal states (defvar w) ; the vector of weights = predictions (defvar e) ; the eligibility trace (defvar lambda .9) ; trace decay parameter (defvar alpha 0.1) ; learning-rate parameter (defvar initial-w 0.5) (defvar standard-walks nil) ; list of standard walks (defvar trace-type :none) ; :replace, :accumulate, :average, :1/t or :none (defvar alpha-type :fixed) ; :fixed, :1/t, or :1/t-max (defvar alpha-array) ; used when each state has a different alpha (defvar u) ; usage count = number of times updated (defvar delta-w) (defun setup (num-runs num-walks) (setq w (make-array n)) (setq delta-w (make-array n)) (setq e (make-array n)) (setq u (make-array n)) (setq alpha-array (make-array n)) (setq standard-walks (standard-walks num-runs num-walks)) (length standard-walks)) (defun init () (loop for i below n do (setf (aref w i) initial-w)) (loop for i below n do (setf (aref alpha-array i) alpha)) (loop for i below n do (setf (aref u i) 0))) (defun init-traces () (loop for i below n do (setf (aref e i) 0))) (defun learn (x target) (ecase alpha-type (:1/t (incf (aref u x)) (setf (aref alpha-array x) (/ 1.0 (aref u x)))) (:fixed) (:1/t-max (when (<= (aref u x) (/ 1 alpha)) (incf (aref u x)) (setf (aref alpha-array x) (/ 1.0 (aref u x)))))) (ecase trace-type (:none) (:replace (loop for i below n do (setf (aref e i) (* lambda (aref e i)))) (decf (aref u x) (aref e x)) (setf (aref e x) 1)) (:accumulate (loop for i below n do (setf (aref e i) (* lambda (aref e i)))) (incf (aref e x) 1)) (:average (loop for i below n do (setf (aref e i) (* lambda (aref e i)))) (setf (aref e x) (+ 1 (* (aref e x) (- 1 (aref alpha-array x)))))) (:1/t (incf (aref u x)) (incf (aref e x) 1) (loop for i below n for lambda = (float (/ (aref u x))) do (setf (aref e i) (* lambda (aref e i)))))) (if (eq trace-type :none) (incf (aref delta-w x) (* alpha (- target (aref w x)))) (loop for i below n with error = (- target (aref w x)) do (incf (aref delta-w i) (* (aref alpha-array i) error (aref e i)))))) (defun process-walk (walk) (destructuring-bind (outcome states) walk (unless (eq trace-type :none) (init-traces)) (loop for s1 in states for s2 in (rest states) do (learn s1 (aref w s2))) (learn (first (last states)) outcome))) (defun process-walk-backwards (walk) (destructuring-bind (outcome states) walk (unless (eq trace-type :none) (init-traces)) (learn (first (last states)) outcome) (loop for s1 in (reverse (butlast states)) for s2 in (reverse (rest states)) do (learn s1 (aref w s2))))) (defun process-walk-MC (walk) (destructuring-bind (outcome states) walk (loop for s in (reverse states) do (learn s outcome)))) (defun standard-walks (num-sets-of-walks num-walks) (loop repeat num-sets-of-walks with random-state = (ut::copy-of-standard-random-state) collect (loop repeat num-walks collect (random-walk n random-state)))) (defun random-walk (n &optional (random-state *random-state*)) (loop with start-state = (round (/ n 2)) for x = start-state then (with-prob .5 (+ x 1) (- x 1) random-state) while (AND (>= x 0) (< x n)) collect x into xs finally (return (list (if (< x 0) 0 1) xs)))) (defun residual-error () "Returns the residual RMSE between the current and correct predictions" (rmse 0 (loop for i below n when (>= (aref w i) -.1) collect (- (aref w i) (/ (+ i 1) (+ n 1) ))))) (defun batch-exp () (setq lambda 0.0) (setq trace-type :none) (setq initial-w -1) (loop for walk-set in standard-walks for run-num from 0 do (loop for l in '(0 1) do (init) (record l run-num (loop for num-walks from 1 to (length walk-set) for walk-subset = (firstn num-walks walk-set) do (setf alpha (/ 1.0 n num-walks 3)) (loop do (loop for i below n do (setf (aref delta-w i) 0)) do (loop for walk in walk-subset do (ecase l (0 (process-walk walk)) (1 (process-walk-mc walk)))) do (loop for i below n do (incf (aref w i) (aref delta-w i))) until (> .0000001 (loop for i below n sum (abs (aref delta-w i))))) collect (residual-error))))))