Source code for nltk.sem.drt

# Natural Language Toolkit: Discourse Representation Theory (DRT)
#
# Author: Dan Garrette <dhgarrette@gmail.com>
#
# Copyright (C) 2001-2015 NLTK Project
# URL: <http://nltk.org/>
# For license information, see LICENSE.TXT
from __future__ import print_function, unicode_literals

import operator
from functools import reduce

from nltk.compat import string_types, python_2_unicode_compatible
from nltk.sem.logic import (APP, AbstractVariableExpression, AllExpression,
                            AndExpression, ApplicationExpression, BinaryExpression,
                            BooleanExpression, ConstantExpression, EqualityExpression,
                            EventVariableExpression, ExistsExpression, Expression,
                            FunctionVariableExpression, ImpExpression,
                            IndividualVariableExpression, LambdaExpression, Tokens,
                            LogicParser, NegatedExpression, OrExpression, Variable,
                            is_eventvar, is_funcvar, is_indvar, unique_variable)

# Import Tkinter-based modules if they are available
try:
    # imports are fixed for Python 2.x by nltk.compat
    from tkinter import Canvas
    from tkinter import Tk
    from tkinter.font import Font
    from nltk.util import in_idle

except ImportError:
    # No need to print a warning here, nltk.draw has already printed one.
    pass

class DrtTokens(Tokens):
    DRS = 'DRS'
    DRS_CONC = '+'
    PRONOUN = 'PRO'
    OPEN_BRACKET = '['
    CLOSE_BRACKET = ']'
    COLON = ':'

    PUNCT = [DRS_CONC, OPEN_BRACKET, CLOSE_BRACKET, COLON]

    SYMBOLS = Tokens.SYMBOLS + PUNCT

    TOKENS = Tokens.TOKENS + [DRS] + PUNCT


class DrtParser(LogicParser):
    """A lambda calculus expression parser."""
    def __init__(self):
        LogicParser.__init__(self)

        self.operator_precedence = dict(
                               [(x,1) for x in DrtTokens.LAMBDA_LIST]             + \
                               [(x,2) for x in DrtTokens.NOT_LIST]                + \
                               [(APP,3)]                                          + \
                               [(x,4) for x in DrtTokens.EQ_LIST+Tokens.NEQ_LIST] + \
                               [(DrtTokens.COLON,5)]                              + \
                               [(DrtTokens.DRS_CONC,6)]                           + \
                               [(x,7) for x in DrtTokens.OR_LIST]                 + \
                               [(x,8) for x in DrtTokens.IMP_LIST]                + \
                               [(None,9)])

    def get_all_symbols(self):
        """This method exists to be overridden"""
        return DrtTokens.SYMBOLS

    def isvariable(self, tok):
        return tok not in DrtTokens.TOKENS

    def handle(self, tok, context):
        """This method is intended to be overridden for logics that
        use different operators or expressions"""
        if tok in DrtTokens.NOT_LIST:
            return self.handle_negation(tok, context)

        elif tok in DrtTokens.LAMBDA_LIST:
            return self.handle_lambda(tok, context)

        elif tok == DrtTokens.OPEN:
            if self.inRange(0) and self.token(0) == DrtTokens.OPEN_BRACKET:
                return self.handle_DRS(tok, context)
            else:
                return self.handle_open(tok, context)

        elif tok.upper() == DrtTokens.DRS:
            self.assertNextToken(DrtTokens.OPEN)
            return self.handle_DRS(tok, context)

        elif self.isvariable(tok):
            if self.inRange(0) and self.token(0) == DrtTokens.COLON:
                return self.handle_prop(tok, context)
            else:
                return self.handle_variable(tok, context)

    def make_NegatedExpression(self, expression):
        return DrtNegatedExpression(expression)

    def handle_DRS(self, tok, context):
        # a DRS
        refs = self.handle_refs()
        if self.inRange(0) and self.token(0) == DrtTokens.COMMA: #if there is a comma (it's optional)
            self.token() # swallow the comma
        conds = self.handle_conds(context)
        self.assertNextToken(DrtTokens.CLOSE)
        return DRS(refs, conds, None)

    def handle_refs(self):
        self.assertNextToken(DrtTokens.OPEN_BRACKET)
        refs = []
        while self.inRange(0) and self.token(0) != DrtTokens.CLOSE_BRACKET:
        # Support expressions like: DRS([x y],C) == DRS([x,y],C)
            if refs and self.token(0) == DrtTokens.COMMA:
                self.token() # swallow the comma
            refs.append(self.get_next_token_variable('quantified'))
        self.assertNextToken(DrtTokens.CLOSE_BRACKET)
        return refs

    def handle_conds(self, context):
        self.assertNextToken(DrtTokens.OPEN_BRACKET)
        conds = []
        while self.inRange(0) and self.token(0) != DrtTokens.CLOSE_BRACKET:
            # Support expressions like: DRS([x y],C) == DRS([x, y],C)
            if conds and self.token(0) == DrtTokens.COMMA:
                self.token() # swallow the comma
            conds.append(self.process_next_expression(context))
        self.assertNextToken(DrtTokens.CLOSE_BRACKET)
        return conds

    def handle_prop(self, tok, context):
        variable = self.make_VariableExpression(tok)
        self.assertNextToken(':')
        drs = self.process_next_expression(DrtTokens.COLON)
        return DrtProposition(variable, drs)

    def make_EqualityExpression(self, first, second):
        """This method serves as a hook for other logic parsers that
        have different equality expression classes"""
        return DrtEqualityExpression(first, second)

    def get_BooleanExpression_factory(self, tok):
        """This method serves as a hook for other logic parsers that
        have different boolean operators"""
        if tok == DrtTokens.DRS_CONC:
            return lambda first, second: DrtConcatenation(first, second, None)
        elif tok in DrtTokens.OR_LIST:
            return DrtOrExpression
        elif tok in DrtTokens.IMP_LIST:
            def make_imp_expression(first, second):
                if isinstance(first, DRS):
                    return DRS(first.refs, first.conds, second)
                if isinstance(first, DrtConcatenation):
                    return DrtConcatenation(first.first, first.second, second)
                raise Exception('Antecedent of implication must be a DRS')
            return make_imp_expression
        else:
            return None

    def make_BooleanExpression(self, factory, first, second):
        return factory(first, second)

    def make_ApplicationExpression(self, function, argument):
        return DrtApplicationExpression(function, argument)

    def make_VariableExpression(self, name):
        return DrtVariableExpression(Variable(name))

    def make_LambdaExpression(self, variables, term):
        return DrtLambdaExpression(variables, term)


[docs]class DrtExpression(object): """ This is the base abstract DRT Expression from which every DRT Expression extends. """ _drt_parser = DrtParser() @classmethod
[docs] def fromstring(cls, s): return cls._drt_parser.parse(s)
[docs] def applyto(self, other): return DrtApplicationExpression(self, other)
def __neg__(self): return DrtNegatedExpression(self) def __and__(self, other): raise NotImplementedError() def __or__(self, other): assert isinstance(other, DrtExpression) return DrtOrExpression(self, other) def __gt__(self, other): assert isinstance(other, DrtExpression) if isinstance(self, DRS): return DRS(self.refs, self.conds, other) if isinstance(self, DrtConcatenation): return DrtConcatenation(self.first, self.second, other) raise Exception('Antecedent of implication must be a DRS')
[docs] def equiv(self, other, prover=None): """ Check for logical equivalence. Pass the expression (self <-> other) to the theorem prover. If the prover says it is valid, then the self and other are equal. :param other: an ``DrtExpression`` to check equality against :param prover: a ``nltk.inference.api.Prover`` """ assert isinstance(other, DrtExpression) f1 = self.simplify().fol(); f2 = other.simplify().fol(); return f1.equiv(f2, prover)
@property def type(self): raise AttributeError("'%s' object has no attribute 'type'" % self.__class__.__name__)
[docs] def typecheck(self, signature=None): raise NotImplementedError()
def __add__(self, other): return DrtConcatenation(self, other, None)
[docs] def get_refs(self, recursive=False): """ Return the set of discourse referents in this DRS. :param recursive: bool Also find discourse referents in subterms? :return: list of ``Variable`` objects """ raise NotImplementedError()
[docs] def is_pronoun_function(self): """ Is self of the form "PRO(x)"? """ return isinstance(self, DrtApplicationExpression) and \ isinstance(self.function, DrtAbstractVariableExpression) and \ self.function.variable.name == DrtTokens.PRONOUN and \ isinstance(self.argument, DrtIndividualVariableExpression)
[docs] def make_EqualityExpression(self, first, second): return DrtEqualityExpression(first, second)
[docs] def make_VariableExpression(self, variable): return DrtVariableExpression(variable)
[docs] def resolve_anaphora(self): return resolve_anaphora(self)
[docs] def eliminate_equality(self): return self.visit_structured(lambda e: e.eliminate_equality(), self.__class__)
[docs] def pretty_format(self): """ Draw the DRS :return: the pretty print string """ return '\n'.join(self._pretty())
[docs] def pretty_print(self): print(self.pretty_format())
[docs] def draw(self): DrsDrawer(self).draw()
@python_2_unicode_compatible
[docs]class DRS(DrtExpression, Expression): """A Discourse Representation Structure."""
[docs] def __init__(self, refs, conds, consequent=None): """ :param refs: list of ``DrtIndividualVariableExpression`` for the discourse referents :param conds: list of ``Expression`` for the conditions """ self.refs = refs self.conds = conds self.consequent = consequent
[docs] def replace(self, variable, expression, replace_bound=False, alpha_convert=True): """Replace all instances of variable v with expression E in self, where v is free in self.""" if variable in self.refs: #if a bound variable is the thing being replaced if not replace_bound: return self else: i = self.refs.index(variable) if self.consequent: consequent = self.consequent.replace(variable, expression, True, alpha_convert) else: consequent = None return DRS(self.refs[:i]+[expression.variable]+self.refs[i+1:], [cond.replace(variable, expression, True, alpha_convert) for cond in self.conds], consequent) else: if alpha_convert: # any bound variable that appears in the expression must # be alpha converted to avoid a conflict for ref in (set(self.refs) & expression.free()): newvar = unique_variable(ref) newvarex = DrtVariableExpression(newvar) i = self.refs.index(ref) if self.consequent: consequent = self.consequent.replace(ref, newvarex, True, alpha_convert) else: consequent = None self = DRS(self.refs[:i]+[newvar]+self.refs[i+1:], [cond.replace(ref, newvarex, True, alpha_convert) for cond in self.conds], consequent) #replace in the conditions if self.consequent: consequent = self.consequent.replace(variable, expression, replace_bound, alpha_convert) else: consequent = None return DRS(self.refs, [cond.replace(variable, expression, replace_bound, alpha_convert) for cond in self.conds], consequent)
[docs] def free(self): """:see: Expression.free()""" conds_free = reduce(operator.or_, [c.free() for c in self.conds], set()) if self.consequent: conds_free.update(self.consequent.free()) return conds_free - set(self.refs)
[docs] def get_refs(self, recursive=False): """:see: AbstractExpression.get_refs()""" if recursive: conds_refs = self.refs + sum((c.get_refs(True) for c in self.conds), []) if self.consequent: conds_refs.extend(self.consequent.get_refs(True)) return conds_refs else: return self.refs
[docs] def visit(self, function, combinator): """:see: Expression.visit()""" parts = list(map(function, self.conds)) if self.consequent: parts.append(function(self.consequent)) return combinator(parts)
[docs] def visit_structured(self, function, combinator): """:see: Expression.visit_structured()""" consequent = (function(self.consequent) if self.consequent else None) return combinator(self.refs, list(map(function, self.conds)), consequent)
[docs] def eliminate_equality(self): drs = self i = 0 while i < len(drs.conds): cond = drs.conds[i] if isinstance(cond, EqualityExpression) and \ isinstance(cond.first, AbstractVariableExpression) and \ isinstance(cond.second, AbstractVariableExpression): drs = DRS(list(set(drs.refs)-set([cond.second.variable])), drs.conds[:i]+drs.conds[i+1:], drs.consequent) if cond.second.variable != cond.first.variable: drs = drs.replace(cond.second.variable, cond.first, False, False) i = 0 i -= 1 i += 1 conds = [] for cond in drs.conds: new_cond = cond.eliminate_equality() new_cond_simp = new_cond.simplify() if not isinstance(new_cond_simp, DRS) or \ new_cond_simp.refs or new_cond_simp.conds or \ new_cond_simp.consequent: conds.append(new_cond) consequent = (drs.consequent.eliminate_equality() if drs.consequent else None) return DRS(drs.refs, conds, consequent)
[docs] def fol(self): if self.consequent: accum = None if self.conds: accum = reduce(AndExpression, [c.fol() for c in self.conds]) if accum: accum = ImpExpression(accum, self.consequent.fol()) else: accum = self.consequent.fol() for ref in self.refs[::-1]: accum = AllExpression(ref, accum) return accum else: if not self.conds: raise Exception("Cannot convert DRS with no conditions to FOL.") accum = reduce(AndExpression, [c.fol() for c in self.conds]) for ref in map(Variable, self._order_ref_strings(self.refs)[::-1]): accum = ExistsExpression(ref, accum) return accum
def _pretty(self): refs_line = ' '.join(self._order_ref_strings(self.refs)) cond_lines = [cond for cond_line in [filter(lambda s: s.strip(), cond._pretty()) for cond in self.conds] for cond in cond_line] length = max([len(refs_line)] + list(map(len, cond_lines))) drs = ([' _' + '_' * length + '_ ', '| ' + refs_line.ljust(length) + ' |', '|-' + '-' * length + '-|'] + ['| ' + line.ljust(length) + ' |' for line in cond_lines] + ['|_' + '_' * length + '_|']) if self.consequent: return DrtBinaryExpression._assemble_pretty(drs, DrtTokens.IMP, self.consequent._pretty()) return drs def _order_ref_strings(self, refs): strings = ["%s" % ref for ref in refs] ind_vars = [] func_vars = [] event_vars = [] other_vars = [] for s in strings: if is_indvar(s): ind_vars.append(s) elif is_funcvar(s): func_vars.append(s) elif is_eventvar(s): event_vars.append(s) else: other_vars.append(s) return sorted(other_vars) + \ sorted(event_vars, key=lambda v: int([v[2:],-1][len(v[2:]) == 0])) + \ sorted(func_vars, key=lambda v: (v[0], int([v[1:],-1][len(v[1:])==0]))) + \ sorted(ind_vars, key=lambda v: (v[0], int([v[1:],-1][len(v[1:])==0]))) def __eq__(self, other): r"""Defines equality modulo alphabetic variance. If we are comparing \x.M and \y.N, then check equality of M and N[x/y].""" if isinstance(other, DRS): if len(self.refs) == len(other.refs): converted_other = other for (r1, r2) in zip(self.refs, converted_other.refs): varex = self.make_VariableExpression(r1) converted_other = converted_other.replace(r2, varex, True) if self.consequent == converted_other.consequent and \ len(self.conds) == len(converted_other.conds): for c1, c2 in zip(self.conds, converted_other.conds): if not (c1 == c2): return False return True return False def __ne__(self, other): return not self == other __hash__ = Expression.__hash__ def __str__(self): drs = '([%s],[%s])' % (','.join(self._order_ref_strings(self.refs)), ', '.join("%s" % cond for cond in self.conds)) # map(str, self.conds))) if self.consequent: return DrtTokens.OPEN + drs + ' ' + DrtTokens.IMP + ' ' + \ "%s" % self.consequent + DrtTokens.CLOSE return drs
def DrtVariableExpression(variable): """ This is a factory method that instantiates and returns a subtype of ``DrtAbstractVariableExpression`` appropriate for the given variable. """ if is_indvar(variable.name): return DrtIndividualVariableExpression(variable) elif is_funcvar(variable.name): return DrtFunctionVariableExpression(variable) elif is_eventvar(variable.name): return DrtEventVariableExpression(variable) else: return DrtConstantExpression(variable) class DrtAbstractVariableExpression(DrtExpression, AbstractVariableExpression): def fol(self): return self def get_refs(self, recursive=False): """:see: AbstractExpression.get_refs()""" return [] def _pretty(self): s = "%s" % self blank = ' '*len(s) return [blank, blank, s, blank] def eliminate_equality(self): return self class DrtIndividualVariableExpression(DrtAbstractVariableExpression, IndividualVariableExpression): pass class DrtFunctionVariableExpression(DrtAbstractVariableExpression, FunctionVariableExpression): pass class DrtEventVariableExpression(DrtIndividualVariableExpression, EventVariableExpression): pass class DrtConstantExpression(DrtAbstractVariableExpression, ConstantExpression): pass @python_2_unicode_compatible class DrtProposition(DrtExpression, Expression): def __init__(self, variable, drs): self.variable = variable self.drs = drs def replace(self, variable, expression, replace_bound=False, alpha_convert=True): if self.variable == variable: assert isinstance(expression, DrtAbstractVariableExpression), "Can only replace a proposition label with a variable" return DrtProposition(expression.variable, self.drs.replace(variable, expression, replace_bound, alpha_convert)) else: return DrtProposition(self.variable, self.drs.replace(variable, expression, replace_bound, alpha_convert)) def eliminate_equality(self): return DrtProposition(self.variable, self.drs.eliminate_equality()) def get_refs(self, recursive=False): return (self.drs.get_refs(True) if recursive else []) def __eq__(self, other): return self.__class__ == other.__class__ and \ self.variable == other.variable and \ self.drs == other.drs def __ne__(self, other): return not self == other __hash__ = Expression.__hash__ def fol(self): return self.drs.fol() def _pretty(self): drs_s = self.drs._pretty() blank = ' ' * len("%s" % self.variable) return ([blank + ' ' + line for line in drs_s[:1]] + ["%s" % self.variable + ':' + line for line in drs_s[1:2]] + [blank + ' ' + line for line in drs_s[2:]]) def visit(self, function, combinator): """:see: Expression.visit()""" return combinator([function(self.drs)]) def visit_structured(self, function, combinator): """:see: Expression.visit_structured()""" return combinator(self.variable, function(self.drs)) def __str__(self): return 'prop(%s, %s)' % (self.variable, self.drs) class DrtNegatedExpression(DrtExpression, NegatedExpression): def fol(self): return NegatedExpression(self.term.fol()) def get_refs(self, recursive=False): """:see: AbstractExpression.get_refs()""" return self.term.get_refs(recursive) def _pretty(self): term_lines = self.term._pretty() return ([' ' + line for line in term_lines[:2]] + ['__ ' + line for line in term_lines[2:3]] + [' | ' + line for line in term_lines[3:4]] + [' ' + line for line in term_lines[4:]]) class DrtLambdaExpression(DrtExpression, LambdaExpression): def alpha_convert(self, newvar): """Rename all occurrences of the variable introduced by this variable binder in the expression to ``newvar``. :param newvar: ``Variable``, for the new variable """ return self.__class__(newvar, self.term.replace(self.variable, DrtVariableExpression(newvar), True)) def fol(self): return LambdaExpression(self.variable, self.term.fol()) def _pretty(self): variables = [self.variable] term = self.term while term.__class__ == self.__class__: variables.append(term.variable) term = term.term var_string = ' '.join("%s" % v for v in variables) + DrtTokens.DOT term_lines = term._pretty() blank = ' ' * len(var_string) return ([' ' + blank + line for line in term_lines[:1]] + [' \ ' + blank + line for line in term_lines[1:2]] + [' /\ ' + var_string + line for line in term_lines[2:3]] + [' ' + blank + line for line in term_lines[3:]]) class DrtBinaryExpression(DrtExpression, BinaryExpression): def get_refs(self, recursive=False): """:see: AbstractExpression.get_refs()""" return self.first.get_refs(True) + self.second.get_refs(True) if recursive else [] def _pretty(self): return DrtBinaryExpression._assemble_pretty(self._pretty_subex(self.first), self.getOp(), self._pretty_subex(self.second)) @staticmethod def _assemble_pretty(first_lines, op, second_lines): max_lines = max(len(first_lines), len(second_lines)) first_lines = _pad_vertically(first_lines, max_lines) second_lines = _pad_vertically(second_lines, max_lines) blank = ' ' * len(op) first_second_lines = list(zip(first_lines, second_lines)) return ([' ' + first_line + ' ' + blank + ' ' + second_line + ' ' for first_line, second_line in first_second_lines[:2]] + ['(' + first_line + ' ' + op + ' ' + second_line + ')' for first_line, second_line in first_second_lines[2:3]] + [' ' + first_line + ' ' + blank + ' ' + second_line + ' ' for first_line, second_line in first_second_lines[3:]]) def _pretty_subex(self, subex): return subex._pretty() class DrtBooleanExpression(DrtBinaryExpression, BooleanExpression): pass class DrtOrExpression(DrtBooleanExpression, OrExpression): def fol(self): return OrExpression(self.first.fol(), self.second.fol()) def _pretty_subex(self, subex): if isinstance(subex, DrtOrExpression): return [line[1:-1] for line in subex._pretty()] return DrtBooleanExpression._pretty_subex(self, subex) class DrtEqualityExpression(DrtBinaryExpression, EqualityExpression): def fol(self): return EqualityExpression(self.first.fol(), self.second.fol()) @python_2_unicode_compatible class DrtConcatenation(DrtBooleanExpression): """DRS of the form '(DRS + DRS)'""" def __init__(self, first, second, consequent=None): DrtBooleanExpression.__init__(self, first, second) self.consequent = consequent def replace(self, variable, expression, replace_bound=False, alpha_convert=True): """Replace all instances of variable v with expression E in self, where v is free in self.""" first = self.first second = self.second consequent = self.consequent # If variable is bound if variable in self.get_refs(): if replace_bound: first = first.replace(variable, expression, replace_bound, alpha_convert) second = second.replace(variable, expression, replace_bound, alpha_convert) if consequent: consequent = consequent.replace(variable, expression, replace_bound, alpha_convert) else: if alpha_convert: # alpha convert every ref that is free in 'expression' for ref in (set(self.get_refs(True)) & expression.free()): v = DrtVariableExpression(unique_variable(ref)) first = first.replace(ref, v, True, alpha_convert) second = second.replace(ref, v, True, alpha_convert) if consequent: consequent = consequent.replace(ref, v, True, alpha_convert) first = first.replace(variable, expression, replace_bound, alpha_convert) second = second.replace(variable, expression, replace_bound, alpha_convert) if consequent: consequent = consequent.replace(variable, expression, replace_bound, alpha_convert) return self.__class__(first, second, consequent) def eliminate_equality(self): #TODO: at some point. for now, simplify. drs = self.simplify() assert not isinstance(drs, DrtConcatenation) return drs.eliminate_equality() def simplify(self): first = self.first.simplify() second = self.second.simplify() consequent = (self.consequent.simplify() if self.consequent else None) if isinstance(first, DRS) and isinstance(second, DRS): # For any ref that is in both 'first' and 'second' for ref in (set(first.get_refs(True)) & set(second.get_refs(True))): # alpha convert the ref in 'second' to prevent collision newvar = DrtVariableExpression(unique_variable(ref)) second = second.replace(ref, newvar, True) return DRS(first.refs + second.refs, first.conds + second.conds, consequent) else: return self.__class__(first, second, consequent) def get_refs(self, recursive=False): """:see: AbstractExpression.get_refs()""" refs = self.first.get_refs(recursive) + self.second.get_refs(recursive) if self.consequent and recursive: refs.extend(self.consequent.get_refs(True)) return refs def getOp(self): return DrtTokens.DRS_CONC def __eq__(self, other): r"""Defines equality modulo alphabetic variance. If we are comparing \x.M and \y.N, then check equality of M and N[x/y].""" if isinstance(other, <