Jeg løb ind i præcis det samme problem, og mand var det et kaninhul. Ønskede at poste min løsning her, da det måske kan spare nogen for en arbejdsdag:
TensorFlow trådspecifikke datastrukturer
I TensorFlow er der to nøgledatastrukturer, der arbejder bag kulisserne, når du kalder model.predict
(eller keras.models.load_model
, eller keras.backend.clear_session
, eller stort set enhver anden funktion, der interagerer med TensorFlow-backend):
- En TensorFlow-graf, som repræsenterer strukturen af din Keras-model
- En TensorFlow-session, som er forbindelsen mellem din aktuelle graf og TensorFlow-runtiden
Noget, der ikke er eksplicit tydeligt i dokumenterne uden lidt gravering, er, at både sessionen og grafen er egenskaber for den aktuelle tråd . Se API-dokumenter her og her.
Brug af TensorFlow-modeller i forskellige tråde
Det er naturligt at ønske at indlæse din model én gang og derefter kalde .predict()
på den flere gange senere:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
I en webserver eller worker pool kontekst som Celery betyder det, at du vil indlæse modellen, når du importerer modulet, der indeholder load_model
linje, så vil en anden tråd udføre some_worker_function
, kører forudsigelse på den globale variabel, der indeholder Keras-modellen. Men at prøve at køre forudsigelse på en model, der er indlæst i en anden tråd, giver "tensor er ikke et element i denne graf"-fejl. Takket være de adskillige SO-indlæg, der berørte dette emne, såsom ValueError:Tensor Tensor(...) er ikke et element i denne graf. Ved brug af global variabel keras-model. For at få dette til at virke, skal du hænge på TensorFlow-grafen, der blev brugt - som vi så tidligere, er grafen en egenskab for den aktuelle tråd. Den opdaterede kode ser sådan ud:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
Det noget overraskende twist her er:ovenstående kode er tilstrækkelig, hvis du bruger Thread
s, men hænger på ubestemt tid, hvis du bruger Process
es. Og som standard bruger Celery processer til at administrere alle sine arbejderpuljer. Så på dette tidspunkt er tingene stadig virker ikke på selleri.
Hvorfor virker dette kun på Thread
s?
I Python, Thread
s deler den samme globale eksekveringskontekst som den overordnede proces. Fra Python _thread docs:
Dette modul giver primitiver på lavt niveau til at arbejde med flere tråde (også kaldet letvægtsprocesser eller opgaver) - flere kontroltråde, der deler deres globale datarum.
Fordi tråde ikke er egentlige separate processer, bruger de den samme pythonfortolker og er derfor underlagt den berygtede Global Interpeter Lock (GIL). Måske endnu vigtigere for denne undersøgelse, de deler globalt datarum med forælderen.
I modsætning til dette, Process
de er faktiske nye processer affødt af programmet. Det betyder:
- Ny Python-fortolkerinstans (og ingen GIL)
- Globalt adresseområde er duplikeret
Bemærk forskellen her. Mens Thread
s har adgang til en delt enkelt global sessionsvariabel (gemt internt i tensorflow_backend
modul af Keras), Process
es har dubletter af Session-variablen.
Min bedste forståelse af dette problem er, at Session-variablen formodes at repræsentere en unik forbindelse mellem en klient (proces) og TensorFlow-runtiden, men ved at blive duplikeret i gaffelprocessen, er denne forbindelsesinformation ikke korrekt justeret. Dette får TensorFlow til at hænge, når du forsøger at bruge en session, der er oprettet i en anden proces. Hvis nogen har mere indsigt i, hvordan dette fungerer under motorhjelmen i TensorFlow, ville jeg elske at høre det!
Løsningen/løsningen
Jeg gik med at justere selleri, så den bruger Thread
s i stedet for Process
es til pooling. Der er nogle ulemper ved denne tilgang (se GIL-kommentaren ovenfor), men dette giver os mulighed for kun at indlæse modellen én gang. Vi er alligevel ikke rigtig CPU-bundet, da TensorFlow-runtiden maksimerer alle CPU-kernerne (den kan omgå GIL, da den ikke er skrevet i Python). Du skal forsyne Selleri med et separat bibliotek for at udføre trådbaseret pooling; dokumenterne foreslår to muligheder:gevent
eller eventlet
. Du sender derefter det bibliotek, du vælger, til arbejderen via --pool
kommandolinjeargument.
Alternativt ser det ud til (som du allerede har fundet ud af @pX0r), at andre Keras-backends såsom Theano ikke har dette problem. Det giver mening, da disse problemer er tæt forbundet med TensorFlow implementeringsdetaljer. Jeg personligt har endnu ikke prøvet Theano, så dit kilometertal kan variere.
Jeg ved, at dette spørgsmål blev stillet for et stykke tid siden, men problemet er stadig derude, så forhåbentlig vil dette hjælpe nogen!