目录

第七章 迭代器与生成器

本章概要

迭代器和生成器是Python中处理序列数据的强大工具。本章将深入探讨迭代协议、生成器函数、yield表达式和协程等高级特性。

7.1 迭代协议

7.1.1 可迭代对象与迭代器

# 可迭代对象(Iterable): 实现了__iter__()的对象
# 迭代器(Iterator): 实现了__iter__()和__next__()的对象
 
# 获取迭代器
my_list = [1, 2, 3]
iterator = iter(my_list)  # 调用__iter__()
 
# 使用迭代器
print(next(iterator))  # 1
print(next(iterator))  # 2
print(next(iterator))  # 3
# print(next(iterator))  # StopIteration异常

7.1.2 自定义迭代器

class CountDown:
    """倒数迭代器"""
    def __init__(self, start):
        self.start = start
 
    def __iter__(self):
        # 返回迭代器对象自身
        return self
 
    def __next__(self):
        if self.start <= 0:
            raise StopIteration
        self.start -= 1
        return self.start + 1
 
# 使用
for num in CountDown(5):
    print(num, end=" ")  # 5 4 3 2 1

7.1.3 可迭代对象 vs 迭代器

class Range:
    """可迭代对象(每次返回新的迭代器)"""
    def __init__(self, start, end):
        self.start = start
        self.end = end
 
    def __iter__(self):
        # 每次返回新的迭代器
        return RangeIterator(self.start, self.end)
 
class RangeIterator:
    """迭代器"""
    def __init__(self, start, end):
        self.current = start
        self.end = end
 
    def __iter__(self):
        return self
 
    def __next__(self):
        if self.current >= self.end:
            raise StopIteration
        num = self.current
        self.current += 1
        return num
 
# 可以多次迭代
r = Range(1, 4)
print(list(r))  # [1, 2, 3]
print(list(r))  # [1, 2, 3]

7.2 生成器

7.2.1 生成器函数

使用 ``yield`` 关键字的函数就是生成器函数。

def countdown(n):
    """生成器函数"""
    print(f"Starting countdown from {n}")
    while n > 0:
        yield n
        n -= 1
    print("Countdown finished!")
 
# 创建生成器对象(不会立即执行)
gen = countdown(3)
 
# 逐个获取值
print(next(gen))  # Starting... 3
print(next(gen))  # 2
print(next(gen))  # 1
# print(next(gen))  # StopIteration

7.2.2 生成器的状态

def generator_example():
    print("Start")
    yield 1
    print("Continue")
    yield 2
    print("End")
 
gen = generator_example()
print("Created generator")
print(next(gen))  # Start, 1
print("---")
print(next(gen))  # Continue, 2
print("---")
# print(next(gen))  # End, StopIteration

7.2.3 生成器表达式

# 列表推导式 - 立即计算
squares_list = [x**2 for x in range(1000000)]
 
# 生成器表达式 - 惰性求值
squares_gen = (x**2 for x in range(1000000))
 
print(sum(squares_gen))  # 按需计算,节省内存
 
# 生成器表达式作为函数参数
print(sum(x**2 for x in range(10)))
print(max(len(word) for word in ["hello", "world", "python"]))

7.2.4 生成器方法

def counter(maximum):
    i = 0
    while i < maximum:
        val = yield i
        print(f"Got value: {val}")
        if val is not None:
            i = val
        else:
            i += 1
 
gen = counter(10)
print(next(gen))      # 0
print(gen.send(5))    # Got value: 5, 5
print(next(gen))      # Got value: None, 6
 
# 抛出异常
gen.throw(ValueError, "Custom error")
 
# 关闭生成器
gen.close()

7.3 yield from

7.3.1 委托子生成器

def sub_generator():
    yield 1
    yield 2
    yield 3
 
def main_generator():
    yield "A"
    yield from sub_generator()  # 委托给子生成器
    yield "B"
 
print(list(main_generator()))  # ['A', 1, 2, 3, 'B']

7.3.2 双向通信

def accumulator():
    total = 0
    while True:
        value = yield total
        if value is None:
            break
        total += value
    return total
 
def delegator():
    result = yield from accumulator()
    print(f"Final total: {result}")
    return result
 
d = delegator()
print(next(d))       # 0
print(d.send(10))    # 10
print(d.send(20))    # 30
print(d.send(30))    # 60
try:
    d.send(None)
except StopIteration as e:
    print(f"Returned: {e.value}")  # Final total: 60, Returned: 60

7.4 生成器应用

7.4.1 无限序列

def fibonacci():
    """无限斐波那契数列"""
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b
 
# 获取前10个斐波那契数
fib = fibonacci()
print([next(fib) for _ in range(10)])
# [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

7.4.2 文件逐行读取

def read_large_file(file_path):
    """逐行读取大文件"""
    with open(file_path, 'r') as f:
        for line in f:
            yield line.strip()
 
# 处理大文件而不占用大量内存
# for line in read_large_file("large_file.txt"):
#     process(line)

7.4.3 流水线处理

def read_lines(file_path):
    """读取行"""
    with open(file_path) as f:
        for line in f:
            yield line.strip()
 
def filter_comments(lines):
    """过滤注释"""
    for line in lines:
        if not line.startswith('#'):
            yield line
 
def convert_to_int(lines):
    """转换为整数"""
    for line in lines:
        try:
            yield int(line)
        except ValueError:
            pass
 
# 流水线处理
# numbers = convert_to_int(filter_comments(read_lines("data.txt")))

7.5 协程(Coroutine)

7.5.1 基本概念

协程是可以在执行过程中暂停和恢复的函数,用于异步编程。

def simple_coroutine():
    print("协程启动")
    x = yield  # 暂停,等待值传入
    print(f"收到值: {x}")
    y = yield
    print(f"收到值: {y}")
    print("协程结束")
 
# 启动协程
coro = simple_coroutine()
next(coro)  # 预激(prime)协程
coro.send(10)  # 发送值
coro.send(20)

7.5.2 装饰器预激

from functools import wraps
 
def coroutine(func):
    """预激协程装饰器"""
    @wraps(func)
    def primer(*args, **kwargs):
        gen = func(*args, **kwargs)
        next(gen)
        return gen
    return primer
 
@coroutine
def averager():
    """计算移动平均值"""
    total = 0.0
    count = 0
    average = None
    while True:
        term = yield average
        total += term
        count += 1
        average = total / count
 
# 使用
avg = averager()  # 已经预激
print(avg.send(10))   # 10.0
print(avg.send(20))   # 15.0
print(avg.send(30))   # 20.0

7.5.3 使用协程处理数据

@coroutine
def printer():
    """打印接收到的数据"""
    while True:
        data = yield
        print(f"Received: {data}")
 
@coroutine
def filter_target(target, predicate):
    """过滤数据"""
    while True:
        data = yield
        if predicate(data):
            target.send(data)
 
@coroutine
def broadcast(targets):
    """广播到多个目标"""
    while True:
        data = yield
        for target in targets:
            target.send(data)
 
# 使用
p = printer()
f = filter_target(p, lambda x: x > 5)
f.send(3)   # 被过滤
f.send(10)  # Received: 10

7.6 itertools模块

import itertools
 
# 无限迭代器
counter = itertools.count(start=10, step=2)  # 10, 12, 14, ...
cycle = itertools.cycle([1, 2, 3])  # 1, 2, 3, 1, 2, 3, ...
repeat = itertools.repeat(10, 3)    # 10, 10, 10
 
# 有限迭代器
chain = itertools.chain([1, 2], [3, 4], [5, 6])  # 1, 2, 3, 4, 5, 6
compress = itertools.compress('ABCDEF', [1, 0, 1, 0, 1, 1])  # A, C, E, F
dropwhile = itertools.dropwhile(lambda x: x < 5, [1, 3, 5, 7, 2])  # 5, 7, 2
takewhile = itertools.takewhile(lambda x: x < 5, [1, 3, 5, 7, 2])  # 1, 3
 
# 组合生成器
product = itertools.product([1, 2], ['a', 'b'])  # (1,a), (1,b), (2,a), (2,b)
permutations = itertools.permutations([1, 2, 3], 2)  # 排列
combinations = itertools.combinations([1, 2, 3], 2)  # 组合
combinations_r = itertools.combinations_with_replacement([1, 2], 2)
 
print(list(combinations))

7.7 代码示例

示例1:实现一个树形结构的迭代器

class TreeNode:
    def __init__(self, value, children=None):
        self.value = value
        self.children = children or []
 
    def __iter__(self):
        """深度优先遍历"""
        yield self.value
        for child in self.children:
            yield from child
 
    def dfs(self):
        """深度优先搜索"""
        yield self.value
        for child in self.children:
            yield from child.dfs()
 
    def bfs(self):
        """广度优先搜索"""
        from collections import deque
        queue = deque([self])
        while queue:
            node = queue.popleft()
            yield node.value
            queue.extend(node.children)
 
# 构建树
root = TreeNode("A", [
    TreeNode("B", [
        TreeNode("D"),
        TreeNode("E")
    ]),
    TreeNode("C", [
        TreeNode("F")
    ])
])
 
print("DFS:", list(root.dfs()))  # ['A', 'B', 'D', 'E', 'C', 'F']
print("BFS:", list(root.bfs()))  # ['A', 'B', 'C', 'D', 'E', 'F']

示例2:使用生成器实现上下文管理器

from contextlib import contextmanager
 
@contextmanager
def managed_resource(name):
    """使用生成器实现上下文管理器"""
    print(f"Acquiring {name}...")
    resource = f"Resource({name})"
    try:
        yield resource
    finally:
        print(f"Releasing {name}...")
 
# 使用
with managed_resource("database") as db:
    print(f"Using {db}")
 
# 等效于
class ManagedResource:
    def __init__(self, name):
        self.name = name
 
    def __enter__(self):
        print(f"Acquiring {self.name}...")
        return f"Resource({self.name})"
 
    def __exit__(self, exc_type, exc_val, exc_tb):
        print(f"Releasing {self.name}...")

7.8 练习题

练习1:实现flatten生成器

def flatten(nested):
    """展平嵌套列表"""
    for item in nested:
        if isinstance(item, list):
            yield from flatten(item)
        else:
            yield item
 
# 测试
nested = [1, [2, [3, 4], 5], 6, [7, 8]]
print(list(flatten(nested)))  # [1, 2, 3, 4, 5, 6, 7, 8]

练习2:实现惰性读取大文件的类

class LazyFileReader:
    """惰性文件读取器"""
    def __init__(self, file_path, chunk_size=1024):
        self.file_path = file_path
        self.chunk_size = chunk_size
 
    def __iter__(self):
        with open(self.file_path, 'r') as f:
            while True:
                chunk = f.read(self.chunk_size)
                if not chunk:
                    break
                yield chunk
 
    def lines(self):
        """逐行读取"""
        with open(self.file_path, 'r') as f:
            for line in f:
                yield line.rstrip('\n')
 
# 使用
# reader = LazyFileReader("large_file.txt")
# for chunk in reader:
#     process(chunk)

本章小结

本章学习了Python的迭代器和生成器:

迭代器和生成器让Python能够高效处理大量数据,是编写Pythonic代码的重要工具。

进一步阅读