Commit 4da2003b authored by Raul Sirel's avatar Raul Sirel
Browse files

add word features to mlp output. remove gpu exception

parent a4da11a1
......@@ -46,6 +46,8 @@ def test_mlp_process(mlp: MLP):
assert "lemmas" in mlp_text
assert isinstance(mlp_text["lemmas"], str) is True
assert "word_features" in mlp_text
assert "language" in mlp_text
assert isinstance(mlp_text["language"], dict) is True
......
......@@ -71,6 +71,7 @@ class Document:
self.__words = []
self.__lemmas = []
self.__pos_tags = []
self.__word_features = []
self.__transliteration = []
self.__texta_facts: List[Fact] = []
......@@ -144,7 +145,6 @@ class Document:
sent_index = tokenized_text[:span[0]].count("\n")
# find last sent break before the match
matches = list(re.finditer(" \n ", tokenized_text[:span[0]]))
print(sent_index)
# check if any sentences
if matches:
# find the last sentence break
......@@ -204,6 +204,8 @@ class Document:
container["lemmas"] = self.get_lemma()
if "pos_tags" in self.analyzers:
container["pos_tags"] = self.get_pos_tags()
if "word_features" in self.analyzers:
container["word_features"] = self.get_word_features()
if "transliteration" in self.analyzers and self.__transliteration:
container["transliteration"] = self.get_transliteration()
if use_default_doc_path:
......@@ -248,6 +250,7 @@ class Document:
def pos_tags(self):
if "sentences" in self.analyzers:
for i,sent in enumerate(self.stanza_sentences):
#print(sent)
tags_in_sent = [word.upos if word and word.upos and word.upos != "_" else "X" if word.upos == "_" else "X" for word in sent]
for tag in tags_in_sent:
self.__pos_tags.append(tag)
......@@ -262,6 +265,23 @@ class Document:
return " ".join([a.strip() for a in self.__pos_tags])
def word_features(self):
if "word_features" in self.analyzers:
for i,sent in enumerate(self.stanza_sentences):
tags_in_sent = [word.feats if word and word.feats and word.feats != "_" else "X" if word.feats == "_" else "X" for word in sent]
for tag in tags_in_sent:
self.__word_features.append(tag)
# if not last item
if i+1 < len(self.stanza_sentences):
self.__word_features.append("LBR")
else:
self.__word_features = [word.feats if word and word.feats and word.feats != "_" else "X" if word.feats == "_" else "X" for word in self.stanza_words]
def get_word_features(self) -> str:
return " ".join([a.strip() for a in self.__word_features])
def entities(self):
"""
Retrieves list-based entities.
......
......@@ -50,6 +50,7 @@ REFRESH_DATA = parse_bool_env("TEXTA_MLP_REFRESH_DATA", False)
SUPPORTED_ANALYZERS = (
"lemmas",
"pos_tags",
"word_features",
"transliteration",
"ner",
"addresses",
......@@ -63,6 +64,7 @@ SUPPORTED_ANALYZERS = (
DEFAULT_ANALYZERS = [
"lemmas",
"pos_tags",
"word_features",
"transliteration",
"ner",
"addresses",
......@@ -268,44 +270,22 @@ class MLP:
def get_stanza_pipeline(self, lang: str):
if lang not in self._stanza_pipelines:
if lang in self.custom_ner_model_langs:
try:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
ner_model_path=f"{self.custom_ner_model_Path}/{lang}",
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
ner_model_path=f"{self.custom_ner_model_Path}/{lang}",
use_gpu=False,
logging_level=self.logging_level,
)
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
ner_model_path=f"{self.custom_ner_model_Path}/{lang}",
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
else:
try:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=False,
logging_level=self.logging_level,
)
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
return self._stanza_pipelines[lang]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment