Spaces:
Paused
Paused
| # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
| import collections | |
| from typing import Any, Callable, Iterator, List, Optional, Tuple, Union | |
| import dns.exception | |
| import dns.name | |
| import dns.node | |
| import dns.rdataclass | |
| import dns.rdataset | |
| import dns.rdatatype | |
| import dns.rrset | |
| import dns.serial | |
| import dns.ttl | |
| class TransactionManager: | |
| def reader(self) -> "Transaction": | |
| """Begin a read-only transaction.""" | |
| raise NotImplementedError # pragma: no cover | |
| def writer(self, replacement: bool = False) -> "Transaction": | |
| """Begin a writable transaction. | |
| *replacement*, a ``bool``. If `True`, the content of the | |
| transaction completely replaces any prior content. If False, | |
| the default, then the content of the transaction updates the | |
| existing content. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def origin_information( | |
| self, | |
| ) -> Tuple[Optional[dns.name.Name], bool, Optional[dns.name.Name]]: | |
| """Returns a tuple | |
| (absolute_origin, relativize, effective_origin) | |
| giving the absolute name of the default origin for any | |
| relative domain names, the "effective origin", and whether | |
| names should be relativized. The "effective origin" is the | |
| absolute origin if relativize is False, and the empty name if | |
| relativize is true. (The effective origin is provided even | |
| though it can be computed from the absolute_origin and | |
| relativize setting because it avoids a lot of code | |
| duplication.) | |
| If the returned names are `None`, then no origin information is | |
| available. | |
| This information is used by code working with transactions to | |
| allow it to coordinate relativization. The transaction code | |
| itself takes what it gets (i.e. does not change name | |
| relativity). | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def get_class(self) -> dns.rdataclass.RdataClass: | |
| """The class of the transaction manager.""" | |
| raise NotImplementedError # pragma: no cover | |
| def from_wire_origin(self) -> Optional[dns.name.Name]: | |
| """Origin to use in from_wire() calls.""" | |
| (absolute_origin, relativize, _) = self.origin_information() | |
| if relativize: | |
| return absolute_origin | |
| else: | |
| return None | |
| class DeleteNotExact(dns.exception.DNSException): | |
| """Existing data did not match data specified by an exact delete.""" | |
| class ReadOnly(dns.exception.DNSException): | |
| """Tried to write to a read-only transaction.""" | |
| class AlreadyEnded(dns.exception.DNSException): | |
| """Tried to use an already-ended transaction.""" | |
| def _ensure_immutable_rdataset(rdataset): | |
| if rdataset is None or isinstance(rdataset, dns.rdataset.ImmutableRdataset): | |
| return rdataset | |
| return dns.rdataset.ImmutableRdataset(rdataset) | |
| def _ensure_immutable_node(node): | |
| if node is None or node.is_immutable(): | |
| return node | |
| return dns.node.ImmutableNode(node) | |
| CheckPutRdatasetType = Callable[ | |
| ["Transaction", dns.name.Name, dns.rdataset.Rdataset], None | |
| ] | |
| CheckDeleteRdatasetType = Callable[ | |
| ["Transaction", dns.name.Name, dns.rdatatype.RdataType, dns.rdatatype.RdataType], | |
| None, | |
| ] | |
| CheckDeleteNameType = Callable[["Transaction", dns.name.Name], None] | |
| class Transaction: | |
| def __init__( | |
| self, | |
| manager: TransactionManager, | |
| replacement: bool = False, | |
| read_only: bool = False, | |
| ): | |
| self.manager = manager | |
| self.replacement = replacement | |
| self.read_only = read_only | |
| self._ended = False | |
| self._check_put_rdataset: List[CheckPutRdatasetType] = [] | |
| self._check_delete_rdataset: List[CheckDeleteRdatasetType] = [] | |
| self._check_delete_name: List[CheckDeleteNameType] = [] | |
| # | |
| # This is the high level API | |
| # | |
| # Note that we currently use non-immutable types in the return type signature to | |
| # avoid covariance problems, e.g. if the caller has a List[Rdataset], mypy will be | |
| # unhappy if we return an ImmutableRdataset. | |
| def get( | |
| self, | |
| name: Optional[Union[dns.name.Name, str]], | |
| rdtype: Union[dns.rdatatype.RdataType, str], | |
| covers: Union[dns.rdatatype.RdataType, str] = dns.rdatatype.NONE, | |
| ) -> dns.rdataset.Rdataset: | |
| """Return the rdataset associated with *name*, *rdtype*, and *covers*, | |
| or `None` if not found. | |
| Note that the returned rdataset is immutable. | |
| """ | |
| self._check_ended() | |
| if isinstance(name, str): | |
| name = dns.name.from_text(name, None) | |
| rdtype = dns.rdatatype.RdataType.make(rdtype) | |
| covers = dns.rdatatype.RdataType.make(covers) | |
| rdataset = self._get_rdataset(name, rdtype, covers) | |
| return _ensure_immutable_rdataset(rdataset) | |
| def get_node(self, name: dns.name.Name) -> Optional[dns.node.Node]: | |
| """Return the node at *name*, if any. | |
| Returns an immutable node or ``None``. | |
| """ | |
| return _ensure_immutable_node(self._get_node(name)) | |
| def _check_read_only(self) -> None: | |
| if self.read_only: | |
| raise ReadOnly | |
| def add(self, *args: Any) -> None: | |
| """Add records. | |
| The arguments may be: | |
| - rrset | |
| - name, rdataset... | |
| - name, ttl, rdata... | |
| """ | |
| self._check_ended() | |
| self._check_read_only() | |
| self._add(False, args) | |
| def replace(self, *args: Any) -> None: | |
| """Replace the existing rdataset at the name with the specified | |
| rdataset, or add the specified rdataset if there was no existing | |
| rdataset. | |
| The arguments may be: | |
| - rrset | |
| - name, rdataset... | |
| - name, ttl, rdata... | |
| Note that if you want to replace the entire node, you should do | |
| a delete of the name followed by one or more calls to add() or | |
| replace(). | |
| """ | |
| self._check_ended() | |
| self._check_read_only() | |
| self._add(True, args) | |
| def delete(self, *args: Any) -> None: | |
| """Delete records. | |
| It is not an error if some of the records are not in the existing | |
| set. | |
| The arguments may be: | |
| - rrset | |
| - name | |
| - name, rdatatype, [covers] | |
| - name, rdataset... | |
| - name, rdata... | |
| """ | |
| self._check_ended() | |
| self._check_read_only() | |
| self._delete(False, args) | |
| def delete_exact(self, *args: Any) -> None: | |
| """Delete records. | |
| The arguments may be: | |
| - rrset | |
| - name | |
| - name, rdatatype, [covers] | |
| - name, rdataset... | |
| - name, rdata... | |
| Raises dns.transaction.DeleteNotExact if some of the records | |
| are not in the existing set. | |
| """ | |
| self._check_ended() | |
| self._check_read_only() | |
| self._delete(True, args) | |
| def name_exists(self, name: Union[dns.name.Name, str]) -> bool: | |
| """Does the specified name exist?""" | |
| self._check_ended() | |
| if isinstance(name, str): | |
| name = dns.name.from_text(name, None) | |
| return self._name_exists(name) | |
| def update_serial( | |
| self, | |
| value: int = 1, | |
| relative: bool = True, | |
| name: dns.name.Name = dns.name.empty, | |
| ) -> None: | |
| """Update the serial number. | |
| *value*, an `int`, is an increment if *relative* is `True`, or the | |
| actual value to set if *relative* is `False`. | |
| Raises `KeyError` if there is no SOA rdataset at *name*. | |
| Raises `ValueError` if *value* is negative or if the increment is | |
| so large that it would cause the new serial to be less than the | |
| prior value. | |
| """ | |
| self._check_ended() | |
| if value < 0: | |
| raise ValueError("negative update_serial() value") | |
| if isinstance(name, str): | |
| name = dns.name.from_text(name, None) | |
| rdataset = self._get_rdataset(name, dns.rdatatype.SOA, dns.rdatatype.NONE) | |
| if rdataset is None or len(rdataset) == 0: | |
| raise KeyError | |
| if relative: | |
| serial = dns.serial.Serial(rdataset[0].serial) + value | |
| else: | |
| serial = dns.serial.Serial(value) | |
| serial = serial.value # convert back to int | |
| if serial == 0: | |
| serial = 1 | |
| rdata = rdataset[0].replace(serial=serial) | |
| new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata) | |
| self.replace(name, new_rdataset) | |
| def __iter__(self): | |
| self._check_ended() | |
| return self._iterate_rdatasets() | |
| def changed(self) -> bool: | |
| """Has this transaction changed anything? | |
| For read-only transactions, the result is always `False`. | |
| For writable transactions, the result is `True` if at some time | |
| during the life of the transaction, the content was changed. | |
| """ | |
| self._check_ended() | |
| return self._changed() | |
| def commit(self) -> None: | |
| """Commit the transaction. | |
| Normally transactions are used as context managers and commit | |
| or rollback automatically, but it may be done explicitly if needed. | |
| A ``dns.transaction.Ended`` exception will be raised if you try | |
| to use a transaction after it has been committed or rolled back. | |
| Raises an exception if the commit fails (in which case the transaction | |
| is also rolled back. | |
| """ | |
| self._end(True) | |
| def rollback(self) -> None: | |
| """Rollback the transaction. | |
| Normally transactions are used as context managers and commit | |
| or rollback automatically, but it may be done explicitly if needed. | |
| A ``dns.transaction.AlreadyEnded`` exception will be raised if you try | |
| to use a transaction after it has been committed or rolled back. | |
| Rollback cannot otherwise fail. | |
| """ | |
| self._end(False) | |
| def check_put_rdataset(self, check: CheckPutRdatasetType) -> None: | |
| """Call *check* before putting (storing) an rdataset. | |
| The function is called with the transaction, the name, and the rdataset. | |
| The check function may safely make non-mutating transaction method | |
| calls, but behavior is undefined if mutating transaction methods are | |
| called. The check function should raise an exception if it objects to | |
| the put, and otherwise should return ``None``. | |
| """ | |
| self._check_put_rdataset.append(check) | |
| def check_delete_rdataset(self, check: CheckDeleteRdatasetType) -> None: | |
| """Call *check* before deleting an rdataset. | |
| The function is called with the transaction, the name, the rdatatype, | |
| and the covered rdatatype. | |
| The check function may safely make non-mutating transaction method | |
| calls, but behavior is undefined if mutating transaction methods are | |
| called. The check function should raise an exception if it objects to | |
| the put, and otherwise should return ``None``. | |
| """ | |
| self._check_delete_rdataset.append(check) | |
| def check_delete_name(self, check: CheckDeleteNameType) -> None: | |
| """Call *check* before putting (storing) an rdataset. | |
| The function is called with the transaction and the name. | |
| The check function may safely make non-mutating transaction method | |
| calls, but behavior is undefined if mutating transaction methods are | |
| called. The check function should raise an exception if it objects to | |
| the put, and otherwise should return ``None``. | |
| """ | |
| self._check_delete_name.append(check) | |
| def iterate_rdatasets( | |
| self, | |
| ) -> Iterator[Tuple[dns.name.Name, dns.rdataset.Rdataset]]: | |
| """Iterate all the rdatasets in the transaction, returning | |
| (`dns.name.Name`, `dns.rdataset.Rdataset`) tuples. | |
| Note that as is usual with python iterators, adding or removing items | |
| while iterating will invalidate the iterator and may raise `RuntimeError` | |
| or fail to iterate over all entries.""" | |
| self._check_ended() | |
| return self._iterate_rdatasets() | |
| def iterate_names(self) -> Iterator[dns.name.Name]: | |
| """Iterate all the names in the transaction. | |
| Note that as is usual with python iterators, adding or removing names | |
| while iterating will invalidate the iterator and may raise `RuntimeError` | |
| or fail to iterate over all entries.""" | |
| self._check_ended() | |
| return self._iterate_names() | |
| # | |
| # Helper methods | |
| # | |
| def _raise_if_not_empty(self, method, args): | |
| if len(args) != 0: | |
| raise TypeError(f"extra parameters to {method}") | |
| def _rdataset_from_args(self, method, deleting, args): | |
| try: | |
| arg = args.popleft() | |
| if isinstance(arg, dns.rrset.RRset): | |
| rdataset = arg.to_rdataset() | |
| elif isinstance(arg, dns.rdataset.Rdataset): | |
| rdataset = arg | |
| else: | |
| if deleting: | |
| ttl = 0 | |
| else: | |
| if isinstance(arg, int): | |
| ttl = arg | |
| if ttl > dns.ttl.MAX_TTL: | |
| raise ValueError(f"{method}: TTL value too big") | |
| else: | |
| raise TypeError(f"{method}: expected a TTL") | |
| arg = args.popleft() | |
| if isinstance(arg, dns.rdata.Rdata): | |
| rdataset = dns.rdataset.from_rdata(ttl, arg) | |
| else: | |
| raise TypeError(f"{method}: expected an Rdata") | |
| return rdataset | |
| except IndexError: | |
| if deleting: | |
| return None | |
| else: | |
| # reraise | |
| raise TypeError(f"{method}: expected more arguments") | |
| def _add(self, replace, args): | |
| try: | |
| args = collections.deque(args) | |
| if replace: | |
| method = "replace()" | |
| else: | |
| method = "add()" | |
| arg = args.popleft() | |
| if isinstance(arg, str): | |
| arg = dns.name.from_text(arg, None) | |
| if isinstance(arg, dns.name.Name): | |
| name = arg | |
| rdataset = self._rdataset_from_args(method, False, args) | |
| elif isinstance(arg, dns.rrset.RRset): | |
| rrset = arg | |
| name = rrset.name | |
| # rrsets are also rdatasets, but they don't print the | |
| # same and can't be stored in nodes, so convert. | |
| rdataset = rrset.to_rdataset() | |
| else: | |
| raise TypeError( | |
| f"{method} requires a name or RRset as the first argument" | |
| ) | |
| if rdataset.rdclass != self.manager.get_class(): | |
| raise ValueError(f"{method} has objects of wrong RdataClass") | |
| if rdataset.rdtype == dns.rdatatype.SOA: | |
| (_, _, origin) = self._origin_information() | |
| if name != origin: | |
| raise ValueError(f"{method} has non-origin SOA") | |
| self._raise_if_not_empty(method, args) | |
| if not replace: | |
| existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) | |
| if existing is not None: | |
| if isinstance(existing, dns.rdataset.ImmutableRdataset): | |
| trds = dns.rdataset.Rdataset( | |
| existing.rdclass, existing.rdtype, existing.covers | |
| ) | |
| trds.update(existing) | |
| existing = trds | |
| rdataset = existing.union(rdataset) | |
| self._checked_put_rdataset(name, rdataset) | |
| except IndexError: | |
| raise TypeError(f"not enough parameters to {method}") | |
| def _delete(self, exact, args): | |
| try: | |
| args = collections.deque(args) | |
| if exact: | |
| method = "delete_exact()" | |
| else: | |
| method = "delete()" | |
| arg = args.popleft() | |
| if isinstance(arg, str): | |
| arg = dns.name.from_text(arg, None) | |
| if isinstance(arg, dns.name.Name): | |
| name = arg | |
| if len(args) > 0 and ( | |
| isinstance(args[0], int) or isinstance(args[0], str) | |
| ): | |
| # deleting by type and (optionally) covers | |
| rdtype = dns.rdatatype.RdataType.make(args.popleft()) | |
| if len(args) > 0: | |
| covers = dns.rdatatype.RdataType.make(args.popleft()) | |
| else: | |
| covers = dns.rdatatype.NONE | |
| self._raise_if_not_empty(method, args) | |
| existing = self._get_rdataset(name, rdtype, covers) | |
| if existing is None: | |
| if exact: | |
| raise DeleteNotExact(f"{method}: missing rdataset") | |
| else: | |
| self._delete_rdataset(name, rdtype, covers) | |
| return | |
| else: | |
| rdataset = self._rdataset_from_args(method, True, args) | |
| elif isinstance(arg, dns.rrset.RRset): | |
| rdataset = arg # rrsets are also rdatasets | |
| name = rdataset.name | |
| else: | |
| raise TypeError( | |
| f"{method} requires a name or RRset as the first argument" | |
| ) | |
| self._raise_if_not_empty(method, args) | |
| if rdataset: | |
| if rdataset.rdclass != self.manager.get_class(): | |
| raise ValueError(f"{method} has objects of wrong RdataClass") | |
| existing = self._get_rdataset(name, rdataset.rdtype, rdataset.covers) | |
| if existing is not None: | |
| if exact: | |
| intersection = existing.intersection(rdataset) | |
| if intersection != rdataset: | |
| raise DeleteNotExact(f"{method}: missing rdatas") | |
| rdataset = existing.difference(rdataset) | |
| if len(rdataset) == 0: | |
| self._checked_delete_rdataset( | |
| name, rdataset.rdtype, rdataset.covers | |
| ) | |
| else: | |
| self._checked_put_rdataset(name, rdataset) | |
| elif exact: | |
| raise DeleteNotExact(f"{method}: missing rdataset") | |
| else: | |
| if exact and not self._name_exists(name): | |
| raise DeleteNotExact(f"{method}: name not known") | |
| self._checked_delete_name(name) | |
| except IndexError: | |
| raise TypeError(f"not enough parameters to {method}") | |
| def _check_ended(self): | |
| if self._ended: | |
| raise AlreadyEnded | |
| def _end(self, commit): | |
| self._check_ended() | |
| if self._ended: | |
| raise AlreadyEnded | |
| try: | |
| self._end_transaction(commit) | |
| finally: | |
| self._ended = True | |
| def _checked_put_rdataset(self, name, rdataset): | |
| for check in self._check_put_rdataset: | |
| check(self, name, rdataset) | |
| self._put_rdataset(name, rdataset) | |
| def _checked_delete_rdataset(self, name, rdtype, covers): | |
| for check in self._check_delete_rdataset: | |
| check(self, name, rdtype, covers) | |
| self._delete_rdataset(name, rdtype, covers) | |
| def _checked_delete_name(self, name): | |
| for check in self._check_delete_name: | |
| check(self, name) | |
| self._delete_name(name) | |
| # | |
| # Transactions are context managers. | |
| # | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if not self._ended: | |
| if exc_type is None: | |
| self.commit() | |
| else: | |
| self.rollback() | |
| return False | |
| # | |
| # This is the low level API, which must be implemented by subclasses | |
| # of Transaction. | |
| # | |
| def _get_rdataset(self, name, rdtype, covers): | |
| """Return the rdataset associated with *name*, *rdtype*, and *covers*, | |
| or `None` if not found. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _put_rdataset(self, name, rdataset): | |
| """Store the rdataset.""" | |
| raise NotImplementedError # pragma: no cover | |
| def _delete_name(self, name): | |
| """Delete all data associated with *name*. | |
| It is not an error if the name does not exist. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _delete_rdataset(self, name, rdtype, covers): | |
| """Delete all data associated with *name*, *rdtype*, and *covers*. | |
| It is not an error if the rdataset does not exist. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _name_exists(self, name): | |
| """Does name exist? | |
| Returns a bool. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _changed(self): | |
| """Has this transaction changed anything?""" | |
| raise NotImplementedError # pragma: no cover | |
| def _end_transaction(self, commit): | |
| """End the transaction. | |
| *commit*, a bool. If ``True``, commit the transaction, otherwise | |
| roll it back. | |
| If committing and the commit fails, then roll back and raise an | |
| exception. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _set_origin(self, origin): | |
| """Set the origin. | |
| This method is called when reading a possibly relativized | |
| source, and an origin setting operation occurs (e.g. $ORIGIN | |
| in a zone file). | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def _iterate_rdatasets(self): | |
| """Return an iterator that yields (name, rdataset) tuples.""" | |
| raise NotImplementedError # pragma: no cover | |
| def _iterate_names(self): | |
| """Return an iterator that yields a name.""" | |
| raise NotImplementedError # pragma: no cover | |
| def _get_node(self, name): | |
| """Return the node at *name*, if any. | |
| Returns a node or ``None``. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| # | |
| # Low-level API with a default implementation, in case a subclass needs | |
| # to override. | |
| # | |
| def _origin_information(self): | |
| # This is only used by _add() | |
| return self.manager.origin_information() | |