Improve support for functions #2

This commit is contained in:
Bartłomiej Pluta
2019-07-04 11:31:02 +02:00
parent 6390ac20de
commit ce101df380
7 changed files with 96 additions and 50 deletions

View File

@@ -1,5 +1,8 @@
from enum import Enum, auto
from smnp.error.function import IllegalFunctionInvocationException
from smnp.type.model import Type
class FunctionType(Enum):
FUNCTION = auto()
@@ -7,12 +10,52 @@ class FunctionType(Enum):
class Function:
def __init__(self, signature, function):
def __init__(self, name, signature, function):
self.name = name
self.signature = signature
self.function = function
def stringSignature(self):
return f"{self.name}{self.signature.string}"
def call(self, env, args):
result = self.signature(args)
result = self.signature.check(args)
if result[0]:
return self.function(env, *result[1:])
# todo: raise illegal signature exception or something
raise IllegalFunctionInvocationException(self.stringSignature(), f"{self.name}{types(args)}") #TODO: argumenty do typów, nie wartości
class CombinedFunction(Function):
def __init__(self, name, *functions):
super().__init__(None, None, None)
self.name = name
self.functions = functions
def stringSignature(self):
return "\nor\n".join([f"{self.name}{function.signature.string}" for function in self.functions])
def call(self, env, args):
for function in self.functions:
result = function.signature.check(args)
if result[0]:
return function.function(env, *result[1:])
raise IllegalFunctionInvocationException(self.stringSignature(), f"{self.name}{types(args)}")
def types(args):
output = []
for arg in args:
if arg.type == Type.LIST:
output.append(listTypes(arg.value, []))
else:
output.append(arg.type.name)
return f"({', '.join(output)})"
def listTypes(l, output=[]):
for item in l:
if item.type == Type.LIST:
output.append(listTypes(item.value, []))
else:
output.append(item.type.name)
return f"LIST<{'|'.join(set(output))}>"

View File

@@ -4,9 +4,10 @@ from smnp.type.model import Type
class Matcher:
def __init__(self, objectType, matcher):
def __init__(self, objectType, matcher, string):
self.type = objectType
self.matcher = matcher
self.string = string
def match(self, value):
if self.type is not None and self.type != value.type:
@@ -16,7 +17,26 @@ class Matcher:
def andWith(self, matcher):
if self.type != matcher.type:
raise RuntimeError("Support types of matches are not the same")
return Matcher(self.type, lambda x: self.matcher(x) and matcher.matcher(x))
string = f"[{self.string} and {matcher.string}]"
return Matcher(self.type, lambda x: self.match(x) and matcher.match(x), string)
def orWith(self, matcher):
string = f"[{self.string} or {matcher.string}]"
return Matcher(None, lambda x: self.match(x) or matcher.match(x), string)
def __str__(self):
return self.string
def __repr__(self):
return self.__str__()
class Signature:
def __init__(self, check, string):
self.check = check
self.string = string
def varargSignature(varargMatcher, *basicSignature):
@@ -33,7 +53,10 @@ def varargSignature(varargMatcher, *basicSignature):
return doesNotMatchVararg(basicSignature)
return True, (*args[:len(basicSignature)]), args[len(basicSignature):]
return check
string = f"({', '.join([str(m) for m in basicSignature])}{', ' if len(basicSignature) > 0 else ''}{str(varargMatcher)}...)"
return Signature(check, string)
def doesNotMatchVararg(basicSignature):
@@ -52,7 +75,9 @@ def signature(*signature):
return (True, *args)
return check
string = f"({', '.join([str(m) for m in signature])})"
return Signature(check, string)
def doesNotMatch(sign):
@@ -62,21 +87,21 @@ def doesNotMatch(sign):
def ofTypes(*types):
def check(value):
return value.type in types
return Matcher(None, check)
return Matcher(None, check, f"<{'|'.join([t.name for t in types])}>")
def listOf(*types):
def check(value):
return len([item for item in value.value if not item.type in types]) == 0
return Matcher(Type.LIST, check)
return Matcher(Type.LIST, check, f"{Type.LIST.name}<{'|'.join([t.name for t in types])}>")
def listMatches(*pattern):
def check(value):
return signature(pattern)(value.value)[0]
return signature(*pattern).check(value.value)[0]
return Matcher(Type.LIST, check)
return Matcher(Type.LIST, check, f"({', '.join([str(m) for m in pattern])})")
def recursiveListMatcher(matcher):
@@ -89,5 +114,5 @@ def recursiveListMatcher(matcher):
for item in value.value:
return check(item)
return Matcher(Type.LIST, check)
return Matcher(Type.LIST, check, f"[LISTS OF {str(matcher)}]")

View File

@@ -1,38 +1,2 @@
from smnp.environment.function.model import Function
def returnElementOrList(list):
return list[0] if len(list) == 1 else list
def combineFunctions(*functions):
if len(functions) == 0:
raise RuntimeError("Must be passed one function at least")
def signature(args):
ret = None
for fun in functions:
ret = fun.signature(args)
if ret[0] == True:
return ret
return ret
def function(env, *args):
originalArgs = removeFirstLevelNesting(args)
for fun in functions:
if fun.signature(originalArgs)[0]:
return fun.function(env, *args)
return None
return Function(signature, function)
def removeFirstLevelNesting(l):
flat = []
for item in l:
if type(item) == list:
for i in item:
flat.append(i)
else:
flat.append(item)
return flat

2
smnp/error/base.py Normal file
View File

@@ -0,0 +1,2 @@
class SmnpException(Exception):
pass

6
smnp/error/function.py Normal file
View File

@@ -0,0 +1,6 @@
from smnp.error.base import SmnpException
class IllegalFunctionInvocationException(SmnpException):
def __init__(self, expected, found):
self.msg = f"Illegal function invocation\n\nExpected signature:\n{expected}\n\nFound:\n{found}"

View File

@@ -1,4 +1,7 @@
class RuntimeException(Exception):
from smnp.error.base import SmnpException
class RuntimeException(SmnpException):
def __init__(self, pos, msg):
posStr = "" if pos is None else f" [line {pos[0]+1}, col {pos[1]+1}]"
self.msg = f"Runtime error{posStr}:\n{msg}"

View File

@@ -1,4 +1,7 @@
class SyntaxException(Exception):
from smnp.error.base import SmnpException
class SyntaxException(SmnpException):
def __init__(self, pos, msg):
posStr = "" if pos is None else f" [line {pos[0]+1}, col {pos[1]+1}]"
self.msg = f"Syntax error{posStr}:\n{msg}"