Skip to content

Hash Table Implementation

Question

Implement a hash table using array in Python.

Key trick

Use an array of buckets, where each bucket holds all entries whose keys hash to that index.

  • Index formula:

    • index = hash(key) % len(buckets)
  • If many keys land in one bucket:

    • Search that bucket linearly.
    • Replace value if key already exists.
    • Append if it does not.
  • To keep average \(O(1)\):

    • Resize when load factor becomes too large.
    • Rehash all entries into a bigger array.

Trap

  • Forgetting collision handling.
  • Forgetting to update an existing key and instead inserting duplicates.
  • Forgetting to rehash on resize.
  • Using mutable/unhashable keys like lists.
  • Forgetting Python hash can be negative, though modulo already handles it correctly.

Why is this algorithm interesting?

It gives array-like speed for arbitrary keys, which is why Python dict and set are so important.

Idiomatic Python implementation

from typing import Any, Optional, Tuple

class HashTable:
    # Separate chaining with dynamic resizing.
    # This is a good interview implementation because it is simple and correct.

    def __init__(self, capacity: int = 8) -> None:
        capacity = max(1, capacity)
        self._buckets: list[list[Tuple[Any, Any]]] = [[] for _ in range(capacity)]
        self._size = 0

    def _index(self, key: Any) -> int:
        # Map arbitrary hashable key to an array index.
        return hash(key) % len(self._buckets)

    def _resize(self) -> None:
        # Double the array size and re-insert every entry.
        old_buckets = self._buckets
        self._buckets = [[] for _ in range(len(old_buckets) * 2)]
        old_size = self._size
        self._size = 0

        for bucket in old_buckets:
            for key, value in bucket:
                self.put(key, value)

        # Size should stay the same after rehashing.
        assert self._size == old_size

    def put(self, key: Any, value: Any) -> None:
        # Resize before insert if load factor would get too high.
        if (self._size + 1) / len(self._buckets) > 0.75:
            self._resize()

        bucket = self._buckets[self._index(key)]

        # Update existing key if present.
        for i, (k, _) in enumerate(bucket):
            if k == key:
                bucket[i] = (key, value)
                return

        # Otherwise append a new key-value pair.
        bucket.append((key, value))
        self._size += 1

    def get(self, key: Any) -> Any:
        bucket = self._buckets[self._index(key)]
        for k, v in bucket:
            if k == key:
                return v
        raise KeyError(key)

    def delete(self, key: Any) -> None:
        bucket = self._buckets[self._index(key)]
        for i, (k, _) in enumerate(bucket):
            if k == key:
                bucket.pop(i)
                self._size -= 1
                return
        raise KeyError(key)

    def contains(self, key: Any) -> bool:
        bucket = self._buckets[self._index(key)]
        return any(k == key for k, _ in bucket)

    def __len__(self) -> int:
        return self._size

Pytest test

import pytest

class BadHash:
    # Forces collisions to test chaining.
    def __init__(self, value):
        self.value = value

    def __hash__(self):
        return 1

    def __eq__(self, other):
        return isinstance(other, BadHash) and self.value == other.value


@pytest.mark.parametrize(
    "ops, expected",
    [
        (
            [("put", "a", 1), ("get", "a", None)],
            [1],
        ),
        (
            [("put", "a", 1), ("put", "a", 2), ("get", "a", None), ("len", None, None)],
            [2, 1],
        ),
        (
            [("put", "a", 1), ("put", "b", 2), ("get", "a", None), ("get", "b", None)],
            [1, 2],
        ),
        (
            [("put", BadHash("x"), 10), ("put", BadHash("y"), 20), ("get", BadHash("x"), None), ("get", BadHash("y"), None)],
            [10, 20],
        ),
        (
            [("put", "a", 1), ("delete", "a", None), ("contains", "a", None), ("len", None, None)],
            [False, 0],
        ),
    ],
)
def test_hash_table(ops, expected):
    ht = HashTable()
    out = []

    for op, key, value in ops:
        if op == "put":
            ht.put(key, value)
        elif op == "get":
            out.append(ht.get(key))
        elif op == "delete":
            ht.delete(key)
        elif op == "contains":
            out.append(ht.contains(key))
        elif op == "len":
            out.append(len(ht))

    assert out == expected