import io
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Type
[docs]class BufferedStream(Iterator):
def __init__(self):
self.input = io.BytesIO()
self.pos_queue = []
def __iter__(self):
return self
def __next__(self):
data = self.input.read(1)
if not data:
raise StopIteration
return data[0]
[docs] @contextmanager
def rollback(self, error_type: Type[Exception]):
self.pos_queue.append(self.input.tell())
error = None
try:
yield
except error_type as e:
error = e
raise
finally:
pos = self.pos_queue.pop()
if error is not None:
self.input.seek(pos)
[docs] def is_eof(self) -> bool:
return self.size() == 0
[docs] def size(self) -> int:
length = len(self.input.getvalue()) - self.input.tell()
assert length >= 0
return length
[docs] def close(self) -> None:
self.input.close()
self.pos_queue = []
[docs] def read(self, size: int = -1) -> bytes:
return self.input.read(size)
[docs] def write(self, data: bytes) -> int:
return self.input.write(data)
[docs] def push(self, data: bytes) -> None:
pos = self.input.tell()
try:
self.input.seek(0, io.SEEK_END)
self.input.write(data)
finally:
self.input.seek(pos)
[docs] def tell(self) -> int:
return self.input.tell()