Add files via upload

This commit is contained in:
Peter Norvig 2023-01-05 18:54:23 -08:00 committed by GitHub
parent fb661727f8
commit c13ef643cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4292 additions and 861 deletions

File diff suppressed because one or more lines are too long

View File

@ -13,13 +13,13 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter, defaultdict, namedtuple, deque, abc\n",
"from dataclasses import dataclass\n",
"from itertools import permutations, combinations, cycle, chain\n",
"from dataclasses import dataclass, field\n",
"from itertools import permutations, combinations, cycle, chain, islice\n",
"from itertools import count as count_from, product as cross_product\n",
"from typing import *\n",
"from statistics import mean, median\n",
@ -28,6 +28,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"import ast\n",
"import fractions\n",
"import functools\n",
"import heapq\n",
"import operator\n",
@ -41,73 +42,69 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Daily Input Parsing\n",
"# Daily Workflow\n",
"\n",
"Each day's work will consist of three tasks, denoted by three sections in the notebook:\n",
"- **Input**: Parse the day's input file. I will use the function `parse(day, parser, sep)`, which:\n",
" - Reads the input file for `day`.\n",
" - Breaks the file into a sequence of *items* separated by `sep` (default newline).\n",
" - Applies `parser` to each item and returns the results as a tuple.\n",
" - Useful parser functions include `ints`, `digits`, `atoms`, `words`, and the built-ins `int` and `str`.\n",
" - Prints the first few input lines and output records. This is useful to me as a debugging tool, and to the reader.\n",
"- **Input**: Parse the day's input file with the function `parse`.\n",
"- **Part 1**: Understand the day's instructions and:\n",
" - Write code to compute the answer to Part 1.\n",
" - Once I have computed the answer and submitted it to the AoC site to verify it is correct, I record it with the `answer` function.\n",
" - Once I have computed the answer and submitted it to the AoC site to verify it is correct, I record it with the `answer` class.\n",
"- **Part 2**: Repeat the above steps for Part 2.\n",
"- Occasionally I'll introduce a **Part 3** where I explore beyond the official instructions.\n",
"\n",
"Here is `parse`:"
"# Parsing Input Files\n",
"\n",
"The function `parse` is meant to handle each day's input. A call `parse(day, parser, sections)` does the following:\n",
" - Reads the input file for `day`.\n",
" - Breaks the file into a *sections*. By default, this is lines, but you can use `paragraphs`, or pass in a custom function.\n",
" - Applies `parser` to each section and returns the results as a tuple of records.\n",
" - Useful parser functions include `ints`, `digits`, `atoms`, `words`, and the built-ins `int` and `str`.\n",
" - Prints the first few input lines and output records. This is useful to me as a debugging tool, and to the reader.\n",
" - The defaults are `parser=str, sections=lines`, so by default `parse(n)` gives a tuple of lines from fuile *day*."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"current_year = 2022 # Subdirectory name for input files\n",
"lines = '\\n' # For inputs where each record is a line\n",
"paragraphs = '\\n\\n' # For inputs where each record is a paragraph \n",
"current_year = 2022 # Subdirectory name for input files\n",
"lines = str.splitlines # By default, split input text into lines\n",
"def paragraphs(text): \"Split text into paragraphs\"; return text.split('\\n\\n')\n",
"\n",
"def parse(day_or_text:Union[int, str], parser:Callable=str, sep:str=lines, show=6) -> tuple:\n",
" \"\"\"Split the input text into items separated by `sep`, and apply `parser` to each.\n",
"def parse(day_or_text:Union[int, str], parser:Callable=str, sections:Callable=lines, show=8) -> tuple:\n",
" \"\"\"Split the input text into `sections`, and apply `parser` to each.\n",
" The first argument is either the text itself, or the day number of a text file.\"\"\"\n",
" if isinstance(day_or_text, str) and show == 8: \n",
" show = 0 # By default, don't show lines when parsing exampole text.\n",
" start = time.time()\n",
" text = get_text(day_or_text)\n",
" print_parse_items('Puzzle input', text.splitlines(), show, 'line')\n",
" records = mapt(parser, text.rstrip().split(sep))\n",
" if parser != str or sep != lines:\n",
" print_parse_items('Parsed representation', records, show, f'{type(records[0]).__name__}')\n",
" show_items('Puzzle input', text.splitlines(), show)\n",
" records = mapt(parser, sections(text.rstrip()))\n",
" if parser != str or sections != lines:\n",
" show_items('Parsed representation', records, show)\n",
" return records\n",
"\n",
"def get_text(day_or_text:Union[int, str]) -> str:\n",
" \"\"\"The text used as input to the puzzle: either a string or the day number of a file.\"\"\"\n",
" if isinstance(day_or_text, int):\n",
" return pathlib.Path(f'AOC/{current_year}/input{day_or_text}.txt').read_text()\n",
" else:\n",
" \"\"\"The text used as input to the puzzle: either a string or the day number,\n",
" which denotes the file 'AOC/year/input{day}.txt'.\"\"\"\n",
" if isinstance(day_or_text, str):\n",
" return day_or_text\n",
" else:\n",
" filename = f'AOC/{current_year}/input{day_or_text}.txt'\n",
" return pathlib.Path(filename).read_text()\n",
"\n",
"def print_parse_items(source, items, show:int, name:str, sep=\"─\"*100):\n",
" \"\"\"Print verbose output from `parse` for lines or records.\"\"\"\n",
" if not show:\n",
" return\n",
" count = f'1 {name}' if len(items) == 1 else f'{len(items)} {name}s'\n",
" for line in (sep, f'{source} ➜ {count}:', sep, *items[:show]):\n",
" print(truncate(line))\n",
" if show < len(items):\n",
" print('...')\n",
" \n",
"def truncate(object, width=100) -> str:\n",
" \"\"\"Use elipsis to truncate `str(object)` to `width` characters, if necessary.\"\"\"\n",
" string = str(object)\n",
" return string if len(string) <= width else string[:width-4] + ' ...'\n",
"\n",
"def parse_sections(specs: Iterable) -> Callable:\n",
" \"\"\"Return a parser that uses the first spec to parse the first section, the second for second, etc.\n",
" Each spec is either parser or [parser, sep].\"\"\"\n",
" specs = ([spec] if callable(spec) else spec for spec in specs)\n",
" fns = ((lambda section: parse(section, *spec, show=0)) for spec in specs)\n",
" return lambda section: next(fns)(section)"
"def show_items(source, items, show:int, hr=\"─\"*100):\n",
" \"\"\"Show the first few items, in a pretty format.\"\"\"\n",
" if show:\n",
" types = Counter(map(type, items))\n",
" counts = ', '.join(f'{n} {t.__name__}{\"\" if n == 1 else \"s\"}' for t, n in types.items())\n",
" print(f'{hr}\\n{source} ➜ {counts}:\\n{hr}')\n",
" for line in items[:show]:\n",
" print(truncate(line))\n",
" if show < len(items):\n",
" print('...')"
]
},
{
@ -119,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -152,11 +149,34 @@
" x = float(text)\n",
" return round(x) if x.is_integer() else x\n",
" except ValueError:\n",
" return text.strip()\n",
" \n",
" return text.strip()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Helper functions:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def truncate(object, width=100, ellipsis=' ...') -> str:\n",
" \"\"\"Use elipsis to truncate `str(object)` to `width` characters, if necessary.\"\"\"\n",
" string = str(object)\n",
" return string if len(string) <= width else string[:width-len(ellipsis)] + ellipsis\n",
"\n",
"def mapt(function: Callable, *sequences) -> tuple:\n",
" \"\"\"`map`, with the result as a tuple.\"\"\"\n",
" return tuple(map(function, *sequences))"
" return tuple(map(function, *sequences))\n",
"\n",
"def mapl(function: Callable, *sequences) -> list:\n",
" \"\"\"`map`, with the result as a list.\"\"\"\n",
" return list(map(function, *sequences))"
]
},
{
@ -168,7 +188,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -190,31 +210,50 @@
"source": [
"# Daily Answers\n",
"\n",
"Here is the `answer` function, which gives verification of a correct computation (or an error message for an incorrect computation), times how long the computation took, ans stores the result in the dict `answers`."
"Here is the `answer` class, which gives verification of a correct computation (or an error message for an incorrect computation), times how long the computation took, and stores the result in the dict `answers`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 91,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
" .0000 seconds, answer: 3 INCORRECT!!!! Expected 2"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# `answers` is a dict of {puzzle_number_id: message_about_results}\n",
"answers = {} \n",
"answers = {} # `answers` is a dict of {puzzle_number: answer}\n",
"\n",
"def answer(puzzle, correct, code: callable):\n",
" \"\"\"Verify that calling `code` computes the `correct` answer for `puzzle`. \n",
" Record results in the dict `answers`. Prints execution time.\"\"\"\n",
" def pretty(x): return f'{x:,d}' if is_int(x) else truncate(x)\n",
" start = time.time()\n",
" got = code()\n",
" secs = time.time() - start\n",
" ans = pretty(got)\n",
" msg = f'{secs:5.3f} seconds for ' + (\n",
" f'correct answer: {ans}' if (got == correct) else\n",
" f'WRONG!! ANSWER: {ans}; EXPECTED {pretty(correct)}')\n",
" answers[puzzle] = msg\n",
" print(msg)"
"class answer:\n",
" \"\"\"Verify that calling `code` computes the `solution` to `puzzle`. \n",
" Record results in the dict `answers`.\"\"\"\n",
" def __init__(self, puzzle, solution, code:callable):\n",
" self.solution, self.code = solution, code\n",
" answers[puzzle] = self\n",
" self.check()\n",
" \n",
" def check(self):\n",
" \"\"\"Check if the code computes the correct solution; record run time.\"\"\"\n",
" start = time.time()\n",
" self.got = self.code()\n",
" self.secs = time.time() - start\n",
" self.ok = (self.got == self.solution)\n",
" return self.ok\n",
" \n",
" def __repr__(self):\n",
" \"\"\"The repr of an answer shows what happened.\"\"\"\n",
" def commas(x) -> str: return f'{x:,d}' if is_int(x) else f'{x}'\n",
" secs = f'{self.secs:7.4f} seconds'.replace(' 0.', ' .')\n",
" ok = '' if self.ok else f' !!!! INCORRECT !!!! Expected {commas(self.solution)}'\n",
" return f'{secs}, answer: {commas(self.got)}{ok}'"
]
},
{
@ -228,13 +267,13 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class multimap(defaultdict):\n",
" \"\"\"A mapping of {key: [val1, val2, ...]}.\"\"\"\n",
" def __init__(self, pairs: Iterable[tuple], symmetric=False):\n",
" def __init__(self, pairs:Iterable[tuple]=(), symmetric=False):\n",
" \"\"\"Given (key, val) pairs, return {key: [val, ...], ...}.\n",
" If `symmetric` is True, treat (key, val) as (key, val) plus (val, key).\"\"\"\n",
" self.default_factory = list\n",
@ -265,15 +304,16 @@
"\n",
"def cover(*integers) -> range:\n",
" \"\"\"A `range` that covers all the given integers, and any in between them.\n",
" cover(lo, hi) is a an inclusive (or closed) range, equal to range(lo, hi + 1).\"\"\"\n",
" cover(lo, hi) is an inclusive (or closed) range, equal to range(lo, hi + 1).\n",
" The same range results from cover(hi, lo) or cover([hi, lo]).\"\"\"\n",
" if len(integers) == 1: integers = the(integers)\n",
" return range(min(integers), max(integers) + 1)\n",
"\n",
"def the(sequence) -> object:\n",
" \"\"\"Return the one item in a sequence. Raise error if not exactly one.\"\"\"\n",
" items = list(sequence)\n",
" if not len(items) == 1:\n",
" raise ValueError(f'Expected exactly one item in the sequence {items}')\n",
" return items[0]\n",
" for i, item in enumerate(sequence, 1):\n",
" if i > 1: raise ValueError(f'Expected exactly one item in the sequence.')\n",
" return item\n",
"\n",
"def split_at(sequence, i) -> Tuple[Sequence, Sequence]:\n",
" \"\"\"The sequence split into two pieces: (before position i, and i-and-after).\"\"\"\n",
@ -285,6 +325,8 @@
"\n",
"def sign(x) -> int: \"0, +1, or -1\"; return (0 if x == 0 else +1 if x > 0 else -1)\n",
"\n",
"def lcm(i, j) -> int: \"Least common multiple\"; return i * j // gcd(i, j)\n",
"\n",
"def union(sets) -> set: \"Union of several sets\"; return set().union(*sets)\n",
"\n",
"def intersection(sets):\n",
@ -306,8 +348,19 @@
" # This is like a clock, where 24 mod 12 is 12, not 0.\n",
" return (i % m) or m\n",
"\n",
"def invert_dict(dic) -> dict:\n",
" \"\"\"Invert a dict, e.g. {1: 'a', 2: 'b'} -> {'a': 1, 'b': 2}.\"\"\"\n",
" return {dic[x]: x for x in dic}\n",
"\n",
"def walrus(name, value):\n",
" \"\"\"If you're not in 3.8, and you can't do `x := val`,\n",
" then you can use `walrus('x', val)`.\"\"\"\n",
" globals()[name] = value\n",
" return value\n",
"\n",
"cat = ''.join\n",
"cache = functools.lru_cache(None)"
"cache = functools.lru_cache(None)\n",
"Ø = frozenset() # empty set"
]
},
{
@ -350,7 +403,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -362,14 +415,25 @@
" \"\"\"The dot product of two vectors.\"\"\"\n",
" return sum(map(operator.mul, vec1, vec2))\n",
"\n",
"def powerset(iterable) -> Iterable[tuple]:\n",
" \"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)\"\n",
" s = list(iterable)\n",
" return flatten(combinations(s, r) for r in range(len(s) + 1))\n",
"\n",
"flatten = chain.from_iterable # Yield items from each sequence in turn\n",
"\n",
"def append(sequences) -> Sequence: \"Append into a list\"; return list(flatten(sequences))\n",
"\n",
"def batched(data, n) -> list:\n",
" \"Batch data into lists of length n. The last batch may be shorter.\"\n",
"def batched(iterable, n) -> Iterable[tuple]:\n",
" \"Batch data into non-overlapping tuples of length n. The last batch may be shorter.\"\n",
" # batched('ABCDEFG', 3) --> ABC DEF G\n",
" return [data[i:i+n] for i in range(0, len(data), n)]\n",
" it = iter(iterable)\n",
" while True:\n",
" batch = tuple(islice(it, n))\n",
" if batch:\n",
" yield batch\n",
" else:\n",
" return\n",
"\n",
"def sliding_window(sequence, n) -> Iterable[Sequence]:\n",
" \"\"\"All length-n subsequences of sequence.\"\"\"\n",
@ -379,6 +443,16 @@
" \"\"\"The first element in an iterable, or the default if iterable is empty.\"\"\"\n",
" return next(iter(iterable), default)\n",
"\n",
"def last(iterable) -> Optional[object]: \n",
" \"\"\"The last element in an iterable.\"\"\"\n",
" for item in iterable:\n",
" pass\n",
" return item\n",
"\n",
"def nth(iterable, n, default=None):\n",
" \"Returns the nth item or a default value\"\n",
" return next(islice(iterable, n, None), default)\n",
"\n",
"def first_true(iterable, default=False):\n",
" \"\"\"Returns the first true value in the iterable.\n",
" If no true value is found, returns `default`.\"\"\"\n",
@ -394,7 +468,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -402,10 +476,79 @@
"assert dotproduct([1, 2, 3, 4], [1000, 100, 10, 1]) == 1234\n",
"assert list(flatten([{1, 2, 3}, (4, 5, 6), [7, 8, 9]])) == [1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
"assert append(([1, 2], [3, 4], [5, 6])) == [1, 2, 3, 4, 5, 6]\n",
"assert batched('abcdefghi', 3) == ['abc', 'def', 'ghi']\n",
"assert list(batched(range(11), 3)) == [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10)]\n",
"assert list(sliding_window('abcdefghi', 3)) == ['abc', 'bcd', 'cde', 'def', 'efg', 'fgh', 'ghi']\n",
"assert first('abc') == 'a'\n",
"assert first_true([0, None, False, 42, 99]) == 42"
"assert first('') == None\n",
"assert last('abc') == 'c'\n",
"assert first_true([0, None, False, 42, 99]) == 42\n",
"assert first_true([0, None, '', 0.0]) == False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Points in Space\n",
"\n",
"Many puzzles involve points; usually two-dimensional points on a plane. A few puzzles involve three-dimensional points, and perhaps one might involve non-integers, so I'll try to make my `Point` implementation flexible in a duck-typing way. A point can also be considered a `Vector`; that is, `(1, 0)` can be a `Point` that means \"this is location x=1, y=0 in the plane\" and it also can be a `Vector` that means \"move Eat (+1 in the along the x axis).\" First we'll define points/vectors:"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"Point = Tuple[int, ...] # Type for points\n",
"Vector = Point # E.g., (1, 0) can be a point, or can be a direction, a Vector\n",
"Zero = (0, 0)\n",
"\n",
"directions4 = East, South, West, North = ((1, 0), (0, 1), (-1, 0), (0, -1))\n",
"diagonals = SE, NE, SW, NW = ((1, 1), (1, -1), (-1, 1), (-1, -1))\n",
"directions8 = directions4 + diagonals\n",
"directions5 = directions4 + (Zero,)\n",
"directions9 = directions8 + (Zero,)\n",
"arrow_direction = {'^': North, 'v': South, '>': East, '<': West, '.': Zero,\n",
" 'U': North, 'D': South, 'R': East, 'L': West}\n",
"\n",
"def X_(point) -> int: \"X coordinate of a point\"; return point[0]\n",
"def Y_(point) -> int: \"Y coordinate of a point\"; return point[1]\n",
"def Z_(point) -> int: \"Z coordinate of a point\"; return point[2]\n",
"\n",
"def Xs(points) -> Tuple[int]: \"X coordinates of a collection of points\"; return mapt(X_, points)\n",
"def Ys(points) -> Tuple[int]: \"Y coordinates of a collection of points\"; return mapt(Y_, points)\n",
"def Zs(points) -> Tuple[int]: \"X coordinates of a collection of points\"; return mapt(Z_, points)\n",
"\n",
"def add(p: Point, q: Point) -> Point: return mapt(operator.add, p, q)\n",
"def sub(p: Point, q: Point) -> Point: return mapt(operator.sub, p, q)\n",
"def neg(p: Point) -> Vector: return mapt(operator.neg, p)\n",
"def mul(p: Point, k: float) -> Vector: return tuple(k * c for c in p)\n",
"\n",
"def distance(p: Point, q: Point) -> float:\n",
" \"\"\"Euclidean (L2) distance between two points.\"\"\"\n",
" d = sum((pi - qi) ** 2 for pi, qi in zip(p, q)) ** 0.5\n",
" return int(d) if d.is_integer() else d\n",
"\n",
"def slide(points: Set[Point], delta: Vector) -> Set[Point]: \n",
" \"\"\"Slide all the points in the set of points by the amount delta.\"\"\"\n",
" return {add(p, delta) for p in points}\n",
"\n",
"def make_turn(facing:Vector, turn:str) -> Vector:\n",
" \"\"\"Turn 90 degrees left or right. `turn` can be 'L' or 'Left' or 'R' or 'Right' or lowercase.\"\"\"\n",
" (x, y) = facing\n",
" return (y, -x) if turn[0] in ('L', 'l') else (-y, x)\n",
"\n",
"# Profiling found that `add` and `taxi_distance` were speed bottlenecks; \n",
"# I define below versions that are specialized for 2D points only.\n",
"\n",
"def add2(p: Point, q: Point) -> Point: \n",
" \"\"\"Specialized version of point addition for 2D Points only. Faster.\"\"\"\n",
" return (p[0] + q[0], p[1] + q[1])\n",
"\n",
"def taxi_distance(p: Point, q: Point) -> int:\n",
" \"\"\"Manhattan (L1) distance between two 2D Points.\"\"\"\n",
" return abs(p[0] - q[0]) + abs(p[1] - q[1])"
]
},
{
@ -414,102 +557,106 @@
"source": [
"# Points on a Grid\n",
"\n",
"Many puzzles seem to involve a two-dimensional rectangular grid with integer coordinates. First we'll define the two-dimensional `Point`, then the `Grid`."
"Many puzzles seem to involve a two-dimensional rectangular grid with integer coordinates. A `Grid` is a rectangular array of (integer, integer) points, where each point holds some contents. Important things to know:\n",
"- `Grid` is a subclass of `dict`\n",
"- Usually the contents will be a character or an integer, but that's not specified or restricted. \n",
"- A Grid can be initialized three ways:\n",
" - With another dict of `{point: contents}`, or an iterable of `(point, contents) pairs.\n",
" - With an iterable of strings, each depicting a row (e.g. `[\"#..\", \"..#\"]`.\n",
" - With a single string, which will be split on newlines.\n",
"- Contents that are a member of `skip` will be skipped. (For example, you could do `skip=[' ']` to not store any point that has a space as its contents.\n",
"- There is a `grid.neighbors(point)` method. By default it returns the 4 orthogonal neighbors but you could make it all 8 adjacent squares, or something else, by specifying the `directions` keyword value in the `Grid` constructor.\n",
"- By default, grids have bounded size; accessing a point outside the grid results in a `KeyError`. But some grids extend in all directions without limit; you can implement that by specifying, say, `default='.'` to make `'.'` contents in all directions."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"Point = Tuple[int, int] # (x, y) points on a grid\n",
"\n",
"def X_(point) -> int: \"X coordinate\"; return point[0]\n",
"def Y_(point) -> int: \"Y coordinate\"; return point[1]\n",
"\n",
"def distance(p: Point, q: Point) -> float:\n",
" \"\"\"Distance between two points.\"\"\"\n",
" dx, dy = abs(X_(p) - X_(q)), abs(Y_(p) - Y_(q))\n",
" return dx + dy if dx == 0 or dy == 0 else (dx ** 2 + dy ** 2) ** 0.5\n",
"\n",
"def manhatten_distance(p: Point, q: Point) -> int:\n",
" \"\"\"Distance along grid lines between two points.\"\"\"\n",
" return sum(abs(pi - qi) for pi, qi in zip(p, q))\n",
"\n",
"def add(p: Point, q: Point) -> Point:\n",
" \"\"\"Add two points.\"\"\"\n",
" return (X_(p) + X_(q), Y_(p) + Y_(q))\n",
"\n",
"def sub(p: Point, q: Point) -> Point:\n",
" \"\"\"Subtract point q from point p.\"\"\"\n",
" return (X_(p) - X_(q), Y_(p) - Y_(q))\n",
"\n",
"directions4 = North, South, East, West = ((0, -1), (0, 1), (1, 0), (-1, 0))\n",
"directions8 = directions4 + ((1, 1), (1, -1), (-1, 1), (-1, -1))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"class Grid(dict):\n",
" \"\"\"A 2D grid, implemented as a mapping of {(x, y): cell_contents}.\"\"\"\n",
" def __init__(self, mapping_or_rows=(), directions=directions4):\n",
" \"\"\"Initialize with either (e.g.) `Grid({(0, 0): 1, (1, 0): 2, ...})`, or\n",
" `Grid([(1, 2, 3), (4, 5, 6)]).\"\"\"\n",
" def __init__(self, grid=(), directions=directions4, skip=(), default=KeyError):\n",
" \"\"\"Initialize with either (e.g.) `Grid({(0, 0): '#', (1, 0): '.', ...})`, or\n",
" `Grid([\"#..\", \"..#\"]) or `Grid(\"#..\\n..#\")`.\"\"\"\n",
" self.directions = directions\n",
" self.update(mapping_or_rows if isinstance(mapping_or_rows, abc.Mapping) else\n",
" {(x, y): val \n",
" for y, row in enumerate(mapping_or_rows) \n",
" for x, val in enumerate(row)})\n",
"\n",
" self.default = default\n",
" if isinstance(grid, abc.Mapping): \n",
" self.update(grid) \n",
" else:\n",
" if isinstance(grid, str): \n",
" grid = grid.splitlines()\n",
" self.update({(x, y): val \n",
" for y, row in enumerate(grid) \n",
" for x, val in enumerate(row)\n",
" if val not in skip})\n",
" \n",
" def copy(self): return Grid(self, directions=self.directions)\n",
" def __missing__(self, point): \n",
" \"\"\"If asked for a point off the grid, either return default or raise error.\"\"\"\n",
" if self.default == KeyError:\n",
" raise KeyError(point)\n",
" else:\n",
" return self.default\n",
"\n",
" def copy(self): return Grid(self, directions=self.directions, default=self.default)\n",
" \n",
" def neighbors(self, point) -> List[Point]:\n",
" \"\"\"Points on the grid that neighbor `point`.\"\"\"\n",
" return [add(point, Δ) for Δ in self.directions if add(point, Δ) in self]\n",
" return [add2(point, Δ) for Δ in self.directions \n",
" if add2(point, Δ) in self or self.default != KeyError]\n",
" \n",
" def to_rows(self, default='.', Xs=None, Ys=None) -> List[List[object]]:\n",
" \"\"\"The contents of the grid in a rectangular list of lists.\"\"\"\n",
" Xs = Xs or range(max(map(X_, self)) + 1)\n",
" Ys = Ys or range(max(map(Y_, self)) + 1)\n",
" return [[self.get((x, y), default) for x in Xs] for y in Ys]\n",
" def neighbor_contents(self, point) -> Iterable:\n",
" \"\"\"The contents of the neighboring points.\"\"\"\n",
" return (self[p] for p in self.neighbors(point))\n",
" \n",
" def to_picture(self, sep='', default='.', Xs=None, Ys=None) -> str:\n",
" \"\"\"The contents of the grid as a picture. Youi can specify the `Xs` and `Ys` to include.\"\"\"\n",
" return '\\n'.join(map(sep.join, self.to_rows(default, Xs, Ys)))\n",
" def to_rows(self, xrange=None, yrange=None) -> List[List[object]]:\n",
" \"\"\"The contents of the grid, as a rectangular list of lists.\n",
" You can define a window with an xrange and yrange; or they default to the whole grid.\"\"\"\n",
" xrange = xrange or cover(Xs(self))\n",
" yrange = yrange or cover(Ys(self))\n",
" default = ' ' if self.default is KeyError else self.default\n",
" return [[self.get((x, y), default) for x in xrange] \n",
" for y in yrange]\n",
"\n",
" def print(self, sep='', xrange=None, yrange=None):\n",
" \"\"\"Print a representation of the grid.\"\"\"\n",
" for row in self.to_rows(xrange, yrange):\n",
" print(*row, sep=sep)\n",
" \n",
" def plot(self, markers, figsize=(14, 14), **kwds):\n",
" def plot(self, markers={'#': 's', '.': ','}, figsize=(14, 14), **kwds):\n",
" \"\"\"Plot a representation of the grid.\"\"\"\n",
" plt.figure(figsize=figsize)\n",
" plt.gca().invert_yaxis()\n",
" for m in markers:\n",
" plt.plot(*T(p for p in self if self[p] == m), markers.get(m, m), **kwds)"
" plt.plot(*T(p for p in self if self[p] == m), markers[m], **kwds)\n",
" \n",
"def neighbors(point, directions=directions4) -> List[Point]:\n",
" \"\"\"Neighbors of this point, in the given directions.\n",
" (This function can be used outside of a Grid class.)\"\"\"\n",
" return [add(point, Δ) for Δ in directions]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tests:"
"Here are some tests:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p, q = (0, 3), (4, 0)\n",
"assert Y_(p) == 3 and X_(q) == 4\n",
"assert distance(p, q) == 5\n",
"assert manhatten_distance(p, q) == 7\n",
"assert taxi_distance(p, q) == 7\n",
"assert add(p, q) == (4, 3)\n",
"assert sub(p, q) == (-4, 3)\n",
"assert add(North, South) == (0,0)"
"assert add(North, South) == (0, 0)"
]
},
{
@ -518,12 +665,18 @@
"source": [
"# A* Search\n",
"\n",
"Many puzzles involve searching over a branching tree of possibilities. For many puzzles, an ad-hoc solution is fine. But when there is a larger search space, it is useful to have a pre-defined efficient best-first search algorithm, and in particular an A* search, which incorporates a heuristic function to estimate the remaining distance to the goal. This is a somewhat heavy-weight approach, as it requires the solver to define a subclass of `SearchProblem`."
"Many puzzles involve searching over a branching tree of possibilities. For many puzzles, an ad-hoc solution is fine. Different problems require different things from a search: \n",
"- Some just need to know the final goal state.\n",
"- Some need to know the sequence of actions that led to the final state.\n",
"- Some neeed to know the sequence of intermediate states. \n",
"- Some need to know the number of steps (or the total cost) to get to the final state.\n",
"\n",
"But sometimes you need all of that (or you think you might need it in Part 2), and sometimes you have a good heuristic estimate of the distance to a goal state, and you want to make sure to use it. If that's the case, then my `SearchProblem` class and `A_star_search` function may be approopriate."
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -551,7 +704,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -579,15 +732,14 @@
" A state is just an (x, y) location in the grid.\"\"\"\n",
" def actions(self, loc): return self.grid.neighbors(loc)\n",
" def result(self, loc1, loc2): return loc2\n",
" def action_cost(self, s1, a, s2): return self.grid[s2]\n",
" def h(self, node): return manhatten_distance(node.state, self.goal) \n",
" def h(self, node): return taxi_distance(node.state, self.goal) \n",
"\n",
"class Node:\n",
" \"A Node in a search tree.\"\n",
" def __init__(self, state, parent=None, action=None, path_cost=0):\n",
" self.__dict__.update(state=state, parent=parent, action=action, path_cost=path_cost)\n",
"\n",
" def __repr__(self): return f'Node({self.state})'\n",
" def __repr__(self): return f'Node({self.state}, path_cost={self.path_cost})'\n",
" def __len__(self): return 0 if self.parent is None else (1 + len(self.parent))\n",
" def __lt__(self, other): return self.path_cost < other.path_cost\n",
" \n",
@ -614,9 +766,22 @@
" return path_states(node.parent) + [node.state]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Other Data Structures\n",
"\n",
"Here I define a few data types:\n",
"- The priority queue, which is needed for A* search.\n",
"- Hashable versions of dicts and Counters. These can be used in sets or as keys in dicts. Beware: unlike the `frozenset`, these are not safe: if you modify one after inserting it in a set or dict, it probably will not be found.\n",
"- Graphs of `{node: [neighboring_node, ...]}`.\n",
"- An `AttrCounter`, which is just like a `Counter`, but can be accessed with, say, `ctr.name` as well as `ctr['name']`. "
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -642,11 +807,54 @@
"\n",
" def __len__(self): return len(self.items)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Hdict(dict):\n",
" \"\"\"A dict, but it is hashable.\"\"\"\n",
" def __hash__(self): return hash(tuple(sorted(self.items())))\n",
" \n",
"class HCounter(Counter):\n",
" \"\"\"A Counter, but it is hashable.\"\"\"\n",
" def __hash__(self): return hash(tuple(sorted(self.items())))"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"class Graph(dict):\n",
" \"\"\"A graph of {node: [neighboring_nodes...]}. \n",
" Can store other kwd attributes on it (which you can't do with a dict).\"\"\"\n",
" def __init__(self, contents, **kwds):\n",
" self.update(contents)\n",
" self.__dict__.update(**kwds)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class AttrCounter(Counter):\n",
" \"\"\"A Counter, but `ctr['name']` and `ctr.name` are the same.\"\"\"\n",
" def __getattr__(self, attr):\n",
" return self[attr]\n",
" def __setattr__(self, attr, value):\n",
" self[attr] = value"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -660,7 +868,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.15"
}
},
"nbformat": 4,