From c21a0681d0a419946e3c1c8c43d781a6c7f14595 Mon Sep 17 00:00:00 2001 From: vik Date: Mon, 6 Nov 2017 16:17:50 +0100 Subject: [PATCH] popleft bug buffer fixed and double deep q learning added --- reinforcement_learning/deep_Q_bridge.ipynb | 1460 +++----------------- 1 file changed, 194 insertions(+), 1266 deletions(-) diff --git a/reinforcement_learning/deep_Q_bridge.ipynb b/reinforcement_learning/deep_Q_bridge.ipynb index 40e2ee0..fd118a8 100644 --- a/reinforcement_learning/deep_Q_bridge.ipynb +++ b/reinforcement_learning/deep_Q_bridge.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": { "collapsed": true }, @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -58,7 +58,7 @@ " [ 0. 3. 0. 5.]\n", " [ 1. 2. 0. 0.]]\n", "(array([ 0., 0., 0., 0., 0., 0., -1., 0., 0., -1., 0., -1., -1.,\n", - " -1., 0., 1.]), 4.8895833333333334, True)\n", + " -1., 0., 1.]), 3.4441558772842846, True)\n", "[[ 0. 0. 0. 0.]\n", " [ 0. 0. 4. 0.]\n", " [ 0. 3. 0. 5.]\n", @@ -121,15 +121,15 @@ " \n", " # Bridge is build\n", " if self.state[-1][-1] != 0:\n", - " r = r + 5 - self.structure() / (0.5*5*self.length**2) # that is moment to the power 2\n", + " r = r + 2 + 1 / (self.structure() / (1 / 16 * self.length**2))**2 # that is moment to the power 2\n", " done = True\n", - " return s, r , done\n", + " return s, r, done\n", " \n", " if len(self.valid_actions) == 0:\n", " done = True\n", " r -= 2\n", " \n", - " return s, r , done\n", + " return s, r, done\n", " \n", " def det_valid_actions(self):\n", " no_action = set()\n", @@ -226,7 +226,7 @@ "# d_distance = self.current_distance - distance\n", "# self.current_distance = distance\n", "\n", - " return self.return_action(-0.1 )\n", + " return self.return_action(-0.1)\n", " \n", " def structure(self):\n", " self.ss = SystemElements()\n", @@ -238,13 +238,13 @@ " x = col[0] \n", "\n", " current_loc = [x, y]\n", - " self.ss.add_element([last_loc, [x, y]])\n", + " self.ss.add_element([last_loc, [x, y]], g=-1)\n", " last_loc = current_loc\n", " \n", - " n_nodes = len(self.ss.node_map)\n", - " forces = -1 / (n_nodes - 2)\n", - " for i in range(2, n_nodes):\n", - " self.ss.point_load(node_id=i, Fz=forces)\n", + "# n_nodes = len(self.ss.node_map)\n", + "# forces = -1 / (n_nodes - 2)\n", + "# for i in range(2, n_nodes):\n", + "# self.ss.point_load(node_id=i, Fz=forces)\n", " \n", " self.ss.add_support_hinged(1)\n", " self.ss.add_support_hinged(len(self.ss.node_map))\n", @@ -286,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": { "collapsed": true }, @@ -350,40 +350,23 @@ " # mse: ( target - prediction)^2\n", " # r + max(Q(s', a') - Q(s, a) )^2\n", " \n", - " #self.loss = tf.losses.huber_loss(self.next_Q_r, - self.Q, delta=2)\n", + " #self.loss = tf.losses.huber_loss(self.next_Q_r, - self.Q, delta=15)\n", "# self.clipped_error = tf.maximum(tf.abs(self.next_Q_r - self.Q), tf.ones(tf.shape(self.Q)))\n", "# self.loss = tf.reduce_sum(tf.square(self.clipped_error)) \n", + "\n", + " starter_learning_rate = learning_rate\n", + " self.train_count = tf.Variable(0, trainable=False, name=f\"{name}_train_count\")\n", + " learning_rate = tf.train.exponential_decay(starter_learning_rate, self.train_count,\n", + " 1000, 0.96)\n", + " \n", " self.loss = tf.reduce_sum(tf.square(self.next_Q_r - self.Q)) \n", " optimizer = tf.train.AdamOptimizer(learning_rate)\n", - " self.train_count = tf.Variable(0, trainable=False)\n", + " \n", " self.train = optimizer.minimize(self.loss, self.train_count)\n", " \n", " \n", " \n", - " \n", - "class FrozenAgent:\n", - " def __init__(self, data_size):\n", - " \"\"\"\n", - " :param data_size: (int) Columns of the data vector.\n", - " \"\"\"\n", - " # Step 1: Feed forward\n", - " # The argmax is the maximum Q-value.\n", - " self.input_s = tf.placeholder(tf.float32, [None, data_size], name=\"input_s\")\n", - " \n", - " # weights and biases\n", - " self.wb = tuple([tf.placeholder(dtype=tf.float32) for _ in range(6)])\n", - " \n", - " w1 = self.wb[0]\n", - " b1 = self.wb[1]\n", - " w2 = self.wb[2]\n", - " b2 = self.wb[3]\n", - " w_out = self.wb[4]\n", - " b_out = self.wb[5]\n", - " \n", - " self.layer_1 = leaky_relu(tf.matmul(self.input_s, w1) + b1)\n", - "# self.layer_2 = leaky_relu(tf.matmul(self.layer_1, w2) + b2)\n", - " \n", - " self.predict_Q = tf.matmul(self.layer_1, w_out) + b_out # actual Q-value\n" + " " ] }, { @@ -397,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 4, "metadata": { "collapsed": true }, @@ -416,7 +399,7 @@ " \"\"\"\n", " The weights and biases of the target will be a depended of the primary network.\n", " \n", - " wb[target] = t\n", + " wb[target] = tau * wb[primary] + (1-tau) * wb[target]\n", " \n", " This is a tensorflow operation and still needs to be run with Session.run(operation_holder)\n", " \n", @@ -440,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 157, + "execution_count": 7, "metadata": { "scrolled": false }, @@ -449,11 +432,37 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.116247427468 train_count 0 target_count 0 loss 0 eps 0.2714376441341128\n", - "0.0162883100961 train_count 0 target_count 0 loss 0 eps 0.2455946488435909\n", - "0.0599824019159 train_count 63 target_count 20 loss 770.631 eps 0.22221210964683\n", - "0.148132212666 train_count 163 target_count 43 loss 416.57 eps 0.20105577180202236\n", - "1.67631473098 train_count 263 target_count 48 loss 221.547 eps 0.18191368345835565\n" + "0.667360347494 train_count 15488 loss 0.477816 eps 0.09990628614486748\n", + "0.443707837601 train_count 15587 loss 1.38988 eps 0.09990628614486748\n", + "0.524500981561 train_count 15686 loss 0.593845 eps 0.09990628614486748\n", + "0.359304577248 train_count 15785 loss 4.05721 eps 0.09990628614486748\n", + "0.374203011064 train_count 15884 loss 0.752228 eps 0.09990628614486748\n", + "0.580970897654 train_count 15984 loss 0.668351 eps 0.09990628614486748\n", + "0.876998766431 train_count 16084 loss 5.69235 eps 0.09990628614486748\n", + "0.0465974575166 train_count 16178 loss 0.603223 eps 0.09990628614486748\n", + "0.0747488582335 train_count 16276 loss 0.859723 eps 0.09990628614486748\n", + "0.0416824847548 train_count 16373 loss 1.52317 eps 0.09990628614486748\n", + "0.215243585896 train_count 16471 loss 4.26813 eps 0.09990628614486748\n", + "0.539468365938 train_count 16570 loss 2.51492 eps 0.09990628614486748\n", + "0.0675016616232 train_count 16667 loss 0.938396 eps 0.09990628614486748\n", + "0.211260882325 train_count 16767 loss 1.46387 eps 0.09990628614486748\n", + "0.101623406461 train_count 16865 loss 0.67447 eps 0.09990628614486748\n", + "-0.123378754201 train_count 16960 loss 3.43497 eps 0.09990628614486748\n", + "0.0591191103686 train_count 17057 loss 55.103 eps 0.09990628614486748\n", + "0.0143498359155 train_count 17155 loss 2.58285 eps 0.09990628614486748\n", + "0.0932387574004 train_count 17253 loss 2.82041 eps 0.09990628614486748\n", + "0.218451295979 train_count 17351 loss 3.64749 eps 0.09990628614486748\n", + "0.384775273901 train_count 17451 loss 4.34415 eps 0.09990628614486748\n", + "0.20984679961 train_count 17548 loss 1.36758 eps 0.09990628614486748\n", + "0.444519653928 train_count 17646 loss 3.15587 eps 0.09990628614486748\n", + "-0.0906771169707 train_count 17740 loss 3.77309 eps 0.09990628614486748\n", + "-0.0520069782439 train_count 17836 loss 2.13154 eps 0.09990628614486748\n", + "0.349923459072 train_count 17936 loss 1.03072 eps 0.09990628614486748\n", + "0.327600297472 train_count 18035 loss 1.62981 eps 0.09990628614486748\n", + "0.762544380003 train_count 18133 loss 2.19067 eps 0.09990628614486748\n", + "1.13968537324 train_count 18232 loss 5.1917 eps 0.09990628614486748\n", + "0.965316180691 train_count 18332 loss 1.88149 eps 0.09990628614486748\n", + "0.74757660094 train_count 18432 loss 2.14978 eps 0.09990628614486748\n" ] }, { @@ -463,11 +472,10 @@ "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 88\u001b[0m \u001b[0mbatch\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m5000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msize\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1000\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 90\u001b[1;33m \u001b[0ms\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 91\u001b[0m \u001b[0ms_new\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 92\u001b[0m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 88\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 89\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[0mbuffer_size\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 90\u001b[1;33m \u001b[0mbatch\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 91\u001b[0m \u001b[0mbatch\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrandint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer_size\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msize\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m500\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 92\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mC:\\Anaconda3\\lib\\site-packages\\numpy\\core\\shape_base.py\u001b[0m in \u001b[0;36mvstack\u001b[1;34m(tup)\u001b[0m\n\u001b[0;32m 228\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 229\u001b[0m \"\"\"\n\u001b[1;32m--> 230\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_nx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0matleast_2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_m\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0m_m\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtup\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 231\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 232\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mhstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mC:\\Anaconda3\\lib\\site-packages\\numpy\\core\\shape_base.py\u001b[0m in \u001b[0;36m\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m 228\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 229\u001b[0m \"\"\"\n\u001b[1;32m--> 230\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0m_nx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0matleast_2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_m\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0m_m\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mtup\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 231\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 232\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mhstack\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtup\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mC:\\Anaconda3\\lib\\site-packages\\numpy\\core\\shape_base.py\u001b[0m in \u001b[0;36matleast_2d\u001b[1;34m(*arys)\u001b[0m\n\u001b[0;32m 98\u001b[0m \u001b[0mres\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 99\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mary\u001b[0m \u001b[1;32min\u001b[0m \u001b[0marys\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 100\u001b[1;33m \u001b[0mary\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0masanyarray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mary\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 101\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mary\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mary\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32mC:\\Anaconda3\\lib\\site-packages\\numpy\\core\\numeric.py\u001b[0m in \u001b[0;36masanyarray\u001b[1;34m(a, dtype, order)\u001b[0m\n\u001b[0;32m 531\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 532\u001b[0m \"\"\"\n\u001b[1;32m--> 533\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0morder\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0morder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msubok\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 534\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 535\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mascontiguousarray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mC:\\Anaconda3\\lib\\site-packages\\numpy\\core\\shape_base.py\u001b[0m in \u001b[0;36matleast_2d\u001b[1;34m(*arys)\u001b[0m\n\u001b[0;32m 105\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 106\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mary\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 107\u001b[1;33m \u001b[0mres\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 108\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mres\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 109\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } @@ -486,24 +494,25 @@ "# fig.show()\n", "# fig.canvas.draw()\n", "\n", - "env = Environment(3, 2, \"moment\")\n", + "env = Environment(4, 4, \"moment\")\n", "env.reset()\n", "\n", - "H = [16] # hidden neurons\n", + "H = [32] # hidden neurons\n", "D = env.state.size # input (state of the environment)\n", - "learning_rate = 1e-3\n", + "learning_rate = 1e-2\n", "gamma = 0.99 # discount factor\n", "epochs = 50000\n", - "max_frames = 500\n", + "buffer_size = 5000\n", + "max_frames = 150\n", "action_space = 8\n", "\n", - "contin = 0\n", + "contin = 1\n", "\n", "if not contin:\n", " eps = 0.3\n", " tf.reset_default_graph()\n", " agent = Agent(D, H, action_space, learning_rate, \"agent\")\n", - " target = Agent(D, H, action_space, 0, \"target\")\n", + " target = Agent(D, H, action_space, learning_rate, \"target\")\n", " \n", " init = tf.global_variables_initializer()\n", " sess = tf.Session()\n", @@ -514,27 +523,23 @@ " # The last half the variables of the target\n", " variables = tf.trainable_variables() \n", " operation_holder = prepare_update_target(variables)\n", - " \n", - "\n", - "scores= []\n", - "n_done = 0\n", + " scores = deque()\n", "\n", "last_ep = 0\n", "\n", "\n", "#https://github.com/awjuliani/DeepRL-Agents/blob/master/Q-Network.ipynb\n", "train_count = 0\n", - "target_update_count = 0\n", "loss = 0\n", - "variable_update_help = None\n", - "for ep in range(epochs):\n", - " if eps > 0.001:\n", - " eps *= 0.999\n", "\n", - " if (ep + 1) % 100 == 0:\n", - " print(np.mean(scores), \"train_count\", train_count, \"target_count\", target_update_count,\n", + "for ep in range(epochs):\n", + " \n", + " if eps > 0.1 and len(buffer) >= buffer_size:\n", + " eps *= 0.999\n", + " \n", + " if (ep + 1) % 100 == 0 and len(buffer) >= buffer_size:\n", + " print(np.mean(list(scores)[-1000:]), \"train_count\", train_count,\n", " \"loss\", loss, \"eps\", eps)\n", - " scores = []\n", " \n", " s = env.reset()\n", " s = [s]\n", @@ -543,24 +548,30 @@ " Q = sess.run(agent.predict_Q, {agent.input_s: s})\n", " \n", " if np.random.rand(1) < eps:\n", + " rand = True\n", " a = env.sample_action()\n", " else:\n", + " rand = False\n", " a = np.argmax(Q)\n", " \n", " s_new, r, done = env.step(a)\n", " scores.append(r)\n", - "\n", - " buffer.append([s, a, r, s_new, done])\n", " \n", - " if len(buffer) > 5000:\n", - " buffer.pop()\n", + " if rand == True:\n", + " buffer.append([s, a, r, s_new, done])\n", + " if rand == False and c % 50 == 0:\n", + " buffer.append([s, a, r, s_new, done])\n", + " \n", + " if len(buffer) > buffer_size:\n", + " buffer.popleft()\n", + " scores.popleft()\n", " s = [s_new]\n", - " \n", + " \n", " if done:\n", - " \n", - " if len(buffer) >= 5000:\n", + " \n", + " if len(buffer) >= buffer_size:\n", " batch = np.vstack(buffer)\n", - " batch = batch[np.random.randint(0, 5000, size=1000)]\n", + " batch = batch[np.random.randint(0, buffer_size, size=500)]\n", "\n", " s = np.vstack(batch[:, 0])\n", " s_new = np.vstack(batch[:, 3])\n", @@ -568,29 +579,51 @@ " a = batch[:, 1]\n", " done_ = np.array(batch[:, 4], dtype=bool)\n", " Q = sess.run(agent.predict_Q, {agent.input_s: s})\n", - " Q_new = sess.run(target.predict_Q, {target.input_s: s_new})\n", - " max_Q_new = np.max(Q_new, 1)\n", "\n", - " target_Q = (r + gamma * max_Q_new)\n", + " # Double Q-learning primary network chooses an action:\n", + " Q_i = np.argmax(sess.run(agent.predict_Q, {agent.input_s: s_new}), 1)\n", + " \n", + " # Target network produces the Q-value of the chosen action\n", + " Q_new = sess.run(target.predict_Q, {target.input_s: s_new})\n", + " max_Q_new =Q_new[np.arange(Q_i.size)[:, None], Q_i[:, None]][:, -1]\n", + " \n", + " target_Q = r + gamma * max_Q_new\n", " target_Q[done_] = r[done_]\n", " \n", " train_count, loss, _ = sess.run([agent.train_count, agent.loss, agent.train], \n", " feed_dict={agent.input_s: s, \n", " agent.executed_actions: a, \n", " agent.next_Q_r: target_Q})\n", - "\n", " if c % 3 == 0:\n", " # update target network\n", - " target_update_count += 1\n", " update_target(operation_holder, sess)\n", - " \n", - " \n", - " \n", + "\n", " break\n", " \n", + " \n", "\n" ] }, + { + "cell_type": "code", + "execution_count": 406, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32749" + ] + }, + "execution_count": 406, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(buffer)" + ] + }, { "cell_type": "code", "execution_count": 161, @@ -633,55 +666,85 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": {}, + "execution_count": 10, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", - " [[ 0. 0. 0.]\n", - " [ 1. 0. 0.]]\n", - "[[ 0.0157699 -0.00472315 -0.00631001 -0.0119287 -0.01288951 -0.00707795\n", - " 0.01728024 0.0105312 ]]\n", - "action 6\n", - "-0.2\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 1. 0. 0. 0.]]\n", + "[[-0.12808736 -0.09334762 -0.10103387 -0.23614739 -0.23016927 -0.19282101\n", + " -0.19304731 -0.2984907 ]]\n", + "action 1\n", + "-0.1 False\n", "\n", - " [[ 0. 0. 0.]\n", - " [ 1. 0. 0.]]\n", - "[[ 0.0157699 -0.00472315 -0.00631001 -0.0119287 -0.01288951 -0.00707795\n", - " 0.01728024 0.0105312 ]]\n", - "action 6\n", - "-0.2\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 0. 0.]]\n", + "[[-0.10368284 -0.11574875 -0.12774605 -0.12350658 -0.1642811 -0.20029946\n", + " -0.16407546 -0.08707853]]\n", + "action 7\n", + "-0.1 False\n", "\n", - " [[ 0. 0. 0.]\n", - " [ 1. 0. 0.]]\n", - "[[ 0.0157699 -0.00472315 -0.00631001 -0.0119287 -0.01288951 -0.00707795\n", - " 0.01728024 0.0105312 ]]\n", - "action 6\n", - "-0.2\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 3. 0.]]\n", + "[[ 2.46880364 -0.09685077 -0.15725884 -0.14345071 -0.17737159 -0.20034733\n", + " -0.20910886 -0.16599104]]\n", + "action 0\n", + "2.51404717866 True\n", "\n", - " [[ 0. 0. 0.]\n", - " [ 1. 0. 0.]]\n", - "[[ 0.0157699 -0.00472315 -0.00631001 -0.0119287 -0.01288951 -0.00707795\n", - " 0.01728024 0.0105312 ]]\n", - "action 6\n", - "-0.2\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 3. 4.]]\n", + "[[-0.11165825 -0.13312925 -0.15180516 -0.14280865 -0.20325702 -0.18633628\n", + " -0.0402295 0.09051565]]\n", + "action 7\n", + "2.41404717866 True\n", "\n", - " [[ 0. 0. 0.]\n", - " [ 1. 0. 0.]]\n", - "[[ 0.0157699 -0.00472315 -0.00631001 -0.0119287 -0.01288951 -0.00707795\n", - " 0.01728024 0.0105312 ]]\n", - "action 6\n", - "-0.2\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 3. 4.]]\n", + "[[-0.11165825 -0.13312925 -0.15180516 -0.14280865 -0.20325702 -0.18633628\n", + " -0.0402295 0.09051565]]\n", + "action 7\n", + "2.41404717866 True\n", + "\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 3. 4.]]\n", + "[[-0.11165825 -0.13312925 -0.15180516 -0.14280865 -0.20325702 -0.18633628\n", + " -0.0402295 0.09051565]]\n", + "action 7\n", + "2.41404717866 True\n", + "\n", + " [[ 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0.]\n", + " [ 0. 2. 0. 0.]\n", + " [ 1. 0. 3. 4.]]\n", + "[[-0.11165825 -0.13312925 -0.15180516 -0.14280865 -0.20325702 -0.18633628\n", + " -0.0402295 0.09051565]]\n", + "action 7\n", + "2.41404717866 True\n", "\r", - " -1.0" + " 11.9702358933 1.27614237492\n" ] } ], "source": [ - "env = Environment(3, 2)\n", + "env = Environment(4, 4)\n", "s = env.reset()\n", "\n", "\n", @@ -696,11 +759,11 @@ " ↘ 7\n", "\"\"\"\n", "\n", - "bot = target\n", + "bot = agent\n", "\n", "total_r = 0\n", "j = 0\n", - "for a in [1, 0, 0, 0, 0]:\n", + "for a in [1, 0, 7, 7, 1, 0, 1]:\n", " j += 1\n", "\n", " print(\"\\n\", env.state)\n", @@ -714,7 +777,7 @@ "\n", " print(a_dst)\n", " print(\"action\", a)\n", - " print(r)\n", + " print(r, d)\n", " total_r += r\n", " \n", "# if d == True:\n", @@ -723,1167 +786,32 @@ "# break\n", "# #env.reset()\n", " \n", - "print(\"\\r\", total_r, end=\"\")" + "print(\"\\r\", total_r, env.structure())" ] }, { "cell_type": "code", - "execution_count": 82, - "metadata": { - "run_control": { - "marked": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-1.47237194, -0.45450115, -1.1099143 , 0.40738916, -0.14500076,\n", - " -1.27405345, -1.88907146, -1.09437132, -1.61892748, -1.89695144,\n", - " -0.40908128, -0.8260721 , -1.37111664, -1.85722709, -0.28329837,\n", - " 0.00418878],\n", - " [-0.82942849, -0.06093037, -0.98205984, -1.03273547, -0.73091006,\n", - " -1.58550084, 0.10534322, -0.96237028, -0.77367294, -0.5560708 ,\n", - " -0.57345325, -1.33160257, -0.61364865, 0.44076538, -0.6399169 ,\n", - " -1.78533554],\n", - " [-0.37469065, -1.39439011, -1.14717066, -0.95678383, -0.83726352,\n", - " -0.88811439, -1.0354259 , -1.96661639, -0.09760374, -1.15479875,\n", - " -0.92978543, -1.10216045, -0.7275157 , -0.92905378, 0.5746758 ,\n", - " -0.68988132],\n", - " [-1.44932294, -0.95880038, -0.84693551, -0.9915266 , -1.32087588,\n", - " -1.57927155, -1.47370005, -0.30102974, -1.3647579 , -0.76708555,\n", - " -1.69573951, -1.70829821, -1.13219559, -1.22856677, -0.95958227,\n", - " -0.4713431 ],\n", - " [-1.27410245, -0.45723015, 0.74974871, -1.06617165, -0.26810974,\n", - " -0.57501745, -0.57527494, -1.18630826, -1.31585658, -1.45227206,\n", - " -1.09144866, -0.74763477, 0.50344932, -1.14395046, -0.56995487,\n", - " -0.84480089],\n", - " [-1.20297003, -1.69174564, -1.67804408, -1.90985036, -0.87092429,\n", - " -1.15351367, -1.55876327, -0.29219115, -1.46831954, -1.04226673,\n", - " -0.86516273, -1.35606909, -1.72454762, -1.71955156, -0.82059425,\n", - " -0.60765541]], dtype=float32)" - ] - }, - "execution_count": 82, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#update_target(operation_holder, sess)\n", - "sess.run(variables[0].value() - variables[6].value())" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([ 1., 1., 1., 0., 0., 0., 0., 0.])" + "(array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., -1., -1., 0., -1.,\n", + " 0., 0., 1.]), 10.140799116945235, True)" ] }, - "execution_count": 13, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "env = Environment(5, 4)\n", - "s = env.reset()\n", - "actions = s[-8:]\n", - "actions" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 1. 1. 1. 1. 1. 0. 0. 1.]\n", - "[0, 1, 2, 3, 4, 7]\n" - ] - }, - { - "data": { - "text/plain": [ - "array([[ 0., 0., 0., 0., 0.],\n", - " [ 0., 0., 5., 0., 0.],\n", - " [ 0., 4., 3., 0., 0.],\n", - " [ 1., 2., 0., 0., 0.]])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "s = env.step(0)[0]\n", - "s = env.step(1)[0]\n", - "s = env.step(4)[0]\n", - "s = env.step(1)[0]\n", - "actions = s[-8:]\n", - "print(actions)\n", - "print(env.valid_actions)\n", - "env.state" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-0.96334002]),\n", - " array([-1.19628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.28920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.07098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.17098354]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.1]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-0.96334002]),\n", - " array([-1.19628939]),\n", - " array([-3.29628939]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-0.91649658]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.07098354]),\n", - " array([-1.28920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.38920712]),\n", - " array([-1.19628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.19628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-0.96334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-0.96334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-0.91649658]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.1]),\n", - " array([-3.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.19628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-0.96334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.07098354]),\n", - " array([-1.28920712]),\n", - " array([-1.19628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.29628939]),\n", - " array([-1.1266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.2266901]),\n", - " array([-1.1]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-1.2]),\n", - " array([-0.96334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-1.06334002]),\n", - " array([-0.67735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.91649658]),\n", - " array([-0.78658905]),\n", - " array([-0.91649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-1.01649658]),\n", - " array([-0.67735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " array([-0.77735027]),\n", - " ...]" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "scores" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{5, 6, 7}" - ] - }, - "execution_count": 149, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\n", - "\"\"\" \n", - " → 0\n", - " ↗ 1\n", - " ↑ 2\n", - " ↖ 3\n", - " ← 4\n", - " ↙ 5\n", - " ↓ 6\n", - " ↘ 7\n", - "\"\"\"\n", - "env.no_action" + "env = Environment(4, 4)\n", + "env.reset()\n", + "env.step(1)\n", + "env.step(0)\n", + "env.step(7)" ] } ],