Commit 1613a25b authored by Raul Sirel's avatar Raul Sirel
Browse files

Merge branch 'custom_models' into 'master'

Custom models

See merge request !15
parents fa0e0b24 969e2e33
Pipeline #6568 canceled with stage
in 32 seconds
......@@ -158,8 +158,7 @@ def test_removal_of_duplicate_facts(mlp: MLP):
facts = result[0]["texta_facts"]
fact = facts[0]
assert len(facts) == 1
assert fact["str_val"] == "nõmme tänav 24"
assert len(facts) == 3
assert fact["doc_path"] == "texts_mlp.text"
......
......@@ -35,6 +35,11 @@ CONCATENATOR_DATA_FILES = (
"https://packages.texta.ee/texta-resources/concatenator/space_between_not_ok.txt",
)
# URLs for Custom NER model downloads.
CUSTOM_NER_MODELS = {
"et": "https://packages.texta.ee/texta-resources/ner_models/_estonian_nertagger.pt",
}
# Location of the resource dir where models are downloaded
DEFAULT_RESOURCE_DIR = os.getenv("TEXTA_MLP_DATA_DIR", os.path.join(os.getcwd(), "data"))
......@@ -74,6 +79,9 @@ DEFAULT_ANALYZERS = [
# https://stanfordnlp.github.io/stanza/available_models.html#available-ner-models
STANZA_NER_SUPPORT = ("ar", "zh", "nl", "en", "fr", "de", "ru", "es", "uk")
# Here we add langs that will have custom ner models.
CUSTOM_NER_MODEL_LANGS = ["et"]
class MLP:
......@@ -83,6 +91,7 @@ class MLP:
default_language_code=DEFAULT_LANG_CODES[0],
use_default_language_code=True,
resource_dir: str = DEFAULT_RESOURCE_DIR,
ner_model_langs: list = CUSTOM_NER_MODEL_LANGS,
logging_level="error",
use_gpu=True,
refresh_data=REFRESH_DATA
......@@ -94,9 +103,11 @@ class MLP:
self.resource_dir = resource_dir
self._stanza_pipelines = {}
self.custom_ner_model_langs = ner_model_langs
self.logging_level = logging_level
self.use_gpu = use_gpu
self.stanza_resource_path = pathlib.Path(self.resource_dir) / "stanza"
self.custom_ner_model_Path = pathlib.Path(self.resource_dir) / "ner_models"
self.prepare_resources(refresh_data)
......@@ -114,9 +125,27 @@ class MLP:
shutil.rmtree(self.resource_dir)
self.logger.info("MLP data directory deleted.")
# download resources
self.download_custom_ner_models(self.resource_dir, logger=self.logger, model_langs=self.custom_ner_model_langs)
self.download_stanza_resources(self.resource_dir, self.supported_langs, logger=self.logger)
self.download_entity_mapper_resources(self.resource_dir, logger=self.logger)
@staticmethod
def download_custom_ner_models(resource_dir: str, logger=None, custom_ner_model_urls: dict = CUSTOM_NER_MODELS, model_langs: list=None):
"""
Downloads custom ner models if not present in resources directory.
"""
ner_resource_dir = pathlib.Path(resource_dir) / "ner_models"
ner_resource_dir.mkdir(parents=True, exist_ok=True)
for lang, url in custom_ner_model_urls.items():
if lang in model_langs:
file_name = urlparse(url).path.split("/")[-1]
file_path = ner_resource_dir / lang
if not file_path.exists():
if logger: logger.info(f"Downloading custom ner model file {file_name} into directory: {url}")
response = urlopen(url)
content = response.read()
with open(file_path, "wb") as fh:
fh.write(content)
@staticmethod
def download_stanza_resources(resource_dir: str, supported_langs: List[str], logger=None):
......@@ -238,24 +267,44 @@ class MLP:
def get_stanza_pipeline(self, lang: str):
if lang not in self._stanza_pipelines:
try:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=False,
logging_level=self.logging_level,
)
if lang in self.custom_ner_model_langs:
try:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
ner_model_path=f"{self.custom_ner_model_Path}/{lang}",
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
ner_model_path=f"{self.custom_ner_model_Path}/{lang}",
use_gpu=False,
logging_level=self.logging_level,
)
else:
try:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=self.use_gpu,
logging_level=self.logging_level,
)
# This is for CUDA OOM exceptions. Fall back to CPU if needed.
except RuntimeError:
self._stanza_pipelines[lang] = stanza.Pipeline(
lang=lang,
dir=str(self.stanza_resource_path),
processors=self._get_stanza_processors(lang),
use_gpu=False,
logging_level=self.logging_level,
)
return self._stanza_pipelines[lang]
......@@ -312,7 +361,7 @@ class MLP:
# For every analyzer, activate the function that processes it from the
# document class.
self.__apply_analyzer(document, analyzer)
if "sentences" in analyzers and spans == "sentence":
document.fact_spans_to_sent()
......@@ -404,7 +453,6 @@ class MLP:
return container
@staticmethod
def download_concatenator_resources(resource_dir: str, logger):
concat_resource_dir = pathlib.Path(resource_dir) / "concatenator"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment